multimodalart's picture
Update app.py
0930d5f verified
raw
history blame
16.6 kB
import spaces
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import cv2
import os
from diffusers.utils import load_image, check_min_version
from controlnet_flux import FluxControlNetModel
from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
from diffusers.models.attention_processor import Attention
from transformers import AutoProcessor, AutoModelForMaskGeneration, pipeline
from dataclasses import dataclass
from typing import Any, List, Dict, Optional, Union, Tuple
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- Helper Dataclasses (Identical to diptych_prompting_inference.py) ---
@dataclass
class BoundingBox:
xmin: int
ymin: int
xmax: int
ymax: int
@property
def xyxy(self) -> List[float]:
return [self.xmin, self.ymin, self.xmax, self.ymax]
@dataclass
class DetectionResult:
score: float
label: str
box: BoundingBox
mask: Optional[np.array] = None
@classmethod
def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
return cls(score=detection_dict['score'],
label=detection_dict['label'],
box=BoundingBox(xmin=detection_dict['box']['xmin'],
ymin=detection_dict['box']['ymin'],
xmax=detection_dict['box']['xmax'],
ymax=detection_dict['box']['ymax']))
# --- Helper Functions (Identical to diptych_prompting_inference.py) ---
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return []
largest_contour = max(contours, key=cv2.contourArea)
return largest_contour.reshape(-1, 2).tolist()
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
mask = np.zeros(image_shape, dtype=np.uint8)
if not polygon:
return mask
pts = np.array(polygon, dtype=np.int32)
cv2.fillPoly(mask, [pts], color=(255,))
return mask
def get_boxes(results: List[DetectionResult]) -> List[List[List[float]]]:
boxes = [result.box.xyxy for result in results]
return [boxes]
def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
masks = masks.cpu().float().permute(0, 2, 3, 1).mean(axis=-1)
masks = (masks > 0).int().numpy().astype(np.uint8)
masks = list(masks)
if polygon_refinement:
for idx, mask in enumerate(masks):
shape = mask.shape
polygon = mask_to_polygon(mask)
refined_mask = polygon_to_mask(polygon, shape)
masks[idx] = refined_mask
return masks
def detect(
object_detector, image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None
) -> List[DetectionResult]:
labels = [label if label.endswith(".") else label + "." for label in labels]
results = object_detector(image, candidate_labels=labels, threshold=threshold)
return [DetectionResult.from_dict(result) for result in results]
def segment(
segmentator, processor, image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False
) -> List[DetectionResult]:
if not detection_results:
return []
boxes = get_boxes(detection_results)
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)
with torch.no_grad():
outputs = segmentator(**inputs)
masks = processor.post_process_masks(
masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes
)[0]
masks = refine_masks(masks, polygon_refinement)
for detection_result, mask in zip(detection_results, masks):
detection_result.mask = mask
return detection_results
def grounded_segmentation(
detect_pipeline, segmentator, segment_processor, image: Image.Image, labels: List[str],
) -> Tuple[np.ndarray, List[DetectionResult]]:
detections = detect(detect_pipeline, image, labels, threshold=0.3)
detections = segment(segmentator, segment_processor, image, detections, polygon_refinement=True)
return np.array(image), detections
def segment_image(image, object_name, detector, segmentator, seg_processor):
image_array, detections = grounded_segmentation(detector, segmentator, seg_processor, image, [object_name])
if not detections or detections[0].mask is None:
raise gr.Error(f"Could not segment the subject '{object_name}' in the image. Please try a clearer image or a more specific subject name.")
mask_expanded = np.expand_dims(detections[0].mask / 255, axis=-1)
segment_result = image_array * mask_expanded + np.ones_like(image_array) * (1 - mask_expanded) * 255
return Image.fromarray(segment_result.astype(np.uint8))
def make_diptych(image):
ref_image_np = np.array(image)
diptych_np = np.concatenate([ref_image_np, np.zeros_like(ref_image_np)], axis=1)
return Image.fromarray(diptych_np)
# --- Custom Attention Processor (EXACTLY as in diptych_prompting_inference.py) ---
class CustomFluxAttnProcessor2_0:
def __init__(self, height=44, width=88, attn_enforce=1.0):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.height = height
self.width = width
self.num_pixels = height * width
self.step = 0
self.attn_enforce = attn_enforce
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
self.step += 1
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim, head_dim = key.shape[-1], key.shape[-1] // attn.heads
query, key, value = [x.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) for x in [query, key, value]]
if attn.norm_q is not None: query = attn.norm_q(query)
if attn.norm_k is not None: key = attn.norm_k(key)
if encoder_hidden_states is not None:
encoder_q = attn.add_q_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
encoder_k = attn.add_k_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
encoder_v = attn.add_v_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_added_q is not None: encoder_q = attn.norm_added_q(encoder_q)
if attn.norm_added_k is not None: encoder_k = attn.norm_added_k(encoder_k)
query, key, value = [torch.cat([e, x], dim=2) for e, x in zip([encoder_q, encoder_k, encoder_v], [query, key, value])]
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if self.attn_enforce != 1.0:
attn_probs = (torch.einsum('bhqd,bhkd->bhqk', query, key) * attn.scale).softmax(dim=-1)
img_attn_probs = attn_probs[:, :, -self.num_pixels:, -self.num_pixels:].reshape((batch_size, attn.heads, self.height, self.width, self.height, self.width))
img_attn_probs[:, :, :, self.width//2:, :, :self.width//2] *= self.attn_enforce
img_attn_probs = img_attn_probs.reshape((batch_size, attn.heads, self.num_pixels, self.num_pixels))
attn_probs[:, :, -self.num_pixels:, -self.num_pixels:] = img_attn_probs
hidden_states = torch.einsum('bhqk,bhkd->bhqd', attn_probs, value)
else:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
if encoder_hidden_states is not None:
encoder_hs, hs = hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :]
hs = attn.to_out[0](hs)
hs = attn.to_out[1](hs)
encoder_hs = attn.to_add_out(encoder_hs)
return hs, encoder_hs
else:
return hidden_states
# --- Model Loading (executed once at startup) ---
print("--- Loading Models: This may take a few minutes and requires >40GB VRAM ---")
controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16)
pipe = FluxControlNetInpaintingPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16).to(device)
pipe.transformer.to(torch.bfloat16)
pipe.controlnet.to(torch.bfloat16)
base_attn_procs = pipe.transformer.attn_processors.copy()
print("Loading segmentation models...")
detector_id, segmenter_id = "IDEA-Research/grounding-dino-tiny", "facebook/sam-vit-base"
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
segment_processor = AutoProcessor.from_pretrained(segmenter_id)
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
print("--- All models loaded successfully! ---")
# --- Main Inference Function for Gradio ---
@spaces.GPU(duration=70)
def run_diptych_prompting(
input_image: Image.Image,
subject_name: str,
target_prompt: str,
attn_enforce: float = 1.3,
ctrl_scale: float = 0.95,
width: int = 768,
height: int = 768,
pixel_offset: int = 8,
num_steps: int = 30,
guidance: float = 3.5,
seed: int = 42,
randomize_seed: bool = False,
progress=gr.Progress(track_tqdm=True)
):
if randomize_seed:
actual_seed = random.randint(0, 9223372036854775807)
else:
actual_seed = seed
if input_image is None: raise gr.Error("Please upload a reference image.")
if not subject_name: raise gr.Error("Please provide the subject's name (e.g., 'a red car').")
if not target_prompt: raise gr.Error("Please provide a target prompt.")
# 1. Prepare dimensions (logic from original script's main block)
padded_width = width + pixel_offset * 2
padded_height = height + pixel_offset * 2
diptych_size = (padded_width * 2, padded_height)
# 2. Prepare prompts and images
progress(0, desc="Resizing and segmenting reference image...")
base_prompt = f"a photo of {subject_name}"
diptych_text_prompt = f"A diptych with two side-by-side images of same {subject_name}. On the left, {base_prompt}. On the right, replicate this {subject_name} exactly but as {target_prompt}"
reference_image = input_image.resize((padded_width, padded_height)).convert("RGB")
segmented_image = segment_image(reference_image, subject_name, object_detector, segmentator, segment_processor)
progress(0.2, desc="Creating diptych and mask...")
mask_image = np.concatenate([np.zeros((padded_height, padded_width, 3)), np.ones((padded_height, padded_width, 3)) * 255], axis=1)
mask_image = Image.fromarray(mask_image.astype(np.uint8))
diptych_image_prompt = make_diptych(segmented_image)
# 3. Setup Attention Processor (logic from original script's main block)
progress(0.3, desc="Setting up attention processors...")
new_attn_procs = base_attn_procs.copy()
for k in new_attn_procs:
# Use full diptych dimensions for the attention processor
new_attn_procs[k] = CustomFluxAttnProcessor2_0(height=padded_height // 16, width=padded_width * 2 // 16, attn_enforce=attn_enforce)
pipe.transformer.set_attn_processor(new_attn_procs)
# 4. Run Inference (using parameters identical to the original script)
progress(0.4, desc="Running diffusion process...")
generator = torch.Generator(device="cuda").manual_seed(actual_seed)
result = pipe(
prompt=diptych_text_prompt,
height=diptych_size[1],
width=diptych_size[0],
control_image=diptych_image_prompt,
control_mask=mask_image,
num_inference_steps=num_steps,
generator=generator,
controlnet_conditioning_scale=ctrl_scale,
guidance_scale=guidance, # This is used for guidance embeds if enabled
negative_prompt="",
true_guidance_scale=guidance # **CRITICAL FIX**: This matches the original script's CFG scale
).images[0]
# 5. Final cropping (logic from original script's main block)
progress(0.95, desc="Finalizing image...")
# Crop the right panel
result = result.crop((padded_width, 0, padded_width * 2, padded_height))
# Crop the pixel offset padding
result = result.crop((pixel_offset, pixel_offset, padded_width - pixel_offset, padded_height - pixel_offset))
return result
# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Diptych Prompting: Zero-Shot Subject-Driven Image Generation
### Official Gradio Demo for the paper "[Large-Scale Text-to-Image Model with Inpainting is a Zero-Shot Subject-Driven Image Generator](https://diptychprompting.github.io/)"
**Instructions:**
1. Upload a clear image of a subject.
2. Describe the subject in the "Subject Name" box (e.g., 'a plush bear', 'a red sports car').
3. Describe the new scene for your subject in the "Target Prompt" box.
> **Note:** This demo runs on a powerful GPU and requires over **40GB of VRAM**. Inference may take 1-2 minutes.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="1. Reference Image")
subject_name = gr.Textbox(label="2. Subject Name", placeholder="e.g., a plush bear")
target_prompt = gr.Textbox(label="3. Target Prompt", placeholder="e.g., riding a skateboard on the moon")
run_button = gr.Button("Generate Image", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
attn_enforce = gr.Slider(minimum=1.0, maximum=2.0, value=1.3, step=0.05, label="Attention Enforcement")
ctrl_scale = gr.Slider(minimum=0.5, maximum=1.0, value=0.95, step=0.01, label="ControlNet Scale")
num_steps = gr.Slider(minimum=20, maximum=50, value=30, step=1, label="Inference Steps")
guidance = gr.Slider(minimum=1.0, maximum=10.0, value=3.5, step=0.1, label="Guidance Scale")
width = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Width")
height = gr.Slider(minimum=512, maximum=1024, value=768, step=64, label="Image Height")
pixel_offset = gr.Slider(minimum=0, maximum=32, value=8, step=1, label="Padding (Pixel Offset)")
seed = gr.Slider(minimum=0, maximum=9223372036854775807, value=42, step=1, label="Seed")
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="Generated Image")
gr.Examples(
examples=[
["./assets/bear_plushie.jpg", "a bear plushie", "a bear plushie riding a skateboard"],
["./assets/corgi.jpg", "a corgi dog", "a corgi dog wearing a superhero cape and flying"],
["./assets/teapot.png", "a blue and white teapot", "a blue and white teapot in a field of flowers"],
],
inputs=[input_image, subject_name, target_prompt],
outputs=output_image,
fn=run_diptych_prompting,
cache_examples="lazy",
)
run_button.click(
fn=run_diptych_prompting,
inputs=[input_image, subject_name, target_prompt, attn_enforce, ctrl_scale, width, height, pixel_offset, num_steps, guidance, seed, randomize_seed],
outputs=output_image
)
if __name__ == "__main__":
if not os.path.exists("./assets"):
os.makedirs("./assets")
print("Created './assets' directory. Please add example images like 'bear_plushie.jpg' there for the examples to work.")
demo.launch(share=True)