|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import gradio as gr |
|
import spaces |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import tempfile |
|
import gc |
|
|
|
from addit_flux_pipeline import AdditFluxPipeline |
|
from addit_flux_transformer import AdditFluxTransformer2DModel |
|
from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler |
|
from addit_methods import add_object_generated, add_object_real |
|
|
|
|
|
pipe = None |
|
device = None |
|
|
|
|
|
print("Initializing ADDIT model...") |
|
try: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
my_transformer = AdditFluxTransformer2DModel.from_pretrained( |
|
"black-forest-labs/FLUX.1-dev", |
|
subfolder="transformer", |
|
torch_dtype=torch.bfloat16 |
|
) |
|
|
|
|
|
pipe = AdditFluxPipeline.from_pretrained( |
|
"black-forest-labs/FLUX.1-dev", |
|
transformer=my_transformer, |
|
torch_dtype=torch.bfloat16 |
|
).to(device) |
|
|
|
|
|
pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) |
|
|
|
print("Model initialized successfully!") |
|
|
|
except Exception as e: |
|
print(f"Error initializing model: {str(e)}") |
|
print("The application will start but model functionality will be unavailable.") |
|
|
|
def validate_inputs(prompt_source, prompt_target, subject_token): |
|
"""Validate user inputs""" |
|
if not prompt_source.strip(): |
|
return "Source prompt cannot be empty" |
|
if not prompt_target.strip(): |
|
return "Target prompt cannot be empty" |
|
if not subject_token.strip(): |
|
return "Subject token cannot be empty" |
|
if subject_token not in prompt_target: |
|
return f"Subject token '{subject_token}' must appear in the target prompt" |
|
return None |
|
|
|
@spaces.GPU |
|
def process_generated_image( |
|
prompt_source, |
|
prompt_target, |
|
subject_token, |
|
seed_src, |
|
seed_obj, |
|
extended_scale, |
|
structure_transfer_step, |
|
blend_steps, |
|
localization_model, |
|
progress=gr.Progress(track_tqdm=True) |
|
): |
|
"""Process generated image with ADDIT""" |
|
global pipe |
|
|
|
if pipe is None: |
|
return None, None, "Model not initialized. Please restart the application." |
|
|
|
|
|
error_msg = validate_inputs(prompt_source, prompt_target, subject_token) |
|
if error_msg: |
|
return None, None, error_msg |
|
|
|
try: |
|
|
|
if blend_steps.strip(): |
|
blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()] |
|
else: |
|
blend_steps_list = [] |
|
|
|
|
|
src_image, edited_image = add_object_generated( |
|
pipe=pipe, |
|
prompt_source=prompt_source, |
|
prompt_object=prompt_target, |
|
subject_token=subject_token, |
|
seed_src=seed_src, |
|
seed_obj=seed_obj, |
|
show_attention=False, |
|
extended_scale=extended_scale, |
|
structure_transfer_step=structure_transfer_step, |
|
blend_steps=blend_steps_list, |
|
localization_model=localization_model, |
|
display_output=False |
|
) |
|
|
|
return src_image, edited_image, "Images generated successfully!" |
|
|
|
except Exception as e: |
|
error_msg = f"Error generating images: {str(e)}" |
|
print(error_msg) |
|
return None, None, error_msg |
|
|
|
@spaces.GPU |
|
def process_real_image( |
|
source_image, |
|
prompt_source, |
|
prompt_target, |
|
subject_token, |
|
seed_src, |
|
seed_obj, |
|
extended_scale, |
|
structure_transfer_step, |
|
blend_steps, |
|
localization_model, |
|
use_offset, |
|
disable_inversion, |
|
progress=gr.Progress(track_tqdm=True) |
|
): |
|
"""Process real image with ADDIT""" |
|
global pipe |
|
|
|
if pipe is None: |
|
return None, None, "Model not initialized. Please restart the application." |
|
|
|
if source_image is None: |
|
return None, None, "Please upload a source image" |
|
|
|
|
|
error_msg = validate_inputs(prompt_source, prompt_target, subject_token) |
|
if error_msg: |
|
return None, None, error_msg |
|
|
|
try: |
|
|
|
source_image = source_image.resize((1024, 1024)) |
|
|
|
|
|
if blend_steps.strip(): |
|
blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()] |
|
else: |
|
blend_steps_list = [] |
|
|
|
|
|
src_image, edited_image = add_object_real( |
|
pipe=pipe, |
|
source_image=source_image, |
|
prompt_source=prompt_source, |
|
prompt_object=prompt_target, |
|
subject_token=subject_token, |
|
seed_src=seed_src, |
|
seed_obj=seed_obj, |
|
extended_scale=extended_scale, |
|
structure_transfer_step=structure_transfer_step, |
|
blend_steps=blend_steps_list, |
|
localization_model=localization_model, |
|
use_offset=use_offset, |
|
show_attention=False, |
|
use_inversion=not disable_inversion, |
|
display_output=False |
|
) |
|
|
|
return src_image, edited_image, "Image edited successfully!" |
|
|
|
except Exception as e: |
|
error_msg = f"Error processing image: {str(e)}" |
|
print(error_msg) |
|
return None, None, error_msg |
|
|
|
def create_interface(): |
|
"""Create the Gradio interface""" |
|
|
|
|
|
model_status = "Model ready!" if pipe is not None else "Model initialization failed - functionality unavailable" |
|
|
|
with gr.Blocks(title="🎨 Add-it: Training-Free Object Insertion in Images With Pretrained Diffusion Models", theme=gr.themes.Soft()) as demo: |
|
gr.HTML(f""" |
|
<div style="text-align: center; margin-bottom: 20px;"> |
|
<h1>🎨 Add-it: Training-Free Object Insertion</h1> |
|
<p>Add objects to images using pretrained diffusion models</p> |
|
<p><a href="https://research.nvidia.com/labs/par/addit/" target="_blank">🌐 Project Website</a> | |
|
<a href="https://arxiv.org/abs/2411.07232" target="_blank">📄 Paper</a> | |
|
<a href="https://github.com/NVlabs/addit" target="_blank">💻 Code</a></p> |
|
<p style="color: {'green' if pipe is not None else 'red'}; font-weight: bold;">Status: {model_status}</p> |
|
</div> |
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
with gr.TabItem("🎭 Generated Images"): |
|
gr.Markdown("### Generate a base image and add objects to it") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gen_prompt_source = gr.Textbox( |
|
label="Source Prompt", |
|
placeholder="A photo of a cat sitting on the couch", |
|
value="A photo of a cat sitting on the couch" |
|
) |
|
gen_prompt_target = gr.Textbox( |
|
label="Target Prompt", |
|
placeholder="A photo of a cat wearing a red hat sitting on the couch", |
|
value="A photo of a cat wearing a red hat sitting on the couch" |
|
) |
|
gen_subject_token = gr.Textbox( |
|
label="Subject Token", |
|
placeholder="hat", |
|
value="hat", |
|
info="Single token representing the object to add **(must appear in target prompt)**" |
|
) |
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
gen_seed_src = gr.Number(label="Source Seed", value=6311, precision=0) |
|
gen_seed_obj = gr.Number(label="Object Seed", value=1, precision=0) |
|
gen_extended_scale = gr.Slider( |
|
label="Extended Scale", |
|
minimum=1.0, |
|
maximum=1.3, |
|
value=1.05, |
|
step=0.01 |
|
) |
|
gen_structure_transfer_step = gr.Slider( |
|
label="Structure Transfer Step", |
|
minimum=0, |
|
maximum=10, |
|
value=2, |
|
step=1 |
|
) |
|
gen_blend_steps = gr.Textbox( |
|
label="Blend Steps", |
|
value="15", |
|
info="Comma-separated list of steps (e.g., '15,20') or empty for no blending" |
|
) |
|
gen_localization_model = gr.Dropdown( |
|
label="Localization Model", |
|
choices=[ |
|
"attention_points_sam", |
|
"attention", |
|
"attention_box_sam", |
|
"attention_mask_sam", |
|
"grounding_sam" |
|
], |
|
value="attention_points_sam" |
|
) |
|
|
|
gen_submit_btn = gr.Button("🎨 Generate & Edit", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
with gr.Row(): |
|
gen_src_output = gr.Image(label="Generated Source Image", type="pil") |
|
gen_edited_output = gr.Image(label="Edited Image", type="pil") |
|
gen_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
gen_submit_btn.click( |
|
fn=process_generated_image, |
|
inputs=[ |
|
gen_prompt_source, gen_prompt_target, gen_subject_token, |
|
gen_seed_src, gen_seed_obj, gen_extended_scale, |
|
gen_structure_transfer_step, gen_blend_steps, |
|
gen_localization_model |
|
], |
|
outputs=[gen_src_output, gen_edited_output, gen_status] |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["A photo of a man sitting on a bench", "A photo of a man sitting on a bench with a dog", "dog"], |
|
["A photo of a cat sitting on the couch", "A photo of a cat wearing a red hat sitting on the couch", "hat"], |
|
["A car driving through an empty street", "A pink car driving through an empty street", "car"] |
|
], |
|
inputs=[ |
|
gen_prompt_source, gen_prompt_target, gen_subject_token |
|
], |
|
label="Example Prompts" |
|
) |
|
|
|
|
|
with gr.TabItem("📸 Real Images"): |
|
gr.Markdown("### Upload an image and add objects to it") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
real_source_image = gr.Image(label="Source Image", type="pil") |
|
real_prompt_source = gr.Textbox( |
|
label="Source Prompt", |
|
placeholder="A photo of a bed in a dark room", |
|
value="A photo of a bed in a dark room" |
|
) |
|
real_prompt_target = gr.Textbox( |
|
label="Target Prompt", |
|
placeholder="A photo of a dog lying on a bed in a dark room", |
|
value="A photo of a dog lying on a bed in a dark room" |
|
) |
|
real_subject_token = gr.Textbox( |
|
label="Subject Token", |
|
placeholder="dog", |
|
value="dog", |
|
info="Single token representing the object to add **(must appear in target prompt)**" |
|
) |
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
real_seed_src = gr.Number(label="Source Seed", value=6311, precision=0) |
|
real_seed_obj = gr.Number(label="Object Seed", value=1, precision=0) |
|
real_extended_scale = gr.Slider( |
|
label="Extended Scale", |
|
minimum=1.0, |
|
maximum=1.3, |
|
value=1.1, |
|
step=0.01 |
|
) |
|
real_structure_transfer_step = gr.Slider( |
|
label="Structure Transfer Step", |
|
minimum=0, |
|
maximum=10, |
|
value=4, |
|
step=1 |
|
) |
|
real_blend_steps = gr.Textbox( |
|
label="Blend Steps", |
|
value="18", |
|
info="Comma-separated list of steps (e.g., '15,20') or empty for no blending" |
|
) |
|
real_localization_model = gr.Dropdown( |
|
label="Localization Model", |
|
choices=[ |
|
"attention", |
|
"attention_points_sam", |
|
"attention_box_sam", |
|
"attention_mask_sam", |
|
"grounding_sam" |
|
], |
|
value="attention" |
|
) |
|
real_use_offset = gr.Checkbox(label="Use Offset", value=False) |
|
real_disable_inversion = gr.Checkbox(label="Disable Inversion", value=False) |
|
|
|
real_submit_btn = gr.Button("🎨 Edit Image", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
with gr.Row(): |
|
real_src_output = gr.Image(label="Source Image", type="pil") |
|
real_edited_output = gr.Image(label="Edited Image", type="pil") |
|
real_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
real_submit_btn.click( |
|
fn=process_real_image, |
|
inputs=[ |
|
real_source_image, real_prompt_source, real_prompt_target, real_subject_token, |
|
real_seed_src, real_seed_obj, real_extended_scale, |
|
real_structure_transfer_step, real_blend_steps, |
|
real_localization_model, real_use_offset, |
|
real_disable_inversion |
|
], |
|
outputs=[real_src_output, real_edited_output, real_status] |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
[ |
|
"images/bed_dark_room.jpg", |
|
"A photo of a bed in a dark room", |
|
"A photo of a dog lying on a bed in a dark room", |
|
"dog" |
|
], |
|
[ |
|
"images/flower.jpg", |
|
"A photo of a flower", |
|
"A bee standing on a flower", |
|
"bee" |
|
] |
|
], |
|
inputs=[ |
|
real_source_image, real_prompt_source, real_prompt_target, real_subject_token |
|
], |
|
label="Example Images & Prompts" |
|
) |
|
|
|
|
|
with gr.Accordion("💡 Tips for Better Results", open=False): |
|
gr.Markdown(""" |
|
- **Prompt Design**: The Target Prompt should be similar to the Source Prompt, but include a description of the new object to insert |
|
- **Seed Variation**: Try different values for Object Seed - some prompts may require a few attempts to get satisfying results |
|
- **Localization Models**: The most effective options are `attention_points_sam` and `attention`. Use Show Attention to visualize localization performance |
|
- **Object Placement Issues**: If the object is not added to the image: |
|
- Try **decreasing** Structure Transfer Step |
|
- Try **increasing** Extended Scale |
|
- **Flexibility**: To allow more flexibility in modifying the source image, leave Blend Steps empty to send an empty list |
|
""") |
|
|
|
return demo |
|
|
|
demo = create_interface() |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=True |
|
) |