Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved. | |
# | |
# This work is licensed under the LICENSE file | |
# located at the root directory. | |
import os | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from PIL import Image | |
import tempfile | |
import gc | |
from datetime import datetime | |
from addit_flux_pipeline import AdditFluxPipeline | |
from addit_flux_transformer import AdditFluxTransformer2DModel | |
from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler | |
from addit_methods import add_object_generated, add_object_real | |
# Global variables for model | |
pipe = None | |
device = None | |
# Initialize model at startup | |
print("Initializing ADDIT model...") | |
try: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Load transformer | |
my_transformer = AdditFluxTransformer2DModel.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
subfolder="transformer", | |
torch_dtype=torch.bfloat16 | |
) | |
# Load pipeline | |
pipe = AdditFluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
transformer=my_transformer, | |
torch_dtype=torch.bfloat16 | |
).to(device) | |
# Set scheduler | |
pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) | |
print("Model initialized successfully!") | |
except Exception as e: | |
print(f"Error initializing model: {str(e)}") | |
print("The application will start but model functionality will be unavailable.") | |
def validate_inputs(prompt_source, prompt_target, subject_token): | |
"""Validate user inputs""" | |
if not prompt_source.strip(): | |
return "Source prompt cannot be empty" | |
if not prompt_target.strip(): | |
return "Target prompt cannot be empty" | |
if not subject_token.strip(): | |
return "Subject token cannot be empty" | |
if subject_token not in prompt_target: | |
return f"Subject token '{subject_token}' must appear in the target prompt" | |
return None | |
def process_generated_image( | |
prompt_source, | |
prompt_target, | |
subject_token, | |
seed_src, | |
seed_obj, | |
extended_scale, | |
structure_transfer_step, | |
blend_steps, | |
localization_model, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
"""Process generated image with ADDIT""" | |
global pipe | |
if pipe is None: | |
return None, None, "Model not initialized. Please restart the application." | |
# Validate inputs | |
error_msg = validate_inputs(prompt_source, prompt_target, subject_token) | |
if error_msg: | |
return None, None, error_msg | |
# Print current time and input information | |
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
print(f"\n[{current_time}] Starting Generated Image Processing") | |
print(f"Source Prompt: '{prompt_source}'") | |
print(f"Target Prompt: '{prompt_target}'") | |
print(f"Subject Token: '{subject_token}'") | |
print(f"Source Seed: {seed_src}, Object Seed: {seed_obj}") | |
print(f"Extended Scale: {extended_scale}, Structure Transfer Step: {structure_transfer_step}") | |
print(f"Blend Steps: '{blend_steps}', Localization Model: '{localization_model}'") | |
try: | |
# Parse blend steps | |
if blend_steps.strip(): | |
blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()] | |
else: | |
blend_steps_list = [] | |
# Generate images | |
src_image, edited_image = add_object_generated( | |
pipe=pipe, | |
prompt_source=prompt_source, | |
prompt_object=prompt_target, | |
subject_token=subject_token, | |
seed_src=seed_src, | |
seed_obj=seed_obj, | |
show_attention=False, | |
extended_scale=extended_scale, | |
structure_transfer_step=structure_transfer_step, | |
blend_steps=blend_steps_list, | |
localization_model=localization_model, | |
display_output=False | |
) | |
return src_image, edited_image, "Images generated successfully!" | |
except Exception as e: | |
error_msg = f"Error generating images: {str(e)}" | |
print(error_msg) | |
return None, None, error_msg | |
def process_real_image( | |
source_image, | |
prompt_source, | |
prompt_target, | |
subject_token, | |
seed_src, | |
seed_obj, | |
extended_scale, | |
structure_transfer_step, | |
blend_steps, | |
localization_model, | |
use_offset, | |
disable_inversion, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
"""Process real image with ADDIT""" | |
global pipe | |
if pipe is None: | |
return None, None, "Model not initialized. Please restart the application." | |
if source_image is None: | |
return None, None, "Please upload a source image" | |
# Validate inputs | |
error_msg = validate_inputs(prompt_source, prompt_target, subject_token) | |
if error_msg: | |
return None, None, error_msg | |
# Print current time and input information | |
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
print(f"\n[{current_time}] Starting Real Image Processing") | |
print(f"Source Image Size: {source_image.size}") | |
print(f"Source Prompt: '{prompt_source}'") | |
print(f"Target Prompt: '{prompt_target}'") | |
print(f"Subject Token: '{subject_token}'") | |
print(f"Source Seed: {seed_src}, Object Seed: {seed_obj}") | |
print(f"Extended Scale: {extended_scale}, Structure Transfer Step: {structure_transfer_step}") | |
print(f"Blend Steps: '{blend_steps}', Localization Model: '{localization_model}'") | |
print(f"Use Offset: {use_offset}, Disable Inversion: {disable_inversion}") | |
try: | |
# Resize source image | |
source_image = source_image.resize((1024, 1024)) | |
# Parse blend steps | |
if blend_steps.strip(): | |
blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()] | |
else: | |
blend_steps_list = [] | |
# Process image | |
src_image, edited_image = add_object_real( | |
pipe=pipe, | |
source_image=source_image, | |
prompt_source=prompt_source, | |
prompt_object=prompt_target, | |
subject_token=subject_token, | |
seed_src=seed_src, | |
seed_obj=seed_obj, | |
extended_scale=extended_scale, | |
structure_transfer_step=structure_transfer_step, | |
blend_steps=blend_steps_list, | |
localization_model=localization_model, | |
use_offset=use_offset, | |
show_attention=False, | |
use_inversion=not disable_inversion, | |
display_output=False | |
) | |
return src_image, edited_image, "Image edited successfully!" | |
except Exception as e: | |
error_msg = f"Error processing image: {str(e)}" | |
print(error_msg) | |
return None, None, error_msg | |
def create_interface(): | |
"""Create the Gradio interface""" | |
# Show model status in the interface | |
model_status = "Model ready!" if pipe is not None else "Model initialization failed - functionality unavailable" | |
with gr.Blocks(title="🎨 Add-it: Training-Free Object Insertion in Images With Pretrained Diffusion Models", theme=gr.themes.Soft()) as demo: | |
gr.HTML(f""" | |
<div style="text-align: center; margin-bottom: 20px;"> | |
<h1>🎨 Add-it: Training-Free Object Insertion</h1> | |
<p>Add objects to images using pretrained diffusion models</p> | |
<p><a href="https://research.nvidia.com/labs/par/addit/" target="_blank">🌐 Project Website</a> | | |
<a href="https://arxiv.org/abs/2411.07232" target="_blank">📄 Paper</a> | | |
<a href="https://github.com/NVlabs/addit" target="_blank">💻 Code</a></p> | |
<p style="color: {'green' if pipe is not None else 'red'}; font-weight: bold;">Status: {model_status}</p> | |
</div> | |
""") | |
# Main interface | |
with gr.Tabs(): | |
# Generated Images Tab | |
with gr.TabItem("🎭 Generated Images"): | |
gr.Markdown("### Generate a base image and add objects to it") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gen_prompt_source = gr.Textbox( | |
label="Source Prompt", | |
placeholder="A photo of a cat sitting on the couch", | |
value="A photo of a cat sitting on the couch" | |
) | |
gen_prompt_target = gr.Textbox( | |
label="Target Prompt", | |
placeholder="A photo of a cat wearing a blue hat sitting on the couch", | |
value="A photo of a cat wearing a blue hat sitting on the couch" | |
) | |
gen_subject_token = gr.Textbox( | |
label="Subject Token", | |
placeholder="hat", | |
value="hat", | |
info="Single token representing the object to add **(must appear in target prompt)**" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
gen_seed_src = gr.Number(label="Source Seed", value=1, precision=0) | |
gen_seed_obj = gr.Number(label="Object Seed", value=42, precision=0) | |
gen_extended_scale = gr.Slider( | |
label="Extended Scale", | |
minimum=1.0, | |
maximum=1.3, | |
value=1.05, | |
step=0.01 | |
) | |
gen_structure_transfer_step = gr.Slider( | |
label="Structure Transfer Step", | |
minimum=0, | |
maximum=10, | |
value=2, | |
step=1 | |
) | |
gen_blend_steps = gr.Textbox( | |
label="Blend Steps", | |
value="15", | |
info="Comma-separated list of steps (e.g., '15,20') or empty for no blending" | |
) | |
gen_localization_model = gr.Dropdown( | |
label="Localization Model", | |
choices=[ | |
"attention_points_sam", | |
"attention", | |
"attention_box_sam", | |
"attention_mask_sam", | |
"grounding_sam" | |
], | |
value="attention_points_sam" | |
) | |
gen_submit_btn = gr.Button("🎨 Generate & Edit", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Row(): | |
gen_src_output = gr.Image(label="Generated Source Image", type="pil") | |
gen_edited_output = gr.Image(label="Edited Image", type="pil") | |
gen_status = gr.Textbox(label="Status", interactive=False) | |
gen_submit_btn.click( | |
fn=process_generated_image, | |
inputs=[ | |
gen_prompt_source, gen_prompt_target, gen_subject_token, | |
gen_seed_src, gen_seed_obj, gen_extended_scale, | |
gen_structure_transfer_step, gen_blend_steps, | |
gen_localization_model | |
], | |
outputs=[gen_src_output, gen_edited_output, gen_status] | |
) | |
# Examples for generated images | |
gr.Examples( | |
examples=[ | |
["A photo of a man sitting on a bench", "A photo of a man sitting on a bench with a dog", "dog"], | |
["A photo of a cat sitting on the couch", "A photo of a cat wearing a blue hat sitting on the couch", "hat"], | |
["A car driving through an empty street", "A pink car driving through an empty street", "car"] | |
], | |
inputs=[ | |
gen_prompt_source, gen_prompt_target, gen_subject_token | |
], | |
label="Example Prompts" | |
) | |
# Real Images Tab | |
with gr.TabItem("📸 Real Images"): | |
gr.Markdown("### Upload an image and add objects to it") | |
gr.HTML("<p style='color: red; font-weight: bold; margin: -15px -10px;'>Note: Images will be resized to 1024x1024 pixels. For best results, use square images.</p>") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
real_source_image = gr.Image(label="Source Image", type="pil") | |
real_prompt_source = gr.Textbox( | |
label="Source Prompt", | |
placeholder="A photo of a bed in a dark room", | |
value="A photo of a bed in a dark room" | |
) | |
real_prompt_target = gr.Textbox( | |
label="Target Prompt", | |
placeholder="A photo of a dog lying on a bed in a dark room", | |
value="A photo of a dog lying on a bed in a dark room" | |
) | |
real_subject_token = gr.Textbox( | |
label="Subject Token", | |
placeholder="dog", | |
value="dog", | |
info="Single token representing the object to add **(must appear in target prompt)**" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
real_seed_src = gr.Number(label="Source Seed", value=1, precision=0) | |
real_seed_obj = gr.Number(label="Object Seed", value=0, precision=0) | |
real_extended_scale = gr.Slider( | |
label="Extended Scale", | |
minimum=1.0, | |
maximum=1.3, | |
value=1.1, | |
step=0.01 | |
) | |
real_structure_transfer_step = gr.Slider( | |
label="Structure Transfer Step", | |
minimum=0, | |
maximum=10, | |
value=4, | |
step=1 | |
) | |
real_blend_steps = gr.Textbox( | |
label="Blend Steps", | |
value="18", | |
info="Comma-separated list of steps (e.g., '15,20') or empty for no blending" | |
) | |
real_localization_model = gr.Dropdown( | |
label="Localization Model", | |
choices=[ | |
"attention", | |
"attention_points_sam", | |
"attention_box_sam", | |
"attention_mask_sam", | |
"grounding_sam" | |
], | |
value="attention" | |
) | |
real_use_offset = gr.Checkbox(label="Use Offset", value=False) | |
real_disable_inversion = gr.Checkbox(label="Disable Inversion", value=False) | |
real_submit_btn = gr.Button("🎨 Edit Image", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Row(): | |
real_src_output = gr.Image(label="Source Image", type="pil") | |
real_edited_output = gr.Image(label="Edited Image", type="pil") | |
real_status = gr.Textbox(label="Status", interactive=False) | |
real_submit_btn.click( | |
fn=process_real_image, | |
inputs=[ | |
real_source_image, real_prompt_source, real_prompt_target, real_subject_token, | |
real_seed_src, real_seed_obj, real_extended_scale, | |
real_structure_transfer_step, real_blend_steps, | |
real_localization_model, real_use_offset, | |
real_disable_inversion | |
], | |
outputs=[real_src_output, real_edited_output, real_status] | |
) | |
# Examples for real images | |
gr.Examples( | |
examples=[ | |
[ | |
"images/bed_dark_room.jpg", | |
"A photo of a bed in a dark room", | |
"A photo of a dog lying on a bed in a dark room", | |
"dog" | |
], | |
[ | |
"images/flower.jpg", | |
"A photo of a flower", | |
"A bee standing on a flower", | |
"bee" | |
] | |
], | |
inputs=[ | |
real_source_image, real_prompt_source, real_prompt_target, real_subject_token | |
], | |
label="Example Images & Prompts" | |
) | |
# Tips | |
with gr.Accordion("💡 Tips for Better Results", open=False): | |
gr.Markdown(""" | |
- **Prompt Design**: The Target Prompt should be similar to the Source Prompt, but include a description of the new object to insert | |
- **Seed Variation**: Try different values for Object Seed - some prompts may require a few attempts to get satisfying results | |
- **Localization Models**: The most effective options are `attention_points_sam` and `attention`. Use Show Attention to visualize localization performance | |
- **Object Placement Issues**: If the object is not added to the image: | |
- Try **decreasing** Structure Transfer Step | |
- Try **increasing** Extended Scale | |
- **Flexibility**: To allow more flexibility in modifying the source image, leave Blend Steps empty to send an empty list | |
""") | |
return demo | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |