Spaces:
Running
on
Zero
Running
on
Zero
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import cv2
|
7 |
+
import os
|
8 |
+
|
9 |
+
from diffusers.utils import load_image, check_min_version
|
10 |
+
from controlnet_flux import FluxControlNetModel
|
11 |
+
from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
|
12 |
+
from diffusers.models.attention_processor import Attention
|
13 |
+
from transformers import AutoProcessor, AutoModelForMaskGeneration, pipeline
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, List, Dict, Optional, Union, Tuple
|
16 |
+
|
17 |
+
# --- Constants and Setup ---
|
18 |
+
# Ensure all required modules are available
|
19 |
+
check_min_version("0.29.0.dev0")
|
20 |
+
|
21 |
+
# Set a seed for reproducibility. The original script uses a fixed seed.
|
22 |
+
generator = torch.Generator(device="cuda").manual_seed(42)
|
23 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
24 |
+
|
25 |
+
|
26 |
+
# --- Helper Dataclasses (Identical to diptych_prompting_inference.py) ---
|
27 |
+
@dataclass
|
28 |
+
class BoundingBox:
|
29 |
+
xmin: int
|
30 |
+
ymin: int
|
31 |
+
xmax: int
|
32 |
+
ymax: int
|
33 |
+
|
34 |
+
@property
|
35 |
+
def xyxy(self) -> List[float]:
|
36 |
+
return [self.xmin, self.ymin, self.xmax, self.ymax]
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class DetectionResult:
|
40 |
+
score: float
|
41 |
+
label: str
|
42 |
+
box: BoundingBox
|
43 |
+
mask: Optional[np.array] = None
|
44 |
+
|
45 |
+
@classmethod
|
46 |
+
def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
|
47 |
+
return cls(score=detection_dict['score'],
|
48 |
+
label=detection_dict['label'],
|
49 |
+
box=BoundingBox(xmin=detection_dict['box']['xmin'],
|
50 |
+
ymin=detection_dict['box']['ymin'],
|
51 |
+
xmax=detection_dict['box']['xmax'],
|
52 |
+
ymax=detection_dict['box']['ymax']))
|
53 |
+
|
54 |
+
|
55 |
+
# --- Helper Functions (Identical to diptych_prompting_inference.py) ---
|
56 |
+
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
|
57 |
+
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
58 |
+
if not contours:
|
59 |
+
return []
|
60 |
+
largest_contour = max(contours, key=cv2.contourArea)
|
61 |
+
return largest_contour.reshape(-1, 2).tolist()
|
62 |
+
|
63 |
+
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
|
64 |
+
mask = np.zeros(image_shape, dtype=np.uint8)
|
65 |
+
if not polygon:
|
66 |
+
return mask
|
67 |
+
pts = np.array(polygon, dtype=np.int32)
|
68 |
+
cv2.fillPoly(mask, [pts], color=(255,))
|
69 |
+
return mask
|
70 |
+
|
71 |
+
def get_boxes(results: List[DetectionResult]) -> List[List[List[float]]]:
|
72 |
+
boxes = [result.box.xyxy for result in results]
|
73 |
+
return [boxes]
|
74 |
+
|
75 |
+
def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
|
76 |
+
masks = masks.cpu().float().permute(0, 2, 3, 1).mean(axis=-1)
|
77 |
+
masks = (masks > 0).int().numpy().astype(np.uint8)
|
78 |
+
masks = list(masks)
|
79 |
+
|
80 |
+
if polygon_refinement:
|
81 |
+
for idx, mask in enumerate(masks):
|
82 |
+
shape = mask.shape
|
83 |
+
polygon = mask_to_polygon(mask)
|
84 |
+
refined_mask = polygon_to_mask(polygon, shape)
|
85 |
+
masks[idx] = refined_mask
|
86 |
+
return masks
|
87 |
+
|
88 |
+
def detect(
|
89 |
+
object_detector, image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None
|
90 |
+
) -> List[DetectionResult]:
|
91 |
+
labels = [label if label.endswith(".") else label + "." for label in labels]
|
92 |
+
results = object_detector(image, candidate_labels=labels, threshold=threshold)
|
93 |
+
return [DetectionResult.from_dict(result) for result in results]
|
94 |
+
|
95 |
+
def segment(
|
96 |
+
segmentator, processor, image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False
|
97 |
+
) -> List[DetectionResult]:
|
98 |
+
if not detection_results:
|
99 |
+
return []
|
100 |
+
boxes = get_boxes(detection_results)
|
101 |
+
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
|
102 |
+
with torch.no_grad():
|
103 |
+
outputs = segmentator(**inputs)
|
104 |
+
masks = processor.post_process_masks(
|
105 |
+
masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes
|
106 |
+
)[0]
|
107 |
+
masks = refine_masks(masks, polygon_refinement)
|
108 |
+
for detection_result, mask in zip(detection_results, masks):
|
109 |
+
detection_result.mask = mask
|
110 |
+
return detection_results
|
111 |
+
|
112 |
+
def grounded_segmentation(
|
113 |
+
detect_pipeline, segmentator, segment_processor, image: Image.Image, labels: List[str],
|
114 |
+
) -> Tuple[np.ndarray, List[DetectionResult]]:
|
115 |
+
detections = detect(detect_pipeline, image, labels, threshold=0.3)
|
116 |
+
detections = segment(segmentator, segment_processor, image, detections, polygon_refinement=True)
|
117 |
+
return np.array(image), detections
|
118 |
+
|
119 |
+
def segment_image(image, object_name, detector, segmentator, seg_processor):
|
120 |
+
image_array, detections = grounded_segmentation(detector, segmentator, seg_processor, image, [object_name])
|
121 |
+
if not detections or detections[0].mask is None:
|
122 |
+
raise gr.Error(f"Could not segment the subject '{object_name}' in the image. Please try a clearer image or a more specific subject name.")
|
123 |
+
|
124 |
+
mask_expanded = np.expand_dims(detections[0].mask / 255, axis=-1)
|
125 |
+
segment_result = image_array * mask_expanded + np.ones_like(image_array) * (1 - mask_expanded) * 255
|
126 |
+
return Image.fromarray(segment_result.astype(np.uint8))
|
127 |
+
|
128 |
+
def make_diptych(image):
|
129 |
+
ref_image_np = np.array(image)
|
130 |
+
diptych_np = np.concatenate([ref_image_np, np.zeros_like(ref_image_np)], axis=1)
|
131 |
+
return Image.fromarray(diptych_np)
|
132 |
+
|
133 |
+
|
134 |
+
# --- Custom Attention Processor (EXACTLY as in diptych_prompting_inference.py) ---
|
135 |
+
class CustomFluxAttnProcessor2_0:
|
136 |
+
def __init__(self, height=44, width=88, attn_enforce=1.0):
|
137 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
138 |
+
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
139 |
+
self.height = height
|
140 |
+
self.width = width
|
141 |
+
self.num_pixels = height * width
|
142 |
+
self.step = 0
|
143 |
+
self.attn_enforce = attn_enforce
|
144 |
+
|
145 |
+
def __call__(
|
146 |
+
self,
|
147 |
+
attn: Attention,
|
148 |
+
hidden_states: torch.FloatTensor,
|
149 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
150 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
151 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
152 |
+
) -> torch.FloatTensor:
|
153 |
+
self.step += 1
|
154 |
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
155 |
+
|
156 |
+
query = attn.to_q(hidden_states)
|
157 |
+
key = attn.to_k(hidden_states)
|
158 |
+
value = attn.to_v(hidden_states)
|
159 |
+
inner_dim, head_dim = key.shape[-1], key.shape[-1] // attn.heads
|
160 |
+
query, key, value = [x.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) for x in [query, key, value]]
|
161 |
+
|
162 |
+
if attn.norm_q is not None: query = attn.norm_q(query)
|
163 |
+
if attn.norm_k is not None: key = attn.norm_k(key)
|
164 |
+
|
165 |
+
if encoder_hidden_states is not None:
|
166 |
+
encoder_q = attn.add_q_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
167 |
+
encoder_k = attn.add_k_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
168 |
+
encoder_v = attn.add_v_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
169 |
+
if attn.norm_added_q is not None: encoder_q = attn.norm_added_q(encoder_q)
|
170 |
+
if attn.norm_added_k is not None: encoder_k = attn.norm_added_k(encoder_k)
|
171 |
+
query, key, value = [torch.cat([e, x], dim=2) for e, x in zip([encoder_q, encoder_k, encoder_v], [query, key, value])]
|
172 |
+
|
173 |
+
if image_rotary_emb is not None:
|
174 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
175 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
176 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
177 |
+
|
178 |
+
if self.attn_enforce != 1.0:
|
179 |
+
attn_probs = (torch.einsum('bhqd,bhkd->bhqk', query, key) * attn.scale).softmax(dim=-1)
|
180 |
+
img_attn_probs = attn_probs[:, :, -self.num_pixels:, -self.num_pixels:].reshape((batch_size, attn.heads, self.height, self.width, self.height, self.width))
|
181 |
+
img_attn_probs[:, :, :, self.width//2:, :, :self.width//2] *= self.attn_enforce
|
182 |
+
img_attn_probs = img_attn_probs.reshape((batch_size, attn.heads, self.num_pixels, self.num_pixels))
|
183 |
+
attn_probs[:, :, -self.num_pixels:, -self.num_pixels:] = img_attn_probs
|
184 |
+
hidden_states = torch.einsum('bhqk,bhkd->bhqd', attn_probs, value)
|
185 |
+
else:
|
186 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
187 |
+
|
188 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
|
189 |
+
|
190 |
+
if encoder_hidden_states is not None:
|
191 |
+
encoder_hs, hs = hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :]
|
192 |
+
hs = attn.to_out[0](hs)
|
193 |
+
hs = attn.to_out[1](hs)
|
194 |
+
encoder_hs = attn.to_add_out(encoder_hs)
|
195 |
+
return hs, encoder_hs
|
196 |
+
else:
|
197 |
+
return hidden_states
|
198 |
+
|
199 |
+
|
200 |
+
# --- Model Loading (executed once at startup) ---
|
201 |
+
print("--- Loading Models: This may take a few minutes and requires >40GB VRAM ---")
|
202 |
+
controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16)
|
203 |
+
pipe = FluxControlNetInpaintingPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16).to(device)
|
204 |
+
pipe.transformer.to(torch.bfloat16)
|
205 |
+
pipe.controlnet.to(torch.bfloat16)
|
206 |
+
base_attn_procs = pipe.transformer.attn_processors.copy()
|
207 |
+
|
208 |
+
print("Loading segmentation models...")
|
209 |
+
detector_id, segmenter_id = "IDEA-Research/grounding-dino-tiny", "facebook/sam-vit-base"
|
210 |
+
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
|
211 |
+
segment_processor = AutoProcessor.from_pretrained(segmenter_id)
|
212 |
+
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
|
213 |
+
print("--- All models loaded successfully! ---")
|
214 |
+
|
215 |
+
|
216 |
+
# --- Main Inference Function for Gradio ---
|
217 |
+
def run_diptych_prompting(
|
218 |
+
input_image: Image.Image,
|
219 |
+
subject_name: str,
|
220 |
+
target_prompt: str,
|
221 |
+
attn_enforce: float,
|
222 |
+
ctrl_scale: float,
|
223 |
+
width: int,
|
224 |
+
height: int,
|
225 |
+
pixel_offset: int,
|
226 |
+
num_steps: int,
|
227 |
+
guidance: float,
|
228 |
+
progress=gr.Progress(track_tqdm=True)
|
229 |
+
):
|
230 |
+
if input_image is None: raise gr.Error("Please upload a reference image.")
|
231 |
+
if not subject_name: raise gr.Error("Please provide the subject's name (e.g., 'a red car').")
|
232 |
+
if not target_prompt: raise gr.Error("Please provide a target prompt.")
|
233 |
+
|
234 |
+
# 1. Prepare dimensions (logic from original script's main block)
|
235 |
+
padded_width = width + pixel_offset * 2
|
236 |
+
padded_height = height + pixel_offset * 2
|
237 |
+
diptych_size = (padded_width * 2, padded_height)
|
238 |
+
|
239 |
+
# 2. Prepare prompts and images
|
240 |
+
progress(0, desc="Resizing and segmenting reference image...")
|
241 |
+
base_prompt = f"a photo of {subject_name}"
|
242 |
+
diptych_text_prompt = f"A diptych with two side-by-side images of same {subject_name}. On the left, {base_prompt}. On the right, replicate this {subject_name} exactly but as {target_prompt}"
|
243 |
+
|
244 |
+
reference_image = input_image.resize((padded_width, padded_height)).convert("RGB")
|
245 |
+
segmented_image = segment_image(reference_image, subject_name, object_detector, segmentator, segment_processor)
|
246 |
+
|
247 |
+
progress(0.2, desc="Creating diptych and mask...")
|
248 |
+
mask_image = np.concatenate([np.zeros((padded_height, padded_width, 3)), np.ones((padded_height, padded_width, 3)) * 255], axis=1)
|
249 |
+
mask_image = Image.fromarray(mask_image.astype(np.uint8))
|
250 |
+
diptych_image_prompt = make_diptych(segmented_image)
|
251 |
+
|
252 |
+
# 3. Setup Attention Processor (logic from original script's main block)
|
253 |
+
progress(0.3, desc="Setting up attention processors...")
|
254 |
+
new_attn_procs = base_attn_procs.copy()
|
255 |
+
for k in new_attn_procs:
|
256 |
+
# Use full diptych dimensions for the attention processor
|
257 |
+
new_attn_procs[k] = CustomFluxAttnProcessor2_0(height=padded_height // 16, width=padded_width * 2 // 16, attn_enforce=attn_enforce)
|
258 |
+
pipe.transformer.set_attn_processor(new_attn_procs)
|
259 |
+
|
260 |
+
# 4. Run Inference (using parameters identical to the original script)
|
261 |
+
progress(0.4, desc="Running diffusion process...")
|
262 |
+
result = pipe(
|
263 |
+
prompt=diptych_text_prompt,
|
264 |
+
height=diptych_size[1],
|
265 |
+
width=diptych_size[0],
|
266 |
+
control_image=diptych_image_prompt,
|
267 |
+
control_mask=mask_image,
|
268 |
+
num_inference_steps=num_steps,
|
269 |
+
generator=generator,
|
270 |
+
controlnet_conditioning_scale=ctrl_scale,
|
271 |
+
guidance_scale=guidance, # This is used for guidance embeds if enabled
|
272 |
+
negative_prompt="",
|
273 |
+
true_guidance_scale=guidance # **CRITICAL FIX**: This matches the original script's CFG scale
|
274 |
+
).images[0]
|
275 |
+
|
276 |
+
# 5. Final cropping (logic from original script's main block)
|
277 |
+
progress(0.95, desc="Finalizing image...")
|
278 |
+
# Crop the right panel
|
279 |
+
result = result.crop((padded_width, 0, padded_width * 2, padded_height))
|
280 |
+
# Crop the pixel offset padding
|
281 |
+
result = result.crop((pixel_offset, pixel_offset, padded_width - pixel_offset, padded_height - pixel_offset))
|
282 |
+
|
283 |
+
return result
|
284 |
+
|
285 |
+
|
286 |
+
# --- Gradio UI Definition ---
|
287 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
288 |
+
gr.Markdown(
|
289 |
+
"""
|
290 |
+
# Diptych Prompting: Zero-Shot Subject-Driven Image Generation
|
291 |
+
### Official Gradio Demo for the paper "[Large-Scale Text-to-Image Model with Inpainting is a Zero-Shot Subject-Driven Image Generator](https://diptychprompting.github.io/)"
|
292 |
+
**Instructions:**
|
293 |
+
1. Upload a clear image of a subject.
|
294 |
+
2. Describe the subject in the "Subject Name" box (e.g., 'a plush bear', 'a red sports car').
|
295 |
+
3. Describe the new scene for your subject in the "Target Prompt" box.
|
296 |
+
|
297 |
+
> **Note:** This demo runs on a powerful GPU and requires over **40GB of VRAM**. Inference may take 1-2 minutes.
|
298 |
+
"""
|
299 |
+
)
|
300 |
+
with gr.Row():
|
301 |
+
with gr.Column(scale=1):
|
302 |
+
input_image = gr.Image(type="pil", label="1. Reference Image")
|
303 |
+
subject_name = gr.Textbox(label="2. Subject Name", placeholder="e.g., a plush bear")
|
304 |
+
target_prompt = gr.Textbox(label="3. Target Prompt", placeholder="e.g., riding a skateboard on the moon")
|
305 |
+
run_button = gr.Button("Generate Image", variant="primary")
|
306 |
+
with gr.Accordion("Advanced Settings", open=False):
|
307 |
+
attn_enforce = gr.Slider(minimum=1.0, maximum=2.0, value=1.3, step=0.05, label="Attention Enforcement")
|
308 |
+
ctrl_scale = gr.Slider(minimum=0.5, maximum=1.0, value=0.95, step=0.01, label="ControlNet Scale")
|
309 |
+
num_steps = gr.Slider(minimum=20, maximum=50, value=30, step=1, label="Inference Steps")
|
310 |
+
guidance = gr.Slider(minimum=1.0, maximum=10.0, value=3.5, step=0.1, label="Guidance Scale")
|
311 |
+
width = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Width")
|
312 |
+
height = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Height")
|
313 |
+
pixel_offset = gr.Slider(minimum=0, maximum=32, value=8, step=1, label="Padding (Pixel Offset)")
|
314 |
+
with gr.Column(scale=1):
|
315 |
+
output_image = gr.Image(type="pil", label="Generated Image")
|
316 |
+
|
317 |
+
gr.Examples(
|
318 |
+
examples=[
|
319 |
+
["./assets/bear_plushie.jpg", "a bear plushie", "a bear plushie riding a skateboard"],
|
320 |
+
["./assets/corgi.jpg", "a corgi dog", "a corgi dog wearing a superhero cape and flying"],
|
321 |
+
["./assets/teapot.png", "a blue and white teapot", "a blue and white teapot in a field of flowers"],
|
322 |
+
],
|
323 |
+
inputs=[input_image, subject_name, target_prompt],
|
324 |
+
outputs=output_image,
|
325 |
+
fn=run_diptych_prompting,
|
326 |
+
cache_examples=False,
|
327 |
+
)
|
328 |
+
|
329 |
+
run_button.click(
|
330 |
+
fn=run_diptych_prompting,
|
331 |
+
inputs=[input_image, subject_name, target_prompt, attn_enforce, ctrl_scale, width, height, pixel_offset, num_steps, guidance],
|
332 |
+
outputs=output_image
|
333 |
+
)
|
334 |
+
|
335 |
+
if __name__ == "__main__":
|
336 |
+
if not os.path.exists("./assets"):
|
337 |
+
os.makedirs("./assets")
|
338 |
+
print("Created './assets' directory. Please add example images like 'bear_plushie.jpg' there for the examples to work.")
|
339 |
+
|
340 |
+
demo.launch(share=True)
|