Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -15,15 +15,8 @@ from transformers import AutoProcessor, AutoModelForMaskGeneration, pipeline
|
|
15 |
from dataclasses import dataclass
|
16 |
from typing import Any, List, Dict, Optional, Union, Tuple
|
17 |
|
18 |
-
# --- Constants and Setup ---
|
19 |
-
# Ensure all required modules are available
|
20 |
-
check_min_version("0.29.0.dev0")
|
21 |
-
|
22 |
-
# Set a seed for reproducibility. The original script uses a fixed seed.
|
23 |
-
generator = torch.Generator(device="cuda").manual_seed(42)
|
24 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
|
26 |
-
|
27 |
# --- Helper Dataclasses (Identical to diptych_prompting_inference.py) ---
|
28 |
@dataclass
|
29 |
class BoundingBox:
|
@@ -220,15 +213,22 @@ def run_diptych_prompting(
|
|
220 |
input_image: Image.Image,
|
221 |
subject_name: str,
|
222 |
target_prompt: str,
|
223 |
-
attn_enforce: float,
|
224 |
-
ctrl_scale: float,
|
225 |
-
width: int,
|
226 |
-
height: int,
|
227 |
-
pixel_offset: int,
|
228 |
-
num_steps: int,
|
229 |
-
guidance: float,
|
|
|
|
|
230 |
progress=gr.Progress(track_tqdm=True)
|
231 |
):
|
|
|
|
|
|
|
|
|
|
|
232 |
if input_image is None: raise gr.Error("Please upload a reference image.")
|
233 |
if not subject_name: raise gr.Error("Please provide the subject's name (e.g., 'a red car').")
|
234 |
if not target_prompt: raise gr.Error("Please provide a target prompt.")
|
@@ -261,6 +261,7 @@ def run_diptych_prompting(
|
|
261 |
|
262 |
# 4. Run Inference (using parameters identical to the original script)
|
263 |
progress(0.4, desc="Running diffusion process...")
|
|
|
264 |
result = pipe(
|
265 |
prompt=diptych_text_prompt,
|
266 |
height=diptych_size[1],
|
@@ -313,6 +314,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
313 |
width = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Width")
|
314 |
height = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Height")
|
315 |
pixel_offset = gr.Slider(minimum=0, maximum=32, value=8, step=1, label="Padding (Pixel Offset)")
|
|
|
|
|
316 |
with gr.Column(scale=1):
|
317 |
output_image = gr.Image(type="pil", label="Generated Image")
|
318 |
|
@@ -325,12 +328,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
325 |
inputs=[input_image, subject_name, target_prompt],
|
326 |
outputs=output_image,
|
327 |
fn=run_diptych_prompting,
|
328 |
-
cache_examples=
|
329 |
)
|
330 |
|
331 |
run_button.click(
|
332 |
fn=run_diptych_prompting,
|
333 |
-
inputs=[input_image, subject_name, target_prompt, attn_enforce, ctrl_scale, width, height, pixel_offset, num_steps, guidance],
|
334 |
outputs=output_image
|
335 |
)
|
336 |
|
|
|
15 |
from dataclasses import dataclass
|
16 |
from typing import Any, List, Dict, Optional, Union, Tuple
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
|
|
|
20 |
# --- Helper Dataclasses (Identical to diptych_prompting_inference.py) ---
|
21 |
@dataclass
|
22 |
class BoundingBox:
|
|
|
213 |
input_image: Image.Image,
|
214 |
subject_name: str,
|
215 |
target_prompt: str,
|
216 |
+
attn_enforce: float = 1.3,
|
217 |
+
ctrl_scale: float = 0.95,
|
218 |
+
width: int = 768,
|
219 |
+
height: int = 768,
|
220 |
+
pixel_offset: int = 8,
|
221 |
+
num_steps: int = 30,
|
222 |
+
guidance: float = 3.5,
|
223 |
+
seed: int = 42,
|
224 |
+
randomize_seed: bool = False,
|
225 |
progress=gr.Progress(track_tqdm=True)
|
226 |
):
|
227 |
+
if randomize_seed:
|
228 |
+
actual_seed = random.randint(0, 9223372036854775807)
|
229 |
+
else:
|
230 |
+
actual_seed = seed
|
231 |
+
|
232 |
if input_image is None: raise gr.Error("Please upload a reference image.")
|
233 |
if not subject_name: raise gr.Error("Please provide the subject's name (e.g., 'a red car').")
|
234 |
if not target_prompt: raise gr.Error("Please provide a target prompt.")
|
|
|
261 |
|
262 |
# 4. Run Inference (using parameters identical to the original script)
|
263 |
progress(0.4, desc="Running diffusion process...")
|
264 |
+
generator = torch.Generator(device="cuda").manual_seed(actual_seed)
|
265 |
result = pipe(
|
266 |
prompt=diptych_text_prompt,
|
267 |
height=diptych_size[1],
|
|
|
314 |
width = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Width")
|
315 |
height = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Height")
|
316 |
pixel_offset = gr.Slider(minimum=0, maximum=32, value=8, step=1, label="Padding (Pixel Offset)")
|
317 |
+
seed = gr.Slider(minimum=0, maximum=9223372036854775807, value=42, step=1, label="Seed")
|
318 |
+
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
319 |
with gr.Column(scale=1):
|
320 |
output_image = gr.Image(type="pil", label="Generated Image")
|
321 |
|
|
|
328 |
inputs=[input_image, subject_name, target_prompt],
|
329 |
outputs=output_image,
|
330 |
fn=run_diptych_prompting,
|
331 |
+
cache_examples="lazy",
|
332 |
)
|
333 |
|
334 |
run_button.click(
|
335 |
fn=run_diptych_prompting,
|
336 |
+
inputs=[input_image, subject_name, target_prompt, attn_enforce, ctrl_scale, width, height, pixel_offset, num_steps, guidance, seed, randomize_seed],
|
337 |
outputs=output_image
|
338 |
)
|
339 |
|