Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
import os | |
from diffusers.utils import load_image, check_min_version | |
from controlnet_flux import FluxControlNetModel | |
from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline | |
from diffusers.models.attention_processor import Attention | |
from transformers import AutoProcessor, AutoModelForMaskGeneration, pipeline | |
from dataclasses import dataclass | |
from typing import Any, List, Dict, Optional, Union, Tuple | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# --- Helper Dataclasses (Identical to diptych_prompting_inference.py) --- | |
class BoundingBox: | |
xmin: int | |
ymin: int | |
xmax: int | |
ymax: int | |
def xyxy(self) -> List[float]: | |
return [self.xmin, self.ymin, self.xmax, self.ymax] | |
class DetectionResult: | |
score: float | |
label: str | |
box: BoundingBox | |
mask: Optional[np.array] = None | |
def from_dict(cls, detection_dict: Dict) -> 'DetectionResult': | |
return cls(score=detection_dict['score'], | |
label=detection_dict['label'], | |
box=BoundingBox(xmin=detection_dict['box']['xmin'], | |
ymin=detection_dict['box']['ymin'], | |
xmax=detection_dict['box']['xmax'], | |
ymax=detection_dict['box']['ymax'])) | |
# --- Helper Functions (Identical to diptych_prompting_inference.py) --- | |
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]: | |
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if not contours: | |
return [] | |
largest_contour = max(contours, key=cv2.contourArea) | |
return largest_contour.reshape(-1, 2).tolist() | |
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray: | |
mask = np.zeros(image_shape, dtype=np.uint8) | |
if not polygon: | |
return mask | |
pts = np.array(polygon, dtype=np.int32) | |
cv2.fillPoly(mask, [pts], color=(255,)) | |
return mask | |
def get_boxes(results: List[DetectionResult]) -> List[List[List[float]]]: | |
boxes = [result.box.xyxy for result in results] | |
return [boxes] | |
def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]: | |
masks = masks.cpu().float().permute(0, 2, 3, 1).mean(axis=-1) | |
masks = (masks > 0).int().numpy().astype(np.uint8) | |
masks = list(masks) | |
if polygon_refinement: | |
for idx, mask in enumerate(masks): | |
shape = mask.shape | |
polygon = mask_to_polygon(mask) | |
refined_mask = polygon_to_mask(polygon, shape) | |
masks[idx] = refined_mask | |
return masks | |
def detect( | |
object_detector, image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None | |
) -> List[DetectionResult]: | |
labels = [label if label.endswith(".") else label + "." for label in labels] | |
results = object_detector(image, candidate_labels=labels, threshold=threshold) | |
return [DetectionResult.from_dict(result) for result in results] | |
def segment( | |
segmentator, processor, image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False | |
) -> List[DetectionResult]: | |
if not detection_results: | |
return [] | |
boxes = get_boxes(detection_results) | |
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = segmentator(**inputs) | |
masks = processor.post_process_masks( | |
masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes | |
)[0] | |
masks = refine_masks(masks, polygon_refinement) | |
for detection_result, mask in zip(detection_results, masks): | |
detection_result.mask = mask | |
return detection_results | |
def grounded_segmentation( | |
detect_pipeline, segmentator, segment_processor, image: Image.Image, labels: List[str], | |
) -> Tuple[np.ndarray, List[DetectionResult]]: | |
detections = detect(detect_pipeline, image, labels, threshold=0.3) | |
detections = segment(segmentator, segment_processor, image, detections, polygon_refinement=True) | |
return np.array(image), detections | |
def segment_image(image, object_name, detector, segmentator, seg_processor): | |
image_array, detections = grounded_segmentation(detector, segmentator, seg_processor, image, [object_name]) | |
if not detections or detections[0].mask is None: | |
raise gr.Error(f"Could not segment the subject '{object_name}' in the image. Please try a clearer image or a more specific subject name.") | |
mask_expanded = np.expand_dims(detections[0].mask / 255, axis=-1) | |
segment_result = image_array * mask_expanded + np.ones_like(image_array) * (1 - mask_expanded) * 255 | |
return Image.fromarray(segment_result.astype(np.uint8)) | |
def make_diptych(image): | |
ref_image_np = np.array(image) | |
diptych_np = np.concatenate([ref_image_np, np.zeros_like(ref_image_np)], axis=1) | |
return Image.fromarray(diptych_np) | |
# --- Custom Attention Processor (EXACTLY as in diptych_prompting_inference.py) --- | |
class CustomFluxAttnProcessor2_0: | |
def __init__(self, height=44, width=88, attn_enforce=1.0): | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | |
self.height = height | |
self.width = width | |
self.num_pixels = height * width | |
self.step = 0 | |
self.attn_enforce = attn_enforce | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: torch.FloatTensor = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
image_rotary_emb: Optional[torch.Tensor] = None, | |
) -> torch.FloatTensor: | |
self.step += 1 | |
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
query = attn.to_q(hidden_states) | |
key = attn.to_k(hidden_states) | |
value = attn.to_v(hidden_states) | |
inner_dim, head_dim = key.shape[-1], key.shape[-1] // attn.heads | |
query, key, value = [x.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) for x in [query, key, value]] | |
if attn.norm_q is not None: query = attn.norm_q(query) | |
if attn.norm_k is not None: key = attn.norm_k(key) | |
if encoder_hidden_states is not None: | |
encoder_q = attn.add_q_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
encoder_k = attn.add_k_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
encoder_v = attn.add_v_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
if attn.norm_added_q is not None: encoder_q = attn.norm_added_q(encoder_q) | |
if attn.norm_added_k is not None: encoder_k = attn.norm_added_k(encoder_k) | |
query, key, value = [torch.cat([e, x], dim=2) for e, x in zip([encoder_q, encoder_k, encoder_v], [query, key, value])] | |
if image_rotary_emb is not None: | |
from diffusers.models.embeddings import apply_rotary_emb | |
query = apply_rotary_emb(query, image_rotary_emb) | |
key = apply_rotary_emb(key, image_rotary_emb) | |
if self.attn_enforce != 1.0: | |
attn_probs = (torch.einsum('bhqd,bhkd->bhqk', query, key) * attn.scale).softmax(dim=-1) | |
img_attn_probs = attn_probs[:, :, -self.num_pixels:, -self.num_pixels:].reshape((batch_size, attn.heads, self.height, self.width, self.height, self.width)) | |
img_attn_probs[:, :, :, self.width//2:, :, :self.width//2] *= self.attn_enforce | |
img_attn_probs = img_attn_probs.reshape((batch_size, attn.heads, self.num_pixels, self.num_pixels)) | |
attn_probs[:, :, -self.num_pixels:, -self.num_pixels:] = img_attn_probs | |
hidden_states = torch.einsum('bhqk,bhkd->bhqd', attn_probs, value) | |
else: | |
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype) | |
if encoder_hidden_states is not None: | |
encoder_hs, hs = hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :] | |
hs = attn.to_out[0](hs) | |
hs = attn.to_out[1](hs) | |
encoder_hs = attn.to_add_out(encoder_hs) | |
return hs, encoder_hs | |
else: | |
return hidden_states | |
# --- Model Loading (executed once at startup) --- | |
print("--- Loading Models: This may take a few minutes and requires >40GB VRAM ---") | |
controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16) | |
pipe = FluxControlNetInpaintingPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16).to(device) | |
pipe.transformer.to(torch.bfloat16) | |
pipe.controlnet.to(torch.bfloat16) | |
base_attn_procs = pipe.transformer.attn_processors.copy() | |
print("Loading segmentation models...") | |
detector_id, segmenter_id = "IDEA-Research/grounding-dino-tiny", "facebook/sam-vit-base" | |
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device) | |
segment_processor = AutoProcessor.from_pretrained(segmenter_id) | |
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device) | |
print("--- All models loaded successfully! ---") | |
# --- Main Inference Function for Gradio --- | |
def run_diptych_prompting( | |
input_image: Image.Image, | |
subject_name: str, | |
target_prompt: str, | |
attn_enforce: float = 1.3, | |
ctrl_scale: float = 0.95, | |
width: int = 768, | |
height: int = 768, | |
pixel_offset: int = 8, | |
num_steps: int = 30, | |
guidance: float = 3.5, | |
seed: int = 42, | |
randomize_seed: bool = False, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
if randomize_seed: | |
actual_seed = random.randint(0, 9223372036854775807) | |
else: | |
actual_seed = seed | |
if input_image is None: raise gr.Error("Please upload a reference image.") | |
if not subject_name: raise gr.Error("Please provide the subject's name (e.g., 'a red car').") | |
if not target_prompt: raise gr.Error("Please provide a target prompt.") | |
# 1. Prepare dimensions (logic from original script's main block) | |
padded_width = width + pixel_offset * 2 | |
padded_height = height + pixel_offset * 2 | |
diptych_size = (padded_width * 2, padded_height) | |
# 2. Prepare prompts and images | |
progress(0, desc="Resizing and segmenting reference image...") | |
base_prompt = f"a photo of {subject_name}" | |
diptych_text_prompt = f"A diptych with two side-by-side images of same {subject_name}. On the left, {base_prompt}. On the right, replicate this {subject_name} exactly but as {target_prompt}" | |
reference_image = input_image.resize((padded_width, padded_height)).convert("RGB") | |
segmented_image = segment_image(reference_image, subject_name, object_detector, segmentator, segment_processor) | |
progress(0.2, desc="Creating diptych and mask...") | |
mask_image = np.concatenate([np.zeros((padded_height, padded_width, 3)), np.ones((padded_height, padded_width, 3)) * 255], axis=1) | |
mask_image = Image.fromarray(mask_image.astype(np.uint8)) | |
diptych_image_prompt = make_diptych(segmented_image) | |
# 3. Setup Attention Processor (logic from original script's main block) | |
progress(0.3, desc="Setting up attention processors...") | |
new_attn_procs = base_attn_procs.copy() | |
for k in new_attn_procs: | |
# Use full diptych dimensions for the attention processor | |
new_attn_procs[k] = CustomFluxAttnProcessor2_0(height=padded_height // 16, width=padded_width * 2 // 16, attn_enforce=attn_enforce) | |
pipe.transformer.set_attn_processor(new_attn_procs) | |
# 4. Run Inference (using parameters identical to the original script) | |
progress(0.4, desc="Running diffusion process...") | |
generator = torch.Generator(device="cuda").manual_seed(actual_seed) | |
result = pipe( | |
prompt=diptych_text_prompt, | |
height=diptych_size[1], | |
width=diptych_size[0], | |
control_image=diptych_image_prompt, | |
control_mask=mask_image, | |
num_inference_steps=num_steps, | |
generator=generator, | |
controlnet_conditioning_scale=ctrl_scale, | |
guidance_scale=guidance, # This is used for guidance embeds if enabled | |
negative_prompt="", | |
true_guidance_scale=guidance # **CRITICAL FIX**: This matches the original script's CFG scale | |
).images[0] | |
# 5. Final cropping (logic from original script's main block) | |
progress(0.95, desc="Finalizing image...") | |
# Crop the right panel | |
result = result.crop((padded_width, 0, padded_width * 2, padded_height)) | |
# Crop the pixel offset padding | |
result = result.crop((pixel_offset, pixel_offset, padded_width - pixel_offset, padded_height - pixel_offset)) | |
return result | |
# --- Gradio UI Definition --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# Diptych Prompting: Zero-Shot Subject-Driven Image Generation | |
### Official Gradio Demo for the paper "[Large-Scale Text-to-Image Model with Inpainting is a Zero-Shot Subject-Driven Image Generator](https://diptychprompting.github.io/)" | |
**Instructions:** | |
1. Upload a clear image of a subject. | |
2. Describe the subject in the "Subject Name" box (e.g., 'a plush bear', 'a red sports car'). | |
3. Describe the new scene for your subject in the "Target Prompt" box. | |
> **Note:** This demo runs on a powerful GPU and requires over **40GB of VRAM**. Inference may take 1-2 minutes. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(type="pil", label="1. Reference Image") | |
subject_name = gr.Textbox(label="2. Subject Name", placeholder="e.g., a plush bear") | |
target_prompt = gr.Textbox(label="3. Target Prompt", placeholder="e.g., riding a skateboard on the moon") | |
run_button = gr.Button("Generate Image", variant="primary") | |
with gr.Accordion("Advanced Settings", open=False): | |
attn_enforce = gr.Slider(minimum=1.0, maximum=2.0, value=1.3, step=0.05, label="Attention Enforcement") | |
ctrl_scale = gr.Slider(minimum=0.5, maximum=1.0, value=0.95, step=0.01, label="ControlNet Scale") | |
num_steps = gr.Slider(minimum=20, maximum=50, value=30, step=1, label="Inference Steps") | |
guidance = gr.Slider(minimum=1.0, maximum=10.0, value=3.5, step=0.1, label="Guidance Scale") | |
width = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Width") | |
height = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Height") | |
pixel_offset = gr.Slider(minimum=0, maximum=32, value=8, step=1, label="Padding (Pixel Offset)") | |
seed = gr.Slider(minimum=0, maximum=9223372036854775807, value=42, step=1, label="Seed") | |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
with gr.Column(scale=1): | |
output_image = gr.Image(type="pil", label="Generated Image") | |
gr.Examples( | |
examples=[ | |
["./assets/bear_plushie.jpg", "a bear plushie", "a bear plushie riding a skateboard"], | |
["./assets/corgi.jpg", "a corgi dog", "a corgi dog wearing a superhero cape and flying"], | |
["./assets/teapot.png", "a blue and white teapot", "a blue and white teapot in a field of flowers"], | |
], | |
inputs=[input_image, subject_name, target_prompt], | |
outputs=output_image, | |
fn=run_diptych_prompting, | |
cache_examples="lazy", | |
) | |
run_button.click( | |
fn=run_diptych_prompting, | |
inputs=[input_image, subject_name, target_prompt, attn_enforce, ctrl_scale, width, height, pixel_offset, num_steps, guidance, seed, randomize_seed], | |
outputs=output_image | |
) | |
if __name__ == "__main__": | |
if not os.path.exists("./assets"): | |
os.makedirs("./assets") | |
print("Created './assets' directory. Please add example images like 'bear_plushie.jpg' there for the examples to work.") | |
demo.launch(share=True) |