Spaces:
Runtime error
Runtime error
| import math | |
| import random | |
| import gradio as gr | |
| import torch | |
| from PIL import Image, ImageOps | |
| from diffusers import StableDiffusionInstructPix2PixPipeline | |
| import spaces | |
| help_text = """ | |
| Considerations while editing: | |
| 1. The Base-Model, trained on the PIPE dataset, is great for some tasks, while the Finetuned-MB-Model, fine-tuned on the MagicBrush dataset, can be better for others. Please try both until you are satisfied. | |
| 2. Image CFG controls how much to deviate from the original image. Higher values keep the image more consistent with the original. | |
| 3. Text CFG does the opposite. Higher values lead to more changes in the image. | |
| 4. Using different seed values will produce varied outputs. | |
| 5. Increasing the number of steps can enhance the results. | |
| 6. The Stable Diffusion autoencoder struggles with small faces in images. | |
| """ | |
| article = """ | |
| <p style='text-align: center'> | |
| <a href='https://arxiv.org/abs/2404.18212' target='_blank'> | |
| Paint by Inpaint: Learning to Add Image Objects by Removing Them First</a> | |
| </p> | |
| """ | |
| description = """ | |
| <p style="text-align: center;"> | |
| Gradio demo for <strong>Paint by Inpaint: Learning to Add Image Objects by Removing Them First</strong>, visit our <a href='https://rotsteinnoam.github.io/Paint-by-Inpaint/' target='_blank'>project page</a>. <br> | |
| The demo is both for models trained for image object addition using the <a href='https://huggingface.co/datasets/paint-by-inpaint/PIPE' target='_blank'>PIPE dataset</a> along with models trained with other datasets that are meant for general editing. <br> | |
| </p> | |
| """ | |
| # Base models | |
| object_addition_base_model_id = "paint-by-inpaint/add-base" | |
| general_editing_base_model_id = "paint-by-inpaint/general-base" | |
| # MagicBrush finetuned models | |
| object_addition_finetuned_model_id = "paint-by-inpaint/add-finetuned-mb" | |
| general_editing_finetuned_model_id = "paint-by-inpaint/general-finetuned-mb" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model(model_id): | |
| return StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device) | |
| pipe_object_addition_base = load_model(object_addition_base_model_id) | |
| pipe_object_addition_finetuned = load_model(object_addition_finetuned_model_id) | |
| pipe_general_editing_base = load_model(general_editing_base_model_id) | |
| pipe_general_editing_finetuned = load_model(general_editing_finetuned_model_id) | |
| def generate( | |
| input_image: Image.Image, | |
| instruction: str, | |
| model_choice: int, | |
| steps: int, | |
| randomize_seed: bool, | |
| seed: int, | |
| text_cfg_scale: float, | |
| image_cfg_scale: float, | |
| task_type: str, | |
| ): | |
| seed = random.randint(0, 100000) if randomize_seed else seed | |
| if task_type == "object_addition": | |
| pipe = pipe_object_addition_base if model_choice == 0 else pipe_object_addition_finetuned | |
| else: | |
| pipe = pipe_general_editing_base if model_choice == 0 else pipe_general_editing_finetuned | |
| width, height = input_image.size | |
| factor = 512 / max(width, height) | |
| factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) | |
| width = int((width * factor) // 64) * 64 | |
| height = int((height * factor) // 64) * 64 | |
| input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) | |
| if instruction == "": | |
| return [input_image, seed] | |
| generator = torch.manual_seed(seed) | |
| edited_image = pipe( | |
| instruction, image=input_image, | |
| guidance_scale=text_cfg_scale, image_guidance_scale=image_cfg_scale, | |
| num_inference_steps=steps, generator=generator, | |
| ).images[0] | |
| return [seed, text_cfg_scale, image_cfg_scale, edited_image] | |
| def reset(): | |
| return [0, "Randomize Seed", 2024, "Fix CFG", 7.5, 1.5, None] | |
| with gr.Blocks(css=".compact-box .gr-row { margin-bottom: 5px; } .compact-box .gr-number input, .compact-box .gr-radio label { padding: 5px 10px; }") as demo: | |
| gr.HTML(""" | |
| <div style="text-align: center;"> | |
| <h1 style="font-weight: 900; margin-bottom: 7px;">Paint by Inpaint</h1> | |
| {description} | |
| </div> | |
| """.format(description=description)) | |
| with gr.Tabs(): | |
| with gr.Tab("Object Addition"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="pil", interactive=True) | |
| instruction = gr.Textbox(lines=1, label="Addition Instruction", interactive=True, max_lines=1, placeholder="Enter addition instruction here") | |
| model_choice = gr.Radio( | |
| ["Base-Model", "Finetuned-MB-Model"], | |
| value="Base-Model", | |
| type="index", | |
| label="Choose Model", | |
| interactive=True, | |
| ) | |
| with gr.Group(elem_id="compact-box"): | |
| with gr.Row(): | |
| steps = gr.Number(value=50, precision=0, label="Steps", interactive=True) | |
| with gr.Column(): | |
| with gr.Row(): | |
| seed = gr.Number(value=2024, precision=0, label="Seed", interactive=True) | |
| randomize_seed = gr.Radio( | |
| ["Fix Seed", "Randomize Seed"], | |
| value="Randomize Seed", | |
| type="index", | |
| show_label=False, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| text_cfg_scale = gr.Number(value=7.5, label="Text CFG", interactive=True) | |
| image_cfg_scale = gr.Number(value=1.5, label="Image CFG", interactive=True) | |
| with gr.Row(): | |
| generate_button = gr.Button("Generate") | |
| reset_button = gr.Button("Reset") | |
| with gr.Column(): | |
| edited_image = gr.Image(label="Edited Image", type="pil", interactive=False) | |
| generate_button.click( | |
| fn=lambda *args: generate(*args, task_type="object_addition"), | |
| inputs=[ | |
| input_image, | |
| instruction, | |
| model_choice, | |
| steps, | |
| randomize_seed, | |
| seed, | |
| text_cfg_scale, | |
| image_cfg_scale, | |
| ], | |
| outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image], | |
| ) | |
| reset_button.click( | |
| fn=reset, | |
| inputs=[], | |
| outputs=[steps, randomize_seed, seed, text_cfg_scale, image_cfg_scale, edited_image], | |
| ) | |
| with gr.Tab("General Editing"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image_editing = gr.Image(label="Input Image", type="pil", interactive=True) | |
| instruction_editing = gr.Textbox(lines=1, label="Editing Instruction", interactive=True, max_lines=1, placeholder="Enter editing instruction here") | |
| model_choice_editing = gr.Radio( | |
| ["Base-Model", "Finetuned-MB-Model"], | |
| value="Base-Model", | |
| type="index", | |
| label="Choose Model", | |
| interactive=True, | |
| ) | |
| with gr.Group(elem_id="compact-box"): | |
| with gr.Row(): | |
| steps_editing = gr.Number(value=50, precision=0, label="Steps", interactive=True) | |
| with gr.Column(): | |
| with gr.Row(): | |
| seed_editing = gr.Number(value=2024, precision=0, label="Seed", interactive=True) | |
| randomize_seed_editing = gr.Radio( | |
| ["Fix Seed", "Randomize Seed"], | |
| value="Randomize Seed", | |
| type="index", | |
| show_label=False, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| text_cfg_scale_editing = gr.Number(value=7.5, label="Text CFG", interactive=True) | |
| image_cfg_scale_editing = gr.Number(value=1.5, label="Image CFG", interactive=True) | |
| with gr.Row(): | |
| generate_button_editing = gr.Button("Generate") | |
| reset_button_editing = gr.Button("Reset") | |
| with gr.Column(): | |
| edited_image_editing = gr.Image(label="Edited Image", type="pil", interactive=False) | |
| generate_button_editing.click( | |
| fn=lambda *args: generate(*args, task_type="general_editing"), | |
| inputs=[ | |
| input_image_editing, | |
| instruction_editing, | |
| model_choice_editing, | |
| steps_editing, | |
| randomize_seed_editing, | |
| seed_editing, | |
| text_cfg_scale_editing, | |
| image_cfg_scale_editing, | |
| ], | |
| outputs=[seed_editing, text_cfg_scale_editing, image_cfg_scale_editing, edited_image_editing], | |
| ) | |
| reset_button_editing.click( | |
| fn=reset, | |
| inputs=[], | |
| outputs=[steps_editing, randomize_seed_editing, seed_editing, text_cfg_scale_editing, image_cfg_scale_editing, edited_image_editing], | |
| ) | |
| gr.Markdown(help_text) | |
| examples = [ | |
| ["examples/messi.jpeg", "Add a royal silver crown"], | |
| ["examples/coffee.jpg", "Add steamed milk"], | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[input_image, instruction], | |
| outputs=[edited_image], | |
| ) | |
| gr.HTML(article) | |
| demo.queue() | |
| demo.launch(share=False, max_threads=1) | |