File size: 9,102 Bytes
dbaf842
 
 
d40f6d4
 
 
 
 
 
 
 
98b9d34
d40f6d4
98b9d34
d40f6d4
 
98b9d34
 
 
d40f6d4
98b9d34
 
dbaf842
98b9d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbaf842
98b9d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe691aa
98b9d34
 
 
 
 
fe691aa
98b9d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b80ee9
98b9d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
861cb49
67643a8
 
 
98b9d34
 
 
 
 
861cb49
 
ac04e7e
 
 
 
 
67643a8
ac04e7e
 
 
 
 
 
dbaf842
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import gradio as gr

import os, pdb

import argparse
import numpy as np
import torch
import requests
from PIL import Image

from transformers import AutoProcessor, BlipForConditionalGeneration

from diffusers import UNet2DConditionModel, DDIMScheduler
from src.utils.ddim_inv import DDIMInversion
from src.utils.scheduler import DDIMInverseScheduler
from src.utils.edit_directions import construct_direction, construct_direction_prompts
from src.utils.edit_pipeline import EditingPipeline
#from src.make_edit_direction import load_sentence_embeddings

torch_dtype = torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"

blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large.to(device)

unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to(device)
pipe_inversion = DDIMInversion.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch_dtype, unet=unet).to(device)
pipe_inversion.scheduler = DDIMInverseScheduler.from_config(pipe_inversion.scheduler.config)

pipe_editing = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch_dtype, unet=unet).to(device)
pipe_editing.scheduler = DDIMScheduler.from_config(pipe_editing.scheduler.config)

def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
    with torch.no_grad():
        l_embeddings = []
        for sent in l_sentences:
            text_inputs = tokenizer(
                    sent,
                    padding="max_length",
                    max_length=tokenizer.model_max_length,
                    truncation=True,
                    return_tensors="pt",
                )
            text_input_ids = text_inputs.input_ids
            prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
            l_embeddings.append(prompt_embeds)
    return torch.cat(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)

def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
    inputs = processor(images=image, return_tensors="pt").to(device)

    if use_float_16:
        inputs = inputs.to(torch.float16)
    
    generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)

    if tokenizer is not None:
        generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    else:
        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
   
    return generated_caption

def generate_inversion(prompt, image, num_ddim_steps=50):
    image = image.resize((512,512), Image.Resampling.LANCZOS)
    x_inv, x_inv_image, x_dec_img = pipe_inversion(
            prompt, 
            guidance_scale=1,
            num_inversion_steps=num_ddim_steps,
            img=image,
            torch_dtype=torch_dtype
        )
    return x_inv[0]

def run_captioning(image):
    caption = generate_caption(blip_processor_large, blip_model_large, image).strip()
    return caption
    
def run_editing(image, original_prompt, edit_prompt, ddim_steps=50, xa_guidance=0.1, negative_guidance_scale=5.0):
    inverted_noise = generate_inversion(original_prompt, image)
    source_prompt_embeddings = load_sentence_embeddings([original_prompt], pipe_editing.tokenizer, pipe_editing.text_encoder, device="cuda")
    target_prompt_embeddings = load_sentence_embeddings([edit_prompt], pipe_editing.tokenizer, pipe_editing.text_encoder, device="cuda")
    rec_pil, edit_pil = pipe_editing(
        original_prompt,
        num_inference_steps=ddim_steps,
        x_in=inverted_noise.unsqueeze(0),
        edit_dir=construct_direction_prompts(source_prompt_embeddings,target_prompt_embeddings),
        guidance_amount=xa_guidance,
        guidance_scale=negative_guidance_scale,
        negative_prompt=original_prompt # use the unedited prompt for the negative prompt
    )
    return edit_pil[0]

