multimodalart HF Staff commited on
Commit
7ce70f3
·
verified ·
1 Parent(s): 0ca2595

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +340 -0
app.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+ import os
8
+
9
+ from diffusers.utils import load_image, check_min_version
10
+ from controlnet_flux import FluxControlNetModel
11
+ from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
12
+ from diffusers.models.attention_processor import Attention
13
+ from transformers import AutoProcessor, AutoModelForMaskGeneration, pipeline
14
+ from dataclasses import dataclass
15
+ from typing import Any, List, Dict, Optional, Union, Tuple
16
+
17
+ # --- Constants and Setup ---
18
+ # Ensure all required modules are available
19
+ check_min_version("0.29.0.dev0")
20
+
21
+ # Set a seed for reproducibility. The original script uses a fixed seed.
22
+ generator = torch.Generator(device="cuda").manual_seed(42)
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+
26
+ # --- Helper Dataclasses (Identical to diptych_prompting_inference.py) ---
27
+ @dataclass
28
+ class BoundingBox:
29
+ xmin: int
30
+ ymin: int
31
+ xmax: int
32
+ ymax: int
33
+
34
+ @property
35
+ def xyxy(self) -> List[float]:
36
+ return [self.xmin, self.ymin, self.xmax, self.ymax]
37
+
38
+ @dataclass
39
+ class DetectionResult:
40
+ score: float
41
+ label: str
42
+ box: BoundingBox
43
+ mask: Optional[np.array] = None
44
+
45
+ @classmethod
46
+ def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
47
+ return cls(score=detection_dict['score'],
48
+ label=detection_dict['label'],
49
+ box=BoundingBox(xmin=detection_dict['box']['xmin'],
50
+ ymin=detection_dict['box']['ymin'],
51
+ xmax=detection_dict['box']['xmax'],
52
+ ymax=detection_dict['box']['ymax']))
53
+
54
+
55
+ # --- Helper Functions (Identical to diptych_prompting_inference.py) ---
56
+ def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
57
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
58
+ if not contours:
59
+ return []
60
+ largest_contour = max(contours, key=cv2.contourArea)
61
+ return largest_contour.reshape(-1, 2).tolist()
62
+
63
+ def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
64
+ mask = np.zeros(image_shape, dtype=np.uint8)
65
+ if not polygon:
66
+ return mask
67
+ pts = np.array(polygon, dtype=np.int32)
68
+ cv2.fillPoly(mask, [pts], color=(255,))
69
+ return mask
70
+
71
+ def get_boxes(results: List[DetectionResult]) -> List[List[List[float]]]:
72
+ boxes = [result.box.xyxy for result in results]
73
+ return [boxes]
74
+
75
+ def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
76
+ masks = masks.cpu().float().permute(0, 2, 3, 1).mean(axis=-1)
77
+ masks = (masks > 0).int().numpy().astype(np.uint8)
78
+ masks = list(masks)
79
+
80
+ if polygon_refinement:
81
+ for idx, mask in enumerate(masks):
82
+ shape = mask.shape
83
+ polygon = mask_to_polygon(mask)
84
+ refined_mask = polygon_to_mask(polygon, shape)
85
+ masks[idx] = refined_mask
86
+ return masks
87
+
88
+ def detect(
89
+ object_detector, image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None
90
+ ) -> List[DetectionResult]:
91
+ labels = [label if label.endswith(".") else label + "." for label in labels]
92
+ results = object_detector(image, candidate_labels=labels, threshold=threshold)
93
+ return [DetectionResult.from_dict(result) for result in results]
94
+
95
+ def segment(
96
+ segmentator, processor, image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False
97
+ ) -> List[DetectionResult]:
98
+ if not detection_results:
99
+ return []
100
+ boxes = get_boxes(detection_results)
101
+ inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
102
+ with torch.no_grad():
103
+ outputs = segmentator(**inputs)
104
+ masks = processor.post_process_masks(
105
+ masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes
106
+ )[0]
107
+ masks = refine_masks(masks, polygon_refinement)
108
+ for detection_result, mask in zip(detection_results, masks):
109
+ detection_result.mask = mask
110
+ return detection_results
111
+
112
+ def grounded_segmentation(
113
+ detect_pipeline, segmentator, segment_processor, image: Image.Image, labels: List[str],
114
+ ) -> Tuple[np.ndarray, List[DetectionResult]]:
115
+ detections = detect(detect_pipeline, image, labels, threshold=0.3)
116
+ detections = segment(segmentator, segment_processor, image, detections, polygon_refinement=True)
117
+ return np.array(image), detections
118
+
119
+ def segment_image(image, object_name, detector, segmentator, seg_processor):
120
+ image_array, detections = grounded_segmentation(detector, segmentator, seg_processor, image, [object_name])
121
+ if not detections or detections[0].mask is None:
122
+ raise gr.Error(f"Could not segment the subject '{object_name}' in the image. Please try a clearer image or a more specific subject name.")
123
+
124
+ mask_expanded = np.expand_dims(detections[0].mask / 255, axis=-1)
125
+ segment_result = image_array * mask_expanded + np.ones_like(image_array) * (1 - mask_expanded) * 255
126
+ return Image.fromarray(segment_result.astype(np.uint8))
127
+
128
+ def make_diptych(image):
129
+ ref_image_np = np.array(image)
130
+ diptych_np = np.concatenate([ref_image_np, np.zeros_like(ref_image_np)], axis=1)
131
+ return Image.fromarray(diptych_np)
132
+
133
+
134
+ # --- Custom Attention Processor (EXACTLY as in diptych_prompting_inference.py) ---
135
+ class CustomFluxAttnProcessor2_0:
136
+ def __init__(self, height=44, width=88, attn_enforce=1.0):
137
+ if not hasattr(F, "scaled_dot_product_attention"):
138
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
139
+ self.height = height
140
+ self.width = width
141
+ self.num_pixels = height * width
142
+ self.step = 0
143
+ self.attn_enforce = attn_enforce
144
+
145
+ def __call__(
146
+ self,
147
+ attn: Attention,
148
+ hidden_states: torch.FloatTensor,
149
+ encoder_hidden_states: torch.FloatTensor = None,
150
+ attention_mask: Optional[torch.FloatTensor] = None,
151
+ image_rotary_emb: Optional[torch.Tensor] = None,
152
+ ) -> torch.FloatTensor:
153
+ self.step += 1
154
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
155
+
156
+ query = attn.to_q(hidden_states)
157
+ key = attn.to_k(hidden_states)
158
+ value = attn.to_v(hidden_states)
159
+ inner_dim, head_dim = key.shape[-1], key.shape[-1] // attn.heads
160
+ query, key, value = [x.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) for x in [query, key, value]]
161
+
162
+ if attn.norm_q is not None: query = attn.norm_q(query)
163
+ if attn.norm_k is not None: key = attn.norm_k(key)
164
+
165
+ if encoder_hidden_states is not None:
166
+ encoder_q = attn.add_q_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
167
+ encoder_k = attn.add_k_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
168
+ encoder_v = attn.add_v_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
169
+ if attn.norm_added_q is not None: encoder_q = attn.norm_added_q(encoder_q)
170
+ if attn.norm_added_k is not None: encoder_k = attn.norm_added_k(encoder_k)
171
+ query, key, value = [torch.cat([e, x], dim=2) for e, x in zip([encoder_q, encoder_k, encoder_v], [query, key, value])]
172
+
173
+ if image_rotary_emb is not None:
174
+ from diffusers.models.embeddings import apply_rotary_emb
175
+ query = apply_rotary_emb(query, image_rotary_emb)
176
+ key = apply_rotary_emb(key, image_rotary_emb)
177
+
178
+ if self.attn_enforce != 1.0:
179
+ attn_probs = (torch.einsum('bhqd,bhkd->bhqk', query, key) * attn.scale).softmax(dim=-1)
180
+ img_attn_probs = attn_probs[:, :, -self.num_pixels:, -self.num_pixels:].reshape((batch_size, attn.heads, self.height, self.width, self.height, self.width))
181
+ img_attn_probs[:, :, :, self.width//2:, :, :self.width//2] *= self.attn_enforce
182
+ img_attn_probs = img_attn_probs.reshape((batch_size, attn.heads, self.num_pixels, self.num_pixels))
183
+ attn_probs[:, :, -self.num_pixels:, -self.num_pixels:] = img_attn_probs
184
+ hidden_states = torch.einsum('bhqk,bhkd->bhqd', attn_probs, value)
185
+ else:
186
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
187
+
188
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
189
+
190
+ if encoder_hidden_states is not None:
191
+ encoder_hs, hs = hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :]
192
+ hs = attn.to_out[0](hs)
193
+ hs = attn.to_out[1](hs)
194
+ encoder_hs = attn.to_add_out(encoder_hs)
195
+ return hs, encoder_hs
196
+ else:
197
+ return hidden_states
198
+
199
+
200
+ # --- Model Loading (executed once at startup) ---
201
+ print("--- Loading Models: This may take a few minutes and requires >40GB VRAM ---")
202
+ controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16)
203
+ pipe = FluxControlNetInpaintingPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16).to(device)
204
+ pipe.transformer.to(torch.bfloat16)
205
+ pipe.controlnet.to(torch.bfloat16)
206
+ base_attn_procs = pipe.transformer.attn_processors.copy()
207
+
208
+ print("Loading segmentation models...")
209
+ detector_id, segmenter_id = "IDEA-Research/grounding-dino-tiny", "facebook/sam-vit-base"
210
+ segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
211
+ segment_processor = AutoProcessor.from_pretrained(segmenter_id)
212
+ object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
213
+ print("--- All models loaded successfully! ---")
214
+
215
+
216
+ # --- Main Inference Function for Gradio ---
217
+ def run_diptych_prompting(
218
+ input_image: Image.Image,
219
+ subject_name: str,
220
+ target_prompt: str,
221
+ attn_enforce: float,
222
+ ctrl_scale: float,
223
+ width: int,
224
+ height: int,
225
+ pixel_offset: int,
226
+ num_steps: int,
227
+ guidance: float,
228
+ progress=gr.Progress(track_tqdm=True)
229
+ ):
230
+ if input_image is None: raise gr.Error("Please upload a reference image.")
231
+ if not subject_name: raise gr.Error("Please provide the subject's name (e.g., 'a red car').")
232
+ if not target_prompt: raise gr.Error("Please provide a target prompt.")
233
+
234
+ # 1. Prepare dimensions (logic from original script's main block)
235
+ padded_width = width + pixel_offset * 2
236
+ padded_height = height + pixel_offset * 2
237
+ diptych_size = (padded_width * 2, padded_height)
238
+
239
+ # 2. Prepare prompts and images
240
+ progress(0, desc="Resizing and segmenting reference image...")
241
+ base_prompt = f"a photo of {subject_name}"
242
+ diptych_text_prompt = f"A diptych with two side-by-side images of same {subject_name}. On the left, {base_prompt}. On the right, replicate this {subject_name} exactly but as {target_prompt}"
243
+
244
+ reference_image = input_image.resize((padded_width, padded_height)).convert("RGB")
245
+ segmented_image = segment_image(reference_image, subject_name, object_detector, segmentator, segment_processor)
246
+
247
+ progress(0.2, desc="Creating diptych and mask...")
248
+ mask_image = np.concatenate([np.zeros((padded_height, padded_width, 3)), np.ones((padded_height, padded_width, 3)) * 255], axis=1)
249
+ mask_image = Image.fromarray(mask_image.astype(np.uint8))
250
+ diptych_image_prompt = make_diptych(segmented_image)
251
+
252
+ # 3. Setup Attention Processor (logic from original script's main block)
253
+ progress(0.3, desc="Setting up attention processors...")
254
+ new_attn_procs = base_attn_procs.copy()
255
+ for k in new_attn_procs:
256
+ # Use full diptych dimensions for the attention processor
257
+ new_attn_procs[k] = CustomFluxAttnProcessor2_0(height=padded_height // 16, width=padded_width * 2 // 16, attn_enforce=attn_enforce)
258
+ pipe.transformer.set_attn_processor(new_attn_procs)
259
+
260
+ # 4. Run Inference (using parameters identical to the original script)
261
+ progress(0.4, desc="Running diffusion process...")
262
+ result = pipe(
263
+ prompt=diptych_text_prompt,
264
+ height=diptych_size[1],
265
+ width=diptych_size[0],
266
+ control_image=diptych_image_prompt,
267
+ control_mask=mask_image,
268
+ num_inference_steps=num_steps,
269
+ generator=generator,
270
+ controlnet_conditioning_scale=ctrl_scale,
271
+ guidance_scale=guidance, # This is used for guidance embeds if enabled
272
+ negative_prompt="",
273
+ true_guidance_scale=guidance # **CRITICAL FIX**: This matches the original script's CFG scale
274
+ ).images[0]
275
+
276
+ # 5. Final cropping (logic from original script's main block)
277
+ progress(0.95, desc="Finalizing image...")
278
+ # Crop the right panel
279
+ result = result.crop((padded_width, 0, padded_width * 2, padded_height))
280
+ # Crop the pixel offset padding
281
+ result = result.crop((pixel_offset, pixel_offset, padded_width - pixel_offset, padded_height - pixel_offset))
282
+
283
+ return result
284
+
285
+
286
+ # --- Gradio UI Definition ---
287
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
288
+ gr.Markdown(
289
+ """
290
+ # Diptych Prompting: Zero-Shot Subject-Driven Image Generation
291
+ ### Official Gradio Demo for the paper "[Large-Scale Text-to-Image Model with Inpainting is a Zero-Shot Subject-Driven Image Generator](https://diptychprompting.github.io/)"
292
+ **Instructions:**
293
+ 1. Upload a clear image of a subject.
294
+ 2. Describe the subject in the "Subject Name" box (e.g., 'a plush bear', 'a red sports car').
295
+ 3. Describe the new scene for your subject in the "Target Prompt" box.
296
+
297
+ > **Note:** This demo runs on a powerful GPU and requires over **40GB of VRAM**. Inference may take 1-2 minutes.
298
+ """
299
+ )
300
+ with gr.Row():
301
+ with gr.Column(scale=1):
302
+ input_image = gr.Image(type="pil", label="1. Reference Image")
303
+ subject_name = gr.Textbox(label="2. Subject Name", placeholder="e.g., a plush bear")
304
+ target_prompt = gr.Textbox(label="3. Target Prompt", placeholder="e.g., riding a skateboard on the moon")
305
+ run_button = gr.Button("Generate Image", variant="primary")
306
+ with gr.Accordion("Advanced Settings", open=False):
307
+ attn_enforce = gr.Slider(minimum=1.0, maximum=2.0, value=1.3, step=0.05, label="Attention Enforcement")
308
+ ctrl_scale = gr.Slider(minimum=0.5, maximum=1.0, value=0.95, step=0.01, label="ControlNet Scale")
309
+ num_steps = gr.Slider(minimum=20, maximum=50, value=30, step=1, label="Inference Steps")
310
+ guidance = gr.Slider(minimum=1.0, maximum=10.0, value=3.5, step=0.1, label="Guidance Scale")
311
+ width = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Width")
312
+ height = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Height")
313
+ pixel_offset = gr.Slider(minimum=0, maximum=32, value=8, step=1, label="Padding (Pixel Offset)")
314
+ with gr.Column(scale=1):
315
+ output_image = gr.Image(type="pil", label="Generated Image")
316
+
317
+ gr.Examples(
318
+ examples=[
319
+ ["./assets/bear_plushie.jpg", "a bear plushie", "a bear plushie riding a skateboard"],
320
+ ["./assets/corgi.jpg", "a corgi dog", "a corgi dog wearing a superhero cape and flying"],
321
+ ["./assets/teapot.png", "a blue and white teapot", "a blue and white teapot in a field of flowers"],
322
+ ],
323
+ inputs=[input_image, subject_name, target_prompt],
324
+ outputs=output_image,
325
+ fn=run_diptych_prompting,
326
+ cache_examples=False,
327
+ )
328
+
329
+ run_button.click(
330
+ fn=run_diptych_prompting,
331
+ inputs=[input_image, subject_name, target_prompt, attn_enforce, ctrl_scale, width, height, pixel_offset, num_steps, guidance],
332
+ outputs=output_image
333
+ )
334
+
335
+ if __name__ == "__main__":
336
+ if not os.path.exists("./assets"):
337
+ os.makedirs("./assets")
338
+ print("Created './assets' directory. Please add example images like 'bear_plushie.jpg' there for the examples to work.")
339
+
340
+ demo.launch(share=True)