#!/usr/bin/env python3 # Copyright (C) 2025 NVIDIA Corporation. All rights reserved. # # This work is licensed under the LICENSE file # located at the root directory. import os import gradio as gr import spaces import torch import numpy as np from PIL import Image import tempfile import gc from datetime import datetime from addit_flux_pipeline import AdditFluxPipeline from addit_flux_transformer import AdditFluxTransformer2DModel from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler from addit_methods import add_object_generated, add_object_real # Global variables for model pipe = None device = None # Initialize model at startup print("Initializing ADDIT model...") try: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load transformer my_transformer = AdditFluxTransformer2DModel.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="transformer", torch_dtype=torch.bfloat16 ) # Load pipeline pipe = AdditFluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", transformer=my_transformer, torch_dtype=torch.bfloat16 ).to(device) # Set scheduler pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) print("Model initialized successfully!") except Exception as e: print(f"Error initializing model: {str(e)}") print("The application will start but model functionality will be unavailable.") def validate_inputs(prompt_source, prompt_target, subject_token): """Validate user inputs""" if not prompt_source.strip(): return "Source prompt cannot be empty" if not prompt_target.strip(): return "Target prompt cannot be empty" if not subject_token.strip(): return "Subject token cannot be empty" if subject_token not in prompt_target: return f"Subject token '{subject_token}' must appear in the target prompt" return None @spaces.GPU def process_generated_image( prompt_source, prompt_target, subject_token, seed_src, seed_obj, extended_scale, structure_transfer_step, blend_steps, localization_model, progress=gr.Progress(track_tqdm=True) ): """Process generated image with ADDIT""" global pipe if pipe is None: return None, None, "Model not initialized. Please restart the application." # Validate inputs error_msg = validate_inputs(prompt_source, prompt_target, subject_token) if error_msg: return None, None, error_msg # Print current time and input information current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print(f"\n[{current_time}] Starting Generated Image Processing") print(f"Source Prompt: '{prompt_source}'") print(f"Target Prompt: '{prompt_target}'") print(f"Subject Token: '{subject_token}'") print(f"Source Seed: {seed_src}, Object Seed: {seed_obj}") print(f"Extended Scale: {extended_scale}, Structure Transfer Step: {structure_transfer_step}") print(f"Blend Steps: '{blend_steps}', Localization Model: '{localization_model}'") try: # Parse blend steps if blend_steps.strip(): blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()] else: blend_steps_list = [] # Generate images src_image, edited_image = add_object_generated( pipe=pipe, prompt_source=prompt_source, prompt_object=prompt_target, subject_token=subject_token, seed_src=seed_src, seed_obj=seed_obj, show_attention=False, extended_scale=extended_scale, structure_transfer_step=structure_transfer_step, blend_steps=blend_steps_list, localization_model=localization_model, display_output=False ) return src_image, edited_image, "Images generated successfully!" except Exception as e: error_msg = f"Error generating images: {str(e)}" print(error_msg) return None, None, error_msg @spaces.GPU def process_real_image( source_image, prompt_source, prompt_target, subject_token, seed_src, seed_obj, extended_scale, structure_transfer_step, blend_steps, localization_model, use_offset, disable_inversion, progress=gr.Progress(track_tqdm=True) ): """Process real image with ADDIT""" global pipe if pipe is None: return None, None, "Model not initialized. Please restart the application." if source_image is None: return None, None, "Please upload a source image" # Validate inputs error_msg = validate_inputs(prompt_source, prompt_target, subject_token) if error_msg: return None, None, error_msg # Print current time and input information current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print(f"\n[{current_time}] Starting Real Image Processing") print(f"Source Image Size: {source_image.size}") print(f"Source Prompt: '{prompt_source}'") print(f"Target Prompt: '{prompt_target}'") print(f"Subject Token: '{subject_token}'") print(f"Source Seed: {seed_src}, Object Seed: {seed_obj}") print(f"Extended Scale: {extended_scale}, Structure Transfer Step: {structure_transfer_step}") print(f"Blend Steps: '{blend_steps}', Localization Model: '{localization_model}'") print(f"Use Offset: {use_offset}, Disable Inversion: {disable_inversion}") try: # Resize source image source_image = source_image.resize((1024, 1024)) # Parse blend steps if blend_steps.strip(): blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()] else: blend_steps_list = [] # Process image src_image, edited_image = add_object_real( pipe=pipe, source_image=source_image, prompt_source=prompt_source, prompt_object=prompt_target, subject_token=subject_token, seed_src=seed_src, seed_obj=seed_obj, extended_scale=extended_scale, structure_transfer_step=structure_transfer_step, blend_steps=blend_steps_list, localization_model=localization_model, use_offset=use_offset, show_attention=False, use_inversion=not disable_inversion, display_output=False ) return src_image, edited_image, "Image edited successfully!" except Exception as e: error_msg = f"Error processing image: {str(e)}" print(error_msg) return None, None, error_msg def create_interface(): """Create the Gradio interface""" # Show model status in the interface model_status = "Model ready!" if pipe is not None else "Model initialization failed - functionality unavailable" with gr.Blocks(title="🎨 Add-it: Training-Free Object Insertion in Images With Pretrained Diffusion Models", theme=gr.themes.Soft()) as demo: gr.HTML(f"""
Add objects to images using pretrained diffusion models
🌐 Project Website | 📄 Paper | 💻 Code
Status: {model_status}
Note: Images will be resized to 1024x1024 pixels. For best results, use square images.
") with gr.Row(): with gr.Column(scale=1): real_source_image = gr.Image(label="Source Image", type="pil") real_prompt_source = gr.Textbox( label="Source Prompt", placeholder="A photo of a bed in a dark room", value="A photo of a bed in a dark room" ) real_prompt_target = gr.Textbox( label="Target Prompt", placeholder="A photo of a dog lying on a bed in a dark room", value="A photo of a dog lying on a bed in a dark room" ) real_subject_token = gr.Textbox( label="Subject Token", placeholder="dog", value="dog", info="Single token representing the object to add **(must appear in target prompt)**" ) with gr.Accordion("Advanced Settings", open=False): real_seed_src = gr.Number(label="Source Seed", value=1, precision=0) real_seed_obj = gr.Number(label="Object Seed", value=0, precision=0) real_extended_scale = gr.Slider( label="Extended Scale", minimum=1.0, maximum=1.3, value=1.1, step=0.01 ) real_structure_transfer_step = gr.Slider( label="Structure Transfer Step", minimum=0, maximum=10, value=4, step=1 ) real_blend_steps = gr.Textbox( label="Blend Steps", value="18", info="Comma-separated list of steps (e.g., '15,20') or empty for no blending" ) real_localization_model = gr.Dropdown( label="Localization Model", choices=[ "attention", "attention_points_sam", "attention_box_sam", "attention_mask_sam", "grounding_sam" ], value="attention" ) real_use_offset = gr.Checkbox(label="Use Offset", value=False) real_disable_inversion = gr.Checkbox(label="Disable Inversion", value=False) real_submit_btn = gr.Button("🎨 Edit Image", variant="primary") with gr.Column(scale=2): with gr.Row(): real_src_output = gr.Image(label="Source Image", type="pil") real_edited_output = gr.Image(label="Edited Image", type="pil") real_status = gr.Textbox(label="Status", interactive=False) real_submit_btn.click( fn=process_real_image, inputs=[ real_source_image, real_prompt_source, real_prompt_target, real_subject_token, real_seed_src, real_seed_obj, real_extended_scale, real_structure_transfer_step, real_blend_steps, real_localization_model, real_use_offset, real_disable_inversion ], outputs=[real_src_output, real_edited_output, real_status] ) # Examples for real images gr.Examples( examples=[ [ "images/bed_dark_room.jpg", "A photo of a bed in a dark room", "A photo of a dog lying on a bed in a dark room", "dog" ], [ "images/flower.jpg", "A photo of a flower", "A bee standing on a flower", "bee" ] ], inputs=[ real_source_image, real_prompt_source, real_prompt_target, real_subject_token ], label="Example Images & Prompts" ) # Tips with gr.Accordion("💡 Tips for Better Results", open=False): gr.Markdown(""" - **Prompt Design**: The Target Prompt should be similar to the Source Prompt, but include a description of the new object to insert - **Seed Variation**: Try different values for Object Seed - some prompts may require a few attempts to get satisfying results - **Localization Models**: The most effective options are `attention_points_sam` and `attention`. Use Show Attention to visualize localization performance - **Object Placement Issues**: If the object is not added to the image: - Try **decreasing** Structure Transfer Step - Try **increasing** Extended Scale - **Flexibility**: To allow more flexibility in modifying the source image, leave Blend Steps empty to send an empty list """) return demo demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=True )