def run_editing_quality(image, item_from, item_from_other, item_to, item_to_other, ddim_steps=50, xa_guidance=0.1, negative_guidance_scale=5.0):
    caption = generate_caption(blip_processor_large, blip_model_large, image).strip()
    item_from_selected = item_from if item_from_other == "" else item_from_other
    item_to_selected = item_to if item_to_other == "" else item_to_other
    inverted_noise = generate_inversion(caption, image)
    emb_dir = f"assets/embeddings_sd_1.4"
    embs_a = torch.load(os.path.join(emb_dir, f"{item_from_selected}.pt"))
    embs_b = torch.load(os.path.join(emb_dir, f"{item_to_selected}.pt"))
    edit_dir = (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
    rec_pil, edit_pil = pipe_editing(
        caption,
        num_inference_steps=ddim_steps,
        x_in=inverted_noise.unsqueeze(0),
        edit_dir=edit_dir,
        guidance_amount=xa_guidance,
        guidance_scale=negative_guidance_scale,
        negative_prompt=caption # use the unedited prompt for the negative prompt
    )
    return edit_pil[0]
css = '''
#generate_button{height: 100%}
#quality_description{text-align: center; margin-top: 1em}
'''
with gr.Blocks(css=css) as demo:
    gr.Markdown('''## Edit with Words - Pix2Pix Zero demo
Upload an image to edit it. You can try `Fast mode` with prompts, or `Quality mode` with custom edit directions. 
''')
    with gr.Row():
        with gr.Column():
            image = gr.Image(label="Upload your image", type="pil", shape=(512, 512))
            with gr.Tabs():
                with gr.TabItem("Fast mode"):
                    with gr.Row():
                        with gr.Column(scale=10):
                            original_prompt = gr.Textbox(label="Image description - type a caption for the image or generate it")
                        with gr.Column(scale=1, min_width=180):
                            btn_caption = gr.Button("Generate caption", elem_id="generate_button")
                    edit_prompt = gr.Textbox(label="Edit prompt - what would you like to edit in the image above. Change one thing at a time")
                    btn_edit_fast = gr.Button("Edit Image")
                with gr.TabItem("Quality mode"):
                    gr.Markdown("Quality mode temporarely set to only 4 categories. Soon to be dynamic, so you can create your own edit directions.", elem_id="quality_description")
                    with gr.Row():
                        with gr.Column():
                            item_from = gr.Dropdown(label="What to identify in your image", choices=["cat", "dog", "horse", "zebra"], value="cat")
                            item_from_other = gr.Textbox(visible=False, label="Type what to identify on your image")
                            item_from.change(lambda choice: gr.Dropdown.update(visible=choice=="Other"), item_from, item_from_other)
                        with gr.Column():
                            item_to = gr.Dropdown(label="What to replace what you identified for", choices=["cat", "dog", "horse", "zebra"], value="dog")
                            item_to_other = gr.Textbox(visible=False, label="Type what to replace what you identified for")
                            item_to.change(lambda choice: gr.Dropdown.update(visible=choice=="Other"), item_to, item_to_other)
                    btn_edit_quality = gr.Button("Edit Image")    
            with gr.Accordion(label="Advanced settings", open=False):
                steps = gr.Slider(minimum=2, maximum=50, step=1, value=50, label="Inference Steps")
                xa_guidance =gr.Slider(minimum=0.0, maximum=10.0, step=0.05, value=0.1, label="xa guidance")
                negative_scale = gr.Slider(minimum=0.0, maximum=20.0, step=0.1, value=5.0, label="Negative Guidance Scale")
        with gr.Column():
            image_output = gr.Image(label="Image with edits",type="pil",shape=(512, 512))
            
    btn_caption.click(fn=run_captioning, inputs=image, outputs=original_prompt)
    
    btn_edit_fast.click(fn=run_editing, inputs=[image, original_prompt, edit_prompt, steps, xa_guidance, negative_scale], outputs=[image_output])
    btn_edit_quality.click(fn=run_editing_quality, inputs=[image, item_from, item_from_other, item_to, item_to_other, steps, xa_guidance, negative_scale], outputs=[image_output])
    gr.Examples(
        examples=[
            [os.path.join(os.path.dirname(__file__), "assets/test_images/cats/cat_1.png"), "cat", "", "dog", ""],
            [os.path.join(os.path.dirname(__file__), "assets/test_images/cats/cat_2.png"), "cat", "", "horse", ""],
            [os.path.join(os.path.dirname(__file__), "assets/test_images/dogs/dog_1.png"), "dog", "", "horse", ""],
            [os.path.join(os.path.dirname(__file__), "assets/test_images/dogs/dog_2.png"), "dog", "", "cat", ""],
        ],
        inputs=[image, item_from, item_from_other, item_to, item_to_other],
        outputs=image_output,
        fn=run_editing_quality,
        cache_examples=True,
    )
demo.launch()