File size: 5,391 Bytes
9091147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e84a799
9091147
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import numpy as np
from torchvision.transforms.functional import to_tensor
from PIL import Image

def blue_loss(images):
    """
    Custom loss function to penalize or encourage the presence of blue hues in the images.
    """
    # Convert images to tensors
    images_tensor = torch.tensor(images).float() / 255.0
    
    # Extract the blue channel (last channel in RGB)
    blue_channel = images_tensor[:, :, :, 2]
    
    # Calculate variance of the blue channel
    variance = torch.var(blue_channel)
    
    # Return negative variance as the loss (penalize less blue)
    return -variance

def generate_with_prompt_style_guidance(prompt, style, seed=42):
    prompt = prompt + ' in style of s'
    
    embed = torch.load(style)

    height = 512
    width = 512
    num_inference_steps = 10
    guidance_scale = 8
    generator = torch.manual_seed(seed)
    batch_size = 1
    contrast_loss_scale = 200
    blue_loss_scale = 100  # Scale for blue loss

    # Prep text
    text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]

    input_ids = text_input.input_ids.to(torch_device)

    # Get token embeddings
    token_embeddings = token_emb_layer(input_ids)

    # The new embedding - our special birb word
    replacement_token_embedding = embed[list(embed.keys())[0]].to(torch_device)

    # Insert this into the token embeddings
    token_embeddings[0, torch.where(input_ids[0] == 338)] = replacement_token_embedding.to(torch_device)

    # Combine with pos embs
    input_embeddings = token_embeddings + position_embeddings

    # Feed through to get final output embs
    modified_output_embeddings = get_output_embeds(input_embeddings)

    # And the uncond. input as before:
    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
        [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    with torch.no_grad():
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]

    text_embeddings = torch.cat([uncond_embeddings, modified_output_embeddings])

    # Prep Scheduler
    scheduler.set_timesteps(num_inference_steps)

    # Prep latents
    latents = torch.randn(
        (batch_size, unet.config.in_channels, height // 8, width // 8),
        generator=generator,
    )
    latents = latents.to(torch_device)
    latents = latents * scheduler.init_noise_sigma

    # Loop
    for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
        # Expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
        latent_model_input = torch.cat([latents] * 2)
        sigma = scheduler.sigmas[i]
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        with torch.no_grad():
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]

        # Perform CFG
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # Additional Guidance
        if i % 5 == 0:
            # Requires grad on the latents
            latents = latents.detach().requires_grad_()

            # Get the predicted x0
            latents_x0 = latents - sigma * noise_pred

            # Decode to image space
            denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5

            # Calculate losses
            contrast_loss_val = contrast_loss(denoised_images) * contrast_loss_scale
            blue_loss_val = blue_loss(denoised_images) * blue_loss_scale

            # Combine losses
            total_loss = contrast_loss_val + blue_loss_val

            # Get gradient
            cond_grad = torch.autograd.grad(total_loss, latents)[0]

            # Modify the latents based on this gradient
            latents = latents.detach() - cond_grad * sigma**2

        # Now step with scheduler
        latents = scheduler.step(noise_pred, t, latents).prev_sample

    return latents_to_pil(latents)[0]

import gradio as gr

dict_styles = { 
    'Dr Strange': 'styles/learned_embeds_dr_strange.bin', 
    'GTA-5':'styles/learned_embeds_gta5.bin', 
    'Manga':'styles/learned_embeds_manga.bin', 
    'Pokemon':'styles/learned_embeds_pokemon.bin', 
}

def inference(prompt, style): 
    if prompt is not None and style is not None: 
        style = dict_styles[style]
        result = generate_with_prompt_style_guidance(prompt, style)
        return np.array(result)
    else:
        return None

title = "Stable Diffusion and Textual Inversion"
description = "A simple Gradio interface to stylize Stable Diffusion outputs"
examples = [['A man sipping wine wearing a spacesuit on the moon']]

demo = gr.Interface(inference,
                    inputs=[gr.Textbox(label='Prompt'), 
                            gr.Dropdown(['Dr Strange', 'GTA-5', 'Manga', 'Pokemon'], label='Style')],
                    outputs=[gr.Image(label="Stable Diffusion Output")],
                    title=title,
                    description=description,
                    examples=examples)
demo.launch()