addit / app.py
YoadTew's picture
Add application file
504c7e8
raw
history blame
17.8 kB
#!/usr/bin/env python3
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the LICENSE file
# located at the root directory.
import os
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image
import tempfile
import gc
from addit_flux_pipeline import AdditFluxPipeline
from addit_flux_transformer import AdditFluxTransformer2DModel
from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler
from addit_methods import add_object_generated, add_object_real
# Global variables for model
pipe = None
device = None
# Initialize model at startup
print("Initializing ADDIT model...")
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load transformer
my_transformer = AdditFluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
torch_dtype=torch.bfloat16
)
# Load pipeline
pipe = AdditFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=my_transformer,
torch_dtype=torch.bfloat16
).to(device)
# Set scheduler
pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
print("Model initialized successfully!")
except Exception as e:
print(f"Error initializing model: {str(e)}")
print("The application will start but model functionality will be unavailable.")
def validate_inputs(prompt_source, prompt_target, subject_token):
"""Validate user inputs"""
if not prompt_source.strip():
return "Source prompt cannot be empty"
if not prompt_target.strip():
return "Target prompt cannot be empty"
if not subject_token.strip():
return "Subject token cannot be empty"
if subject_token not in prompt_target:
return f"Subject token '{subject_token}' must appear in the target prompt"
return None
@spaces.GPU
def process_generated_image(
prompt_source,
prompt_target,
subject_token,
seed_src,
seed_obj,
extended_scale,
structure_transfer_step,
blend_steps,
localization_model,
progress=gr.Progress(track_tqdm=True)
):
"""Process generated image with ADDIT"""
global pipe
if pipe is None:
return None, None, "Model not initialized. Please restart the application."
# Validate inputs
error_msg = validate_inputs(prompt_source, prompt_target, subject_token)
if error_msg:
return None, None, error_msg
try:
# Parse blend steps
if blend_steps.strip():
blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()]
else:
blend_steps_list = []
# Generate images
src_image, edited_image = add_object_generated(
pipe=pipe,
prompt_source=prompt_source,
prompt_object=prompt_target,
subject_token=subject_token,
seed_src=seed_src,
seed_obj=seed_obj,
show_attention=False,
extended_scale=extended_scale,
structure_transfer_step=structure_transfer_step,
blend_steps=blend_steps_list,
localization_model=localization_model,
display_output=False
)
return src_image, edited_image, "Images generated successfully!"
except Exception as e:
error_msg = f"Error generating images: {str(e)}"
print(error_msg)
return None, None, error_msg
@spaces.GPU
def process_real_image(
source_image,
prompt_source,
prompt_target,
subject_token,
seed_src,
seed_obj,
extended_scale,
structure_transfer_step,
blend_steps,
localization_model,
use_offset,
disable_inversion,
progress=gr.Progress(track_tqdm=True)
):
"""Process real image with ADDIT"""
global pipe
if pipe is None:
return None, None, "Model not initialized. Please restart the application."
if source_image is None:
return None, None, "Please upload a source image"
# Validate inputs
error_msg = validate_inputs(prompt_source, prompt_target, subject_token)
if error_msg:
return None, None, error_msg
try:
# Resize source image
source_image = source_image.resize((1024, 1024))
# Parse blend steps
if blend_steps.strip():
blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()]
else:
blend_steps_list = []
# Process image
src_image, edited_image = add_object_real(
pipe=pipe,
source_image=source_image,
prompt_source=prompt_source,
prompt_object=prompt_target,
subject_token=subject_token,
seed_src=seed_src,
seed_obj=seed_obj,
extended_scale=extended_scale,
structure_transfer_step=structure_transfer_step,
blend_steps=blend_steps_list,
localization_model=localization_model,
use_offset=use_offset,
show_attention=False,
use_inversion=not disable_inversion,
display_output=False
)
return src_image, edited_image, "Image edited successfully!"
except Exception as e:
error_msg = f"Error processing image: {str(e)}"
print(error_msg)
return None, None, error_msg
def create_interface():
"""Create the Gradio interface"""
# Show model status in the interface
model_status = "Model ready!" if pipe is not None else "Model initialization failed - functionality unavailable"
with gr.Blocks(title="🎨 Add-it: Training-Free Object Insertion in Images With Pretrained Diffusion Models", theme=gr.themes.Soft()) as demo:
gr.HTML(f"""
<div style="text-align: center; margin-bottom: 20px;">
<h1>🎨 Add-it: Training-Free Object Insertion</h1>
<p>Add objects to images using pretrained diffusion models</p>
<p><a href="https://research.nvidia.com/labs/par/addit/" target="_blank">🌐 Project Website</a> |
<a href="https://arxiv.org/abs/2411.07232" target="_blank">📄 Paper</a> |
<a href="https://github.com/NVlabs/addit" target="_blank">💻 Code</a></p>
<p style="color: {'green' if pipe is not None else 'red'}; font-weight: bold;">Status: {model_status}</p>
</div>
""")
# Main interface
with gr.Tabs():
# Generated Images Tab
with gr.TabItem("🎭 Generated Images"):
gr.Markdown("### Generate a base image and add objects to it")
with gr.Row():
with gr.Column(scale=1):
gen_prompt_source = gr.Textbox(
label="Source Prompt",
placeholder="A photo of a cat sitting on the couch",
value="A photo of a cat sitting on the couch"
)
gen_prompt_target = gr.Textbox(
label="Target Prompt",
placeholder="A photo of a cat wearing a red hat sitting on the couch",
value="A photo of a cat wearing a red hat sitting on the couch"
)
gen_subject_token = gr.Textbox(
label="Subject Token",
placeholder="hat",
value="hat",
info="Single token representing the object to add **(must appear in target prompt)**"
)
with gr.Accordion("Advanced Settings", open=False):
gen_seed_src = gr.Number(label="Source Seed", value=6311, precision=0)
gen_seed_obj = gr.Number(label="Object Seed", value=1, precision=0)
gen_extended_scale = gr.Slider(
label="Extended Scale",
minimum=1.0,
maximum=1.3,
value=1.05,
step=0.01
)
gen_structure_transfer_step = gr.Slider(
label="Structure Transfer Step",
minimum=0,
maximum=10,
value=2,
step=1
)
gen_blend_steps = gr.Textbox(
label="Blend Steps",
value="15",
info="Comma-separated list of steps (e.g., '15,20') or empty for no blending"
)
gen_localization_model = gr.Dropdown(
label="Localization Model",
choices=[
"attention_points_sam",
"attention",
"attention_box_sam",
"attention_mask_sam",
"grounding_sam"
],
value="attention_points_sam"
)
gen_submit_btn = gr.Button("🎨 Generate & Edit", variant="primary")
with gr.Column(scale=2):
with gr.Row():
gen_src_output = gr.Image(label="Generated Source Image", type="pil")
gen_edited_output = gr.Image(label="Edited Image", type="pil")
gen_status = gr.Textbox(label="Status", interactive=False)
gen_submit_btn.click(
fn=process_generated_image,
inputs=[
gen_prompt_source, gen_prompt_target, gen_subject_token,
gen_seed_src, gen_seed_obj, gen_extended_scale,
gen_structure_transfer_step, gen_blend_steps,
gen_localization_model
],
outputs=[gen_src_output, gen_edited_output, gen_status]
)
# Examples for generated images
gr.Examples(
examples=[
["A photo of a man sitting on a bench", "A photo of a man sitting on a bench with a dog", "dog"],
["A photo of a cat sitting on the couch", "A photo of a cat wearing a red hat sitting on the couch", "hat"],
["A car driving through an empty street", "A pink car driving through an empty street", "car"]
],
inputs=[
gen_prompt_source, gen_prompt_target, gen_subject_token
],
label="Example Prompts"
)
# Real Images Tab
with gr.TabItem("📸 Real Images"):
gr.Markdown("### Upload an image and add objects to it")
with gr.Row():
with gr.Column(scale=1):
real_source_image = gr.Image(label="Source Image", type="pil")
real_prompt_source = gr.Textbox(
label="Source Prompt",
placeholder="A photo of a bed in a dark room",
value="A photo of a bed in a dark room"
)
real_prompt_target = gr.Textbox(
label="Target Prompt",
placeholder="A photo of a dog lying on a bed in a dark room",
value="A photo of a dog lying on a bed in a dark room"
)
real_subject_token = gr.Textbox(
label="Subject Token",
placeholder="dog",
value="dog",
info="Single token representing the object to add **(must appear in target prompt)**"
)
with gr.Accordion("Advanced Settings", open=False):
real_seed_src = gr.Number(label="Source Seed", value=6311, precision=0)
real_seed_obj = gr.Number(label="Object Seed", value=1, precision=0)
real_extended_scale = gr.Slider(
label="Extended Scale",
minimum=1.0,
maximum=1.3,
value=1.1,
step=0.01
)
real_structure_transfer_step = gr.Slider(
label="Structure Transfer Step",
minimum=0,
maximum=10,
value=4,
step=1
)
real_blend_steps = gr.Textbox(
label="Blend Steps",
value="18",
info="Comma-separated list of steps (e.g., '15,20') or empty for no blending"
)
real_localization_model = gr.Dropdown(
label="Localization Model",
choices=[
"attention",
"attention_points_sam",
"attention_box_sam",
"attention_mask_sam",
"grounding_sam"
],
value="attention"
)
real_use_offset = gr.Checkbox(label="Use Offset", value=False)
real_disable_inversion = gr.Checkbox(label="Disable Inversion", value=False)
real_submit_btn = gr.Button("🎨 Edit Image", variant="primary")
with gr.Column(scale=2):
with gr.Row():
real_src_output = gr.Image(label="Source Image", type="pil")
real_edited_output = gr.Image(label="Edited Image", type="pil")
real_status = gr.Textbox(label="Status", interactive=False)
real_submit_btn.click(
fn=process_real_image,
inputs=[
real_source_image, real_prompt_source, real_prompt_target, real_subject_token,
real_seed_src, real_seed_obj, real_extended_scale,
real_structure_transfer_step, real_blend_steps,
real_localization_model, real_use_offset,
real_disable_inversion
],
outputs=[real_src_output, real_edited_output, real_status]
)
# Examples for real images
gr.Examples(
examples=[
[
"images/bed_dark_room.jpg",
"A photo of a bed in a dark room",
"A photo of a dog lying on a bed in a dark room",
"dog"
],
[
"images/flower.jpg",
"A photo of a flower",
"A bee standing on a flower",
"bee"
]
],
inputs=[
real_source_image, real_prompt_source, real_prompt_target, real_subject_token
],
label="Example Images & Prompts"
)
# Tips
with gr.Accordion("💡 Tips for Better Results", open=False):
gr.Markdown("""
- **Prompt Design**: The Target Prompt should be similar to the Source Prompt, but include a description of the new object to insert
- **Seed Variation**: Try different values for Object Seed - some prompts may require a few attempts to get satisfying results
- **Localization Models**: The most effective options are `attention_points_sam` and `attention`. Use Show Attention to visualize localization performance
- **Object Placement Issues**: If the object is not added to the image:
- Try **decreasing** Structure Transfer Step
- Try **increasing** Extended Scale
- **Flexibility**: To allow more flexibility in modifying the source image, leave Blend Steps empty to send an empty list
""")
return demo
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)