FLUX_1_Kontext / app.py
IdlecloudX's picture
Upload 3 files
a194186 verified
raw
history blame
7.99 kB
import gradio as gr
import numpy as np
import spaces
import torch
import random
from PIL import Image
from pipeline_flux_kontext import FluxKontextPipeline
from diffusers.utils import load_image
# --- LoRA 配置 ---
# 结构: "LoRA显示名称": {"file": "LoRA文件名.safetensors", "adapter_name": "唯一的适配器名称"}
LORA_REPO_ID = "IdlecloudX/Flux_and_Wan_Lora"
LORA_SETS = {
"Remove Clothes": {
"file": "change_clothes_to_nothing_000012800.safetensors",
"adapter_name": "remove_clothes"
}
}
# -------------------------
# 加载 Kontext 模型
MAX_SEED = np.iinfo(np.int32).max
print("正在加载 FLUX Kontext pipeline...")
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
print("Pipeline 加载完成。")
# --- 加载所有定义的 LoRA 权重 ---
for name, lora_config in LORA_SETS.items():
print(f"--- 正在加载 LoRA: {name} ---")
try:
pipe.load_lora_weights(
LORA_REPO_ID,
weight_name=lora_config['file'],
adapter_name=lora_config['adapter_name']
)
print(f"'{name}' LoRA 加载成功。")
except Exception as e:
print(f"加载 LoRA '{name}' ({lora_config['file']}) 失败: {e}")
print("请检查 LORA_REPO_ID 和文件名是否正确,或者 LoRA 是否与当前模型兼容。")
# ------------------------------------
def concatenate_images(images, direction="horizontal"):
"""
将多个PIL图像水平或垂直拼接。
"""
if not images:
return None
valid_images = [img for img in images if img is not None]
if not valid_images:
return None
if len(valid_images) == 1:
return valid_images[0].convert("RGB")
valid_images = [img.convert("RGB") for img in valid_images]
if direction == "horizontal":
total_width = sum(img.width for img in valid_images)
max_height = max(img.height for img in valid_images)
concatenated = Image.new('RGB', (total_width, max_height), (255, 255, 255))
x_offset = 0
for img in valid_images:
y_offset = (max_height - img.height) // 2
concatenated.paste(img, (x_offset, y_offset))
x_offset += img.width
else: # vertical
max_width = max(img.width for img in valid_images)
total_height = sum(img.height for img in valid_images)
concatenated = Image.new('RGB', (max_width, total_height), (255, 255, 255))
y_offset = 0
for img in valid_images:
x_offset = (max_width - img.width) // 2
concatenated.paste(img, (x_offset, y_offset))
y_offset += img.height
return concatenated
@spaces.GPU
def infer(input_images, prompt, selected_loras, seed=42, randomize_seed=False, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
if input_images is None:
raise gr.Error("请至少上传一张图片。")
if not isinstance(input_images, list):
input_images = [input_images]
valid_images = [img[0] for img in input_images if img is not None]
if not valid_images:
raise gr.Error("请上传至少一张有效的图片。")
concatenated_image = concatenate_images(valid_images, "horizontal")
if concatenated_image is None:
raise gr.Error("处理输入图片失败。")
final_prompt = f"From the provided reference images, create a unified, cohesive image such that {prompt}. Maintain the identity and characteristics of each subject while adjusting their proportions, scale, and positioning to create a harmonious, naturally balanced composition. Blend and integrate all elements seamlessly with consistent lighting, perspective, and style.the final result should look like a single naturally captured scene where all subjects are properly sized and positioned relative to each other, not assembled from multiple sources."
# --- LoRA 应用逻辑 ---
active_adapters = []
if selected_loras:
for lora_name in selected_loras:
if lora_name in LORA_SETS:
active_adapters.append(LORA_SETS[lora_name]["adapter_name"])
if active_adapters:
print(f"正在启用选择的 LoRA 适配器: {active_adapters}")
pipe.set_adapters(active_adapters, adapter_weights=[1.0] * len(active_adapters))
else:
pipe.disable_lora()
image = pipe(
image=concatenated_image,
prompt=final_prompt,
guidance_scale=guidance_scale,
width=concatenated_image.size[0],
height=concatenated_image.size[1],
generator=torch.Generator().manual_seed(seed),
).images[0]
if active_adapters:
print("推理完成,正在禁用 LoRA 适配器。")
pipe.disable_lora()
return image, seed, gr.update(visible=True)
css="""
#col-container {
margin: 0 auto;
max-width: 960px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# FLUX.1 Kontext [dev] - Multi-Image with LoRA
使用 FLUX.1 Kontext [dev] 将多张图片中的元素组合成一张新图,并支持应用自定义 LoRA 风格。
""")
with gr.Row():
with gr.Column():
input_images = gr.Gallery(
label="上传用于编辑的图片",
show_label=True,
elem_id="gallery_input",
columns=3,
rows=2,
object_fit="contain",
height="auto",
file_types=['image'],
type='pil'
)
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
info = "描述期望的输出构图",
max_lines=1,
placeholder="例如:左边图片里的狗坐在右边图片的长椅上",
container=False,
)
run_button = gr.Button("Run", scale=0)
lora_selection = gr.CheckboxGroup(
choices=list(LORA_SETS.keys()),
label="选择 LoRA 风格 (可多选)",
info="选择一个或多个风格进行叠加。"
)
with gr.Accordion("高级设置", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="随机种子", value=True)
guidance_scale = gr.Slider(
label="引导系数 (Guidance Scale)",
minimum=1,
maximum=10,
step=0.1,
value=2.5,
)
with gr.Column():
result = gr.Image(label="结果", show_label=False, interactive=False)
reuse_button = gr.Button("复用此图", visible=False)
inputs = [input_images, prompt, lora_selection, seed, randomize_seed, guidance_scale]
outputs = [result, seed, reuse_button]
gr.on(
triggers=[run_button.click, prompt.submit],
fn = infer,
inputs = inputs,
outputs = outputs
)
reuse_button.click(
fn = lambda image: [image] if image is not None else [],
inputs = [result],
outputs = [input_images]
)
demo.launch()