Sqxww commited on
Commit
808dc9c
·
1 Parent(s): 35b0893

添加prompt tab

Browse files
Files changed (2) hide show
  1. app.py +126 -2
  2. ominicontrol.py +89 -15
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import spaces
3
- from ominicontrol import generate_image
4
  import os
5
 
6
  from huggingface_hub import login
@@ -193,10 +193,134 @@ def infer(
193
  )
194
  return result_image
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  if USE_ZERO_GPU:
198
  infer = spaces.GPU(infer)
 
199
 
200
  if __name__ == "__main__":
201
- demo = gradio_interface()
202
  demo.launch(server_name="0.0.0.0", ssr_mode=False)
 
1
  import gradio as gr
2
  import spaces
3
+ from ominicontrol import generate_image, generate_image_with_prompt
4
  import os
5
 
6
  from huggingface_hub import login
 
193
  )
194
  return result_image
195
 
196
+ def prompt_gradio_interface():
197
+ with gr.Blocks(css=css) as demo:
198
+ with gr.Row(equal_height=False):
199
+ with gr.Column(variant="panel", elem_classes="inputPanel"):
200
+ original_image = gr.Image(
201
+ type="pil",
202
+ label="Condition Image",
203
+ width=400,
204
+ height=400,
205
+ )
206
+ prompt = gr.Textbox(
207
+ label="Prompt",
208
+ )
209
+ # Advanced settings
210
+ with gr.Accordion(
211
+ "⚙️ Advanced Settings", open=False
212
+ ) as advanced_settings:
213
+ inference_mode = gr.Radio(
214
+ ["High Quality", "Fast"],
215
+ value="High Quality",
216
+ label="Generating Mode",
217
+ )
218
+ image_ratio = gr.Radio(
219
+ ["Auto", "Square(1:1)", "Portrait(2:3)", "Landscape(3:2)"],
220
+ label="Image Ratio",
221
+ value="Auto",
222
+ )
223
+ use_random_seed = gr.Checkbox(label="Use Random Seed", value=True)
224
+ seed = gr.Number(
225
+ label="Seed",
226
+ value=42,
227
+ visible=(not use_random_seed.value),
228
+ )
229
+ use_random_seed.change(
230
+ lambda x: gr.update(visible=(not x)),
231
+ use_random_seed,
232
+ seed,
233
+ show_progress="hidden",
234
+ )
235
+ image_guidance = gr.Slider(
236
+ label="Image Guidance",
237
+ minimum=1.1,
238
+ maximum=5,
239
+ value=1.5,
240
+ step=0.1,
241
+ )
242
+ steps = gr.Slider(
243
+ label="Steps",
244
+ minimum=10,
245
+ maximum=50,
246
+ value=20,
247
+ step=1,
248
+ )
249
+ inference_mode.change(
250
+ lambda x: gr.update(interactive=(x == "High Quality")),
251
+ inference_mode,
252
+ image_guidance,
253
+ show_progress="hidden",
254
+ )
255
+
256
+ btn = gr.Button("Generate Image", variant="primary")
257
+
258
+ with gr.Column(elem_classes="outputPanel"):
259
+ output_image = gr.Image(
260
+ type="pil",
261
+ width=600,
262
+ height=600,
263
+ label="Output Image",
264
+ interactive=False,
265
+ sources=None,
266
+ )
267
+
268
+ # with gr.Row():
269
+ btn.click(
270
+ fn=prompt_infer,
271
+ inputs=[
272
+ original_image,
273
+ prompt,
274
+ inference_mode,
275
+ image_guidance,
276
+ image_ratio,
277
+ use_random_seed,
278
+ seed,
279
+ steps,
280
+ ],
281
+ outputs=[
282
+ output_image,
283
+ ],
284
+ )
285
+
286
+ return demo
287
+
288
+ def prompt_infer(
289
+ original_image,
290
+ prompt,
291
+ inference_mode,
292
+ image_guidance,
293
+ image_ratio,
294
+ use_random_seed,
295
+ seed,
296
+ steps,
297
+ ):
298
+ result_image = generate_image_with_prompt(
299
+ image=original_image,
300
+ prompt=prompt,
301
+ inference_mode=inference_mode,
302
+ image_guidance=image_guidance,
303
+ image_ratio=image_ratio,
304
+ use_random_seed=use_random_seed,
305
+ seed=seed,
306
+ steps=steps,
307
+ )
308
+ return result_image
309
+
310
+ def multi_gradio_interface():
311
+ with gr.Blocks(css="style.css") as demo:
312
+ with gr.Tabs():
313
+ with gr.Tab(label="Style"):
314
+ gradio_interface()
315
+ with gr.Tab(label="Prompt"):
316
+ prompt_gradio_interface()
317
+
318
+ return demo
319
 
320
  if USE_ZERO_GPU:
321
  infer = spaces.GPU(infer)
322
+ prompt_infer = spaces.GPU(prompt_infer)
323
 
324
  if __name__ == "__main__":
325
+ demo = multi_gradio_interface()
326
  demo.launch(server_name="0.0.0.0", ssr_mode=False)
ominicontrol.py CHANGED
@@ -12,6 +12,9 @@ pipe = FluxPipeline.from_pretrained(
12
  )
