Spaces:
Runtime error
Runtime error
File size: 9,102 Bytes
dbaf842 d40f6d4 98b9d34 d40f6d4 98b9d34 d40f6d4 98b9d34 d40f6d4 98b9d34 dbaf842 98b9d34 dbaf842 98b9d34 fe691aa 98b9d34 fe691aa 98b9d34 7b80ee9 98b9d34 861cb49 67643a8 98b9d34 861cb49 ac04e7e 67643a8 ac04e7e dbaf842 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import os
import gradio as gr
import os, pdb
import argparse
import numpy as np
import torch
import requests
from PIL import Image
from transformers import AutoProcessor, BlipForConditionalGeneration
from diffusers import UNet2DConditionModel, DDIMScheduler
from src.utils.ddim_inv import DDIMInversion
from src.utils.scheduler import DDIMInverseScheduler
from src.utils.edit_directions import construct_direction, construct_direction_prompts
from src.utils.edit_pipeline import EditingPipeline
#from src.make_edit_direction import load_sentence_embeddings
torch_dtype = torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"
blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large.to(device)
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to(device)
pipe_inversion = DDIMInversion.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch_dtype, unet=unet).to(device)
pipe_inversion.scheduler = DDIMInverseScheduler.from_config(pipe_inversion.scheduler.config)
pipe_editing = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch_dtype, unet=unet).to(device)
pipe_editing.scheduler = DDIMScheduler.from_config(pipe_editing.scheduler.config)
def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
with torch.no_grad():
l_embeddings = []
for sent in l_sentences:
text_inputs = tokenizer(
sent,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
l_embeddings.append(prompt_embeds)
return torch.cat(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)
def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
inputs = processor(images=image, return_tensors="pt").to(device)
if use_float_16:
inputs = inputs.to(torch.float16)
generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
if tokenizer is not None:
generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_caption
def generate_inversion(prompt, image, num_ddim_steps=50):
image = image.resize((512,512), Image.Resampling.LANCZOS)
x_inv, x_inv_image, x_dec_img = pipe_inversion(
prompt,
guidance_scale=1,
num_inversion_steps=num_ddim_steps,
img=image,
torch_dtype=torch_dtype
)
return x_inv[0]
def run_captioning(image):
caption = generate_caption(blip_processor_large, blip_model_large, image).strip()
return caption
def run_editing(image, original_prompt, edit_prompt, ddim_steps=50, xa_guidance=0.1, negative_guidance_scale=5.0):
inverted_noise = generate_inversion(original_prompt, image)
source_prompt_embeddings = load_sentence_embeddings([original_prompt], pipe_editing.tokenizer, pipe_editing.text_encoder, device="cuda")
target_prompt_embeddings = load_sentence_embeddings([edit_prompt], pipe_editing.tokenizer, pipe_editing.text_encoder, device="cuda")
rec_pil, edit_pil = pipe_editing(
original_prompt,
num_inference_steps=ddim_steps,
x_in=inverted_noise.unsqueeze(0),
edit_dir=construct_direction_prompts(source_prompt_embeddings,target_prompt_embeddings),
guidance_amount=xa_guidance,
guidance_scale=negative_guidance_scale,
negative_prompt=original_prompt # use the unedited prompt for the negative prompt
)
return edit_pil[0]
def run_editing_quality(image, item_from, item_from_other, item_to, item_to_other, ddim_steps=50, xa_guidance=0.1, negative_guidance_scale=5.0):
caption = generate_caption(blip_processor_large, blip_model_large, image).strip()
item_from_selected = item_from if item_from_other == "" else item_from_other
item_to_selected = item_to if item_to_other == "" else item_to_other
inverted_noise = generate_inversion(caption, image)
emb_dir = f"assets/embeddings_sd_1.4"
embs_a = torch.load(os.path.join(emb_dir, f"{item_from_selected}.pt"))
embs_b = torch.load(os.path.join(emb_dir, f"{item_to_selected}.pt"))
edit_dir = (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
rec_pil, edit_pil = pipe_editing(
caption,
num_inference_steps=ddim_steps,
x_in=inverted_noise.unsqueeze(0),
edit_dir=edit_dir,
guidance_amount=xa_guidance,
guidance_scale=negative_guidance_scale,
negative_prompt=caption # use the unedited prompt for the negative prompt
)
return edit_pil[0]
css = '''
#generate_button{height: 100%}
#quality_description{text-align: center; margin-top: 1em}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown('''## Edit with Words - Pix2Pix Zero demo
Upload an image to edit it. You can try `Fast mode` with prompts, or `Quality mode` with custom edit directions.
''')
with gr.Row():
with gr.Column():
image = gr.Image(label="Upload your image", type="pil", shape=(512, 512))
with gr.Tabs():
with gr.TabItem("Fast mode"):
with gr.Row():
with gr.Column(scale=10):
original_prompt = gr.Textbox(label="Image description - type a caption for the image or generate it")
with gr.Column(scale=1, min_width=180):
btn_caption = gr.Button("Generate caption", elem_id="generate_button")
edit_prompt = gr.Textbox(label="Edit prompt - what would you like to edit in the image above. Change one thing at a time")
btn_edit_fast = gr.Button("Edit Image")
with gr.TabItem("Quality mode"):
gr.Markdown("Quality mode temporarely set to only 4 categories. Soon to be dynamic, so you can create your own edit directions.", elem_id="quality_description")
with gr.Row():
with gr.Column():
item_from = gr.Dropdown(label="What to identify in your image", choices=["cat", "dog", "horse", "zebra"], value="cat")
item_from_other = gr.Textbox(visible=False, label="Type what to identify on your image")
item_from.change(lambda choice: gr.Dropdown.update(visible=choice=="Other"), item_from, item_from_other)
with gr.Column():
item_to = gr.Dropdown(label="What to replace what you identified for", choices=["cat", "dog", "horse", "zebra"], value="dog")
item_to_other = gr.Textbox(visible=False, label="Type what to replace what you identified for")
item_to.change(lambda choice: gr.Dropdown.update(visible=choice=="Other"), item_to, item_to_other)
btn_edit_quality = gr.Button("Edit Image")
with gr.Accordion(label="Advanced settings", open=False):
steps = gr.Slider(minimum=2, maximum=50, step=1, value=50, label="Inference Steps")
xa_guidance =gr.Slider(minimum=0.0, maximum=10.0, step=0.05, value=0.1, label="xa guidance")
negative_scale = gr.Slider(minimum=0.0, maximum=20.0, step=0.1, value=5.0, label="Negative Guidance Scale")
with gr.Column():
image_output = gr.Image(label="Image with edits",type="pil",shape=(512, 512))
btn_caption.click(fn=run_captioning, inputs=image, outputs=original_prompt)
btn_edit_fast.click(fn=run_editing, inputs=[image, original_prompt, edit_prompt, steps, xa_guidance, negative_scale], outputs=[image_output])
btn_edit_quality.click(fn=run_editing_quality, inputs=[image, item_from, item_from_other, item_to, item_to_other, steps, xa_guidance, negative_scale], outputs=[image_output])
gr.Examples(
examples=[
[os.path.join(os.path.dirname(__file__), "assets/test_images/cats/cat_1.png"), "cat", "", "dog", ""],
[os.path.join(os.path.dirname(__file__), "assets/test_images/cats/cat_2.png"), "cat", "", "horse", ""],
[os.path.join(os.path.dirname(__file__), "assets/test_images/dogs/dog_1.png"), "dog", "", "horse", ""],
[os.path.join(os.path.dirname(__file__), "assets/test_images/dogs/dog_2.png"), "dog", "", "cat", ""],
],
inputs=[image, item_from, item_from_other, item_to, item_to_other],
outputs=image_output,
fn=run_editing_quality,
cache_examples=True,
)
demo.launch() |