Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import random | |
| from PIL import Image | |
| from pipeline_flux_kontext import FluxKontextPipeline | |
| from diffusers.utils import load_image | |
| # --- LoRA 配置 --- | |
| # 结构: "LoRA显示名称": {"file": "LoRA文件名.safetensors", "adapter_name": "唯一的适配器名称"} | |
| LORA_REPO_ID = "IdlecloudX/Flux_and_Wan_Lora" | |
| LORA_SETS = { | |
| "Remove Clothes": { | |
| "file": "change_clothes_to_nothing_000012800.safetensors", | |
| "adapter_name": "remove_clothes" | |
| } | |
| } | |
| # ------------------------- | |
| # 加载 Kontext 模型 | |
| MAX_SEED = np.iinfo(np.int32).max | |
| print("正在加载 FLUX Kontext pipeline...") | |
| pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda") | |
| print("Pipeline 加载完成。") | |
| # --- 加载所有定义的 LoRA 权重 --- | |
| for name, lora_config in LORA_SETS.items(): | |
| print(f"--- 正在加载 LoRA: {name} ---") | |
| try: | |
| pipe.load_lora_weights( | |
| LORA_REPO_ID, | |
| weight_name=lora_config['file'], | |
| adapter_name=lora_config['adapter_name'] | |
| ) | |
| print(f"'{name}' LoRA 加载成功。") | |
| except Exception as e: | |
| print(f"加载 LoRA '{name}' ({lora_config['file']}) 失败: {e}") | |
| print("请检查 LORA_REPO_ID 和文件名是否正确,或者 LoRA 是否与当前模型兼容。") | |
| # ------------------------------------ | |
| def concatenate_images(images, direction="horizontal"): | |
| """ | |
| 将多个PIL图像水平或垂直拼接。 | |
| """ | |
| if not images: | |
| return None | |
| valid_images = [img for img in images if img is not None] | |
| if not valid_images: | |
| return None | |
| if len(valid_images) == 1: | |
| return valid_images[0].convert("RGB") | |
| valid_images = [img.convert("RGB") for img in valid_images] | |
| if direction == "horizontal": | |
| total_width = sum(img.width for img in valid_images) | |
| max_height = max(img.height for img in valid_images) | |
| concatenated = Image.new('RGB', (total_width, max_height), (255, 255, 255)) | |
| x_offset = 0 | |
| for img in valid_images: | |
| y_offset = (max_height - img.height) // 2 | |
| concatenated.paste(img, (x_offset, y_offset)) | |
| x_offset += img.width | |
| else: # vertical | |
| max_width = max(img.width for img in valid_images) | |
| total_height = sum(img.height for img in valid_images) | |
| concatenated = Image.new('RGB', (max_width, total_height), (255, 255, 255)) | |
| y_offset = 0 | |
| for img in valid_images: | |
| x_offset = (max_width - img.width) // 2 | |
| concatenated.paste(img, (x_offset, y_offset)) | |
| y_offset += img.height | |
| return concatenated | |
| def infer(input_images, prompt, selected_loras, seed=42, randomize_seed=False, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| if input_images is None: | |
| raise gr.Error("请至少上传一张图片。") | |
| if not isinstance(input_images, list): | |
| input_images = [input_images] | |
| valid_images = [img[0] for img in input_images if img is not None] | |
| if not valid_images: | |
| raise gr.Error("请上传至少一张有效的图片。") | |
| concatenated_image = concatenate_images(valid_images, "horizontal") | |
| if concatenated_image is None: | |
| raise gr.Error("处理输入图片失败。") | |
| final_prompt = f"From the provided reference images, create a unified, cohesive image such that {prompt}. Maintain the identity and characteristics of each subject while adjusting their proportions, scale, and positioning to create a harmonious, naturally balanced composition. Blend and integrate all elements seamlessly with consistent lighting, perspective, and style.the final result should look like a single naturally captured scene where all subjects are properly sized and positioned relative to each other, not assembled from multiple sources." | |
| # --- LoRA 应用逻辑 --- | |
| active_adapters = [] | |
| if selected_loras: | |
| for lora_name in selected_loras: | |
| if lora_name in LORA_SETS: | |
| active_adapters.append(LORA_SETS[lora_name]["adapter_name"]) | |
| if active_adapters: | |
| print(f"正在启用选择的 LoRA 适配器: {active_adapters}") | |
| pipe.set_adapters(active_adapters, adapter_weights=[1.0] * len(active_adapters)) | |
| else: | |
| pipe.disable_lora() | |
| image = pipe( | |
| image=concatenated_image, | |
| prompt=final_prompt, | |
| guidance_scale=guidance_scale, | |
| width=concatenated_image.size[0], | |
| height=concatenated_image.size[1], | |
| generator=torch.Generator().manual_seed(seed), | |
| ).images[0] | |
| if active_adapters: | |
| print("推理完成,正在禁用 LoRA 适配器。") | |
| pipe.disable_lora() | |
| return image, seed, gr.update(visible=True) | |
| css=""" | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 960px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown(f"""# FLUX.1 Kontext [dev] - Multi-Image with LoRA | |
| 使用 FLUX.1 Kontext [dev] 将多张图片中的元素组合成一张新图,并支持应用自定义 LoRA 风格。 | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_images = gr.Gallery( | |
| label="上传用于编辑的图片", | |
| show_label=True, | |
| elem_id="gallery_input", | |
| columns=3, | |
| rows=2, | |
| object_fit="contain", | |
| height="auto", | |
| file_types=['image'], | |
| type='pil' | |
| ) | |
| with gr.Row(): | |
| prompt = gr.Text( | |
| label="Prompt", | |
| show_label=False, | |
| info = "描述期望的输出构图", | |
| max_lines=1, | |
| placeholder="例如:左边图片里的狗坐在右边图片的长椅上", | |
| container=False, | |
| ) | |
| run_button = gr.Button("Run", scale=0) | |
| lora_selection = gr.CheckboxGroup( | |
| choices=list(LORA_SETS.keys()), | |
| label="选择 LoRA 风格 (可多选)", | |
| info="选择一个或多个风格进行叠加。" | |
| ) | |
| with gr.Accordion("高级设置", open=False): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox(label="随机种子", value=True) | |
| guidance_scale = gr.Slider( | |
| label="引导系数 (Guidance Scale)", | |
| minimum=1, | |
| maximum=10, | |
| step=0.1, | |
| value=2.5, | |
| ) | |
| with gr.Column(): | |
| result = gr.Image(label="结果", show_label=False, interactive=False) | |
| reuse_button = gr.Button("复用此图", visible=False) | |
| inputs = [input_images, prompt, lora_selection, seed, randomize_seed, guidance_scale] | |
| outputs = [result, seed, reuse_button] | |
| gr.on( | |
| triggers=[run_button.click, prompt.submit], | |
| fn = infer, | |
| inputs = inputs, | |
| outputs = outputs | |
| ) | |
| reuse_button.click( | |
| fn = lambda image: [image] if image is not None else [], | |
| inputs = [result], | |
| outputs = [input_images] | |
| ) | |
| demo.launch() | |