13
  pipe = pipe.to("cuda")
14
 
 
 
 
15
  pipe.unload_lora_weights()
16
 
17
  pipe.load_lora_weights(
@@ -34,19 +37,6 @@ pipe.load_lora_weights(
34
  weight_name=f"v0/snoopy.safetensors",
35
  adapter_name="snoopy",
36
  )
37
- # ref: https://civitai.com/models/715472/flux-hayao-miyazaki-ghibli
38
- pipe.load_lora_weights(
39
- "./lora",
40
- weight_name="MaoMu_Ghibli.safetensors",
41
- adapter_name="MaoMu_Ghibli",
42
- )
43
- # ref: https://civitai.com/models/824739/flux-3d-animation-style-lora
44
- pipe.load_lora_weights(
45
- "./lora",
46
- weight_name="3d_animation.safetensors",
47
- adapter_name="3d_animation",
48
- )
49
-
50
 
51
  def generate_image(
52
  image,
@@ -72,8 +62,6 @@ def generate_image(
72
  "Irasutoya Illustration": "irasutoya",
73
  "The Simpsons": "simpsons",
74
  "Snoopy": "snoopy",
75
- "3D Animation": "3d_animation",
76
- "MaoMu Ghibli": "MaoMu_Ghibli",
77
  }[style]
78
  pipe.set_adapters(activate_adapter_name)
79
 
@@ -145,3 +133,89 @@ def generate_image(
145
 
146
  return result_img
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  )
13
  pipe = pipe.to("cuda")
14
 
15
+ prompt_pipe = FluxPipeline.from_pipe(pipe)
16
+ prompt_pipe = prompt_pipe.to("cuda")
17
+
18
  pipe.unload_lora_weights()
19
 
20
  pipe.load_lora_weights(
 
37
  weight_name=f"v0/snoopy.safetensors",
38
  adapter_name="snoopy",
39
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def generate_image(
42
  image,
 
62
  "Irasutoya Illustration": "irasutoya",
63
  "The Simpsons": "simpsons",
64
  "Snoopy": "snoopy",
 
 
65
  }[style]
66
  pipe.set_adapters(activate_adapter_name)
67
 
 
133
 
134
  return result_img
135
 
136
+
137
+ def generate_image_with_prompt(
138
+ image,
139
+ prompt,
140
+ inference_mode,
141
+ image_guidance,
142
+ image_ratio,
143
+ steps,
144
+ use_random_seed,
145
+ seed,
146
+ ):
147
+ # Prepare Condition
148
+ def resize(img, factor=16):
149
+ w, h = img.size
150
+ new_w, new_h = w // factor * factor, h // factor * factor
151
+ padding_w, padding_h = (w - new_w) // 2, (h - new_h) // 2
152
+ img = img.crop((padding_w, padding_h, new_w + padding_w, new_h + padding_h))
153
+ return img
154
+
155
+ original_width, original_height = image.size
156
+
157
+ factor = 512 / max(image.size)
158
+ image = resize(
159
+ image.resize(
160
+ (int(image.size[0] * factor), int(image.size[1] * factor)),
161
+ Image.LANCZOS,
162
+ )
163
+ )
164
+
165
+ delta = -image.size[0] // 16
166
+ condition = Condition(
167
+ "subject",
168
+ # activate_adapter_name,
169
+ image,
170
+ position_delta=(0, delta),
171
+ )
172
+
173
+ # Prepare seed
174
+ if use_random_seed:
175
+ seed = random.randint(0, 2**32 - 1)
176
+ seed_everything(seed)
177
+
178
+ # Image guidance scale
179
+ image_guidance = 1.0 if inference_mode == "Fast" else image_guidance
180
+
181
+ # Output size
182
+ if image_ratio == "Auto":
183
+ r = image.size[0] / image.size[1]
184
+ ratio = min([0.67, 1, 1.5], key=lambda x: abs(x - r))
185
+ else:
186
+ ratio = {
187
+ "Square(1:1)": 1,
188
+ "Portrait(2:3)": 0.67,
189
+ "Landscape(3:2)": 1.5,
190
+ }[image_ratio]
191
+ width, height = {
192
+ 0.67: (640, 960),
193
+ 1: (640, 640),
194
+ 1.5: (960, 640),
195
+ }[ratio]
196
+
197
+
198
+ output_factor = max(width, height) / max(original_width, original_height)
199
+ width = int(original_width * output_factor)
200
+ height = int(original_height * output_factor)
201
+
202
+ print(
203
+ f"Image Ratio: {image_ratio}, Inference Mode: {inference_mode}, Image Guidance: {image_guidance}, Seed: {seed}, Steps: {steps}, Ratio: {ratio}, Size: {width}x{height}"
204
+ )
205
+ # Generate
206
+ result_img = generate(
207
+ prompt_pipe,
208
+ prompt=prompt,
209
+ conditions=[condition],
210
+ num_inference_steps=steps,
211
+ width=width,
212
+ height=height,
213
+ image_guidance_scale=image_guidance,
214
+ default_lora=True,
215
+ max_sequence_length=32,
216
+ ).images[0]
217
+ # result_img = image
218
+
219
+ result_img = result_img.resize((width, height), Image.LANCZOS)
220
+
221
+ return result_img