artificialguybr commited on
Commit
49ffc6c
·
verified ·
1 Parent(s): 9a828e5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -0
app.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import spaces
6
+ from PIL import Image
7
+ from huggingface_hub import hf_hub_download
8
+ from safetensors.torch import load_file
9
+ from tqdm import tqdm
10
+ import gc
11
+
12
+ from qwenimage.pipeline_qwen_image_edit import QwenImageEditPipeline
13
+ from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
14
+ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
15
+
16
+
17
+ LORA_CONFIG = {
18
+ "None": {
19
+ "repo_id": None,
20
+ "filename": None,
21
+ "type": "edit",
22
+ "method": "none",
23
+ "prompt_template": "{prompt}",
24
+ "description": "Use the base Qwen-Image-Edit model without any LoRA.",
25
+ },
26
+ "InStyle (Style Transfer)": {
27
+ "repo_id": "peteromallet/Qwen-Image-Edit-InStyle",
28
+ "filename": "InStyle-0.5.safetensors",
29
+ "type": "style",
30
+ "method": "manual_fuse",
31
+ "prompt_template": "Make an image in this style of {prompt}",
32
+ "description": "Transfers the style from a reference image to a new image described by the prompt.",
33
+ },
34
+ "InScene (In-Scene Editing)": {
35
+ "repo_id": "flymy-ai/qwen-image-edit-inscene-lora",
36
+ "filename": "flymy_qwen_image_edit_inscene_lora.safetensors",
37
+ "type": "edit",
38
+ "method": "standard",
39
+ "prompt_template": "{prompt}",
40
+ "description": "Improves in-scene editing, object positioning, and camera perspective changes.",
41
+ },
42
+ "Face Segmentation": {
43
+ "repo_id": "TsienDragon/qwen-image-edit-lora-face-segmentation",
44
+ "filename": "pytorch_lora_weights.safetensors",
45
+ "type": "edit",
46
+ "method": "standard",
47
+ "prompt_template": "change the face to face segmentation mask",
48
+ "description": "Transforms a facial image into a precise segmentation mask.",
49
+ },
50
+ "Object Remover": {
51
+ "repo_id": "valiantcat/Qwen-Image-Edit-Remover-General-LoRA",
52
+ "filename": "qwen-edit-remover.safetensors",
53
+ "type": "edit",
54
+ "method": "standard",
55
+ "prompt_template": "Remove {prompt}",
56
+ "description": "Removes objects from an image while maintaining background consistency.",
57
+ },
58
+ }
59
+
60
+ print("Initializing model...")
61
+ dtype = torch.bfloat16
62
+ device = "cuda" if torch.cuda.is_available() else "cpu"
63
+
64
+ pipe = QwenImageEditPipeline.from_pretrained(
65
+ "Qwen/Qwen-Image-Edit",
66
+ torch_dtype=dtype
67
+ ).to(device)
68
+
69
+ pipe.transformer.__class__ = QwenImageTransformer2DModel
70
+ pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
71
+
72
+ original_transformer_state_dict = pipe.transformer.state_dict()
73
+ print("Base model loaded and ready.")
74
+
75
+ def fuse_lora_manual(transformer, lora_state_dict, alpha=1.0):
76
+ key_mapping = {}
77
+ for key in lora_state_dict.keys():
78
+ base_key = key.replace('diffusion_model.', '').rsplit('.lora_', 1)[0]
79
+ if base_key not in key_mapping:
80
+ key_mapping[base_key] = {}
81
+ if 'lora_A' in key:
82
+ key_mapping[base_key]['down'] = lora_state_dict[key]
83
+ elif 'lora_B' in key:
84
+ key_mapping[base_key]['up'] = lora_state_dict[key]
85
+
86
+ for name, module in tqdm(transformer.named_modules(), desc="Fusing layers"):
87
+ if name in key_mapping and isinstance(module, torch.nn.Linear):
88
+ lora_weights = key_mapping[name]
89
+ if 'down' in lora_weights and 'up' in lora_weights:
90
+ device = module.weight.device
91
+ dtype = module.weight.dtype
92
+ lora_down = lora_weights['down'].to(device, dtype=dtype)
93
+ lora_up = lora_weights['up'].to(device, dtype=dtype)
94
+ merged_delta = lora_up @ lora_down
95
+ module.weight.data += alpha * merged_delta
96
+ return transformer
97
+
98
+ def load_and_fuse_lora(lora_name):
99
+ """Carrega uma LoRA, funde-a ao modelo e retorna o pipeline modificado."""
100
+ config = LORA_CONFIG[lora_name]
101
+
102
+ print("Resetting transformer to original state...")
103
+ pipe.transformer.load_state_dict(original_transformer_state_dict)
104
+
105
+ if config["method"] == "none":
106
+ print("No LoRA selected. Using base model.")
107
+ return
108
+
109
+ print(f"Loading LoRA: {lora_name}")
110
+ lora_path = hf_hub_download(repo_id=config["repo_id"], filename=config["filename"])
111
+
112
+ if config["method"] == "standard":
113
+ print("Using standard loading method...")
114
+ pipe.load_lora_weights(lora_path)
115
+ print("Fusing LoRA into the model...")
116
+ pipe.fuse_lora()
117
+ elif config["method"] == "manual_fuse":
118
+ print("Using manual fusion method...")
119
+ lora_state_dict = load_file(lora_path)
120
+ pipe.transformer = fuse_lora_manual(pipe.transformer, lora_state_dict)
121
+
122
+ gc.collect()
123
+ torch.cuda.empty_cache()
124
+ print(f"LoRA '{lora_name}' is now active.")
125
+
126
+ @spaces.GPU(duration=60)
127
+ def infer(
128
+ lora_name,
129
+ input_image,
130
+ style_image,
131
+ prompt,
132
+ seed,
133
+ randomize_seed,
134
+ true_guidance_scale,
135
+ num_inference_steps,
136
+ progress=gr.Progress(track_tqdm=True),
137
+ ):
138
+ if not lora_name:
139
+ raise gr.Error("Please select a LoRA model.")
140
+
141
+ config = LORA_CONFIG[lora_name]
142
+
143
+ if config["type"] == "style":
144
+ if style_image is None:
145
+ raise gr.Error("Style Transfer LoRA requires a Style Reference Image.")
146
+ image_for_pipeline = style_image
147
+ else: # 'edit'
148
+ if input_image is None:
149
+ raise gr.Error("This LoRA requires an Input Image.")
150
+ image_for_pipeline = input_image
151
+
152
+ if not prompt and config["prompt_template"] != "change the face to face segmentation mask":
153
+ raise gr.Error("A text prompt is required for this LoRA.")
154
+
155
+ load_and_fuse_lora(lora_name)
156
+
157
+ final_prompt = config["prompt_template"].format(prompt=prompt)
158
+
159
+ if randomize_seed:
160
+ seed = random.randint(0, np.iinfo(np.int32).max)
161
+ generator = torch.Generator(device=device).manual_seed(int(seed))
162
+
163
+ print("--- Running Inference ---")
164
+ print(f"LoRA: {lora_name}")
165
+ print(f"Prompt: {final_prompt}")
166
+ print(f"Seed: {seed}, Steps: {num_inference_steps}, CFG: {true_guidance_scale}")
167
+
168
+ with torch.inference_mode():
169
+ result_image = pipe(
170
+ image=image_for_pipeline,
171
+ prompt=final_prompt,
172
+ negative_prompt=" ",
173
+ num_inference_steps=int(num_inference_steps),
174
+ generator=generator,
175
+ true_cfg_scale=true_guidance_scale,
176
+ ).images[0]
177
+
178
+ pipe.unfuse_lora()
179
+ gc.collect()
180
+ torch.cuda.empty_cache()
181
+
182
+ return result_image, seed
183
+
184
+ def on_lora_change(lora_name):
185
+ config = LORA_CONFIG[lora_name]
186
+ is_style_lora = config["type"] == "style"
187
+ return {
188
+ lora_description: gr.Markdown(visible=True, value=f"**Description:** {config['description']}"),
189
+ input_image_box: gr.Image(visible=not is_style_lora),
190
+ style_image_box: gr.Image(visible=is_style_lora),
191
+ prompt_box: gr.Textbox(visible=(config["prompt_template"] != "change the face to face segmentation mask"))
192
+ }
193
+
194
+ with gr.Blocks(css="#col-container { margin: 0 auto; max-width: 1024px; }") as demo:
195
+ with gr.Column(elem_id="col-container"):
196
+ gr.HTML('<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/qwen_image_edit_logo.png" alt="Qwen-Image Logo" style="width: 400px; margin: 0 auto; display: block;">')
197
+ gr.Markdown("<h2 style='text-align: center;'>Qwen-Image-Edit Multi-LoRA Playground</h2>")
198
+
199
+ with gr.Row():
200
+ with gr.Column(scale=1):
201
+ lora_selector = gr.Dropdown(
202
+ label="Select LoRA Model",
203
+ choices=list(LORA_CONFIG.keys()),
204
+ value="InStyle (Style Transfer)"
205
+ )
206
+ lora_description = gr.Markdown(visible=False)
207
+
208
+ input_image_box = gr.Image(label="Input Image", type="pil", visible=False)
209
+ style_image_box = gr.Image(label="Style Reference Image", type="pil", visible=True)
210
+
211
+ prompt_box = gr.Textbox(label="Prompt", placeholder="Describe the content or object to remove...")
212
+
213
+ run_button = gr.Button("Generate!", variant="primary")
214
+
215
+ with gr.Column(scale=1):
216
+ result_image = gr.Image(label="Result", type="pil")
217
+ used_seed = gr.Number(label="Used Seed", interactive=False)
218
+
219
+ with gr.Accordion("Advanced Settings", open=False):
220
+ seed_slider = gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.int32).max, step=1, value=42)
221
+ randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
222
+ cfg_slider = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, step=0.1, value=4.0)
223
+ steps_slider = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=25)
224
+
225
+ lora_selector.change(
226
+ fn=on_lora_change,
227
+ inputs=lora_selector,
228
+ outputs=[lora_description, input_image_box, style_image_box, prompt_box]
229
+ ).then(
230
+ None,
231
+ lora_selector,
232
+ [lora_description, input_image_box, style_image_box, prompt_box],
233
+ _js="() => { document.querySelector('#lora_selector select').dispatchEvent(new Event('change')) }"
234
+ )
235
+
236
+ run_button.click(
237
+ fn=infer,
238
+ inputs=[
239
+ lora_selector,
240
+ input_image_box, style_image_box,
241
+ prompt_box,
242
+ seed_slider, randomize_seed_checkbox,
243
+ cfg_slider, steps_slider
244
+ ],
245
+ outputs=[result_image, used_seed]
246
+ )
247
+
248
+ if __name__ == "__main__":
249
+ demo.launch()