Shivdutta commited on
Commit
9091147
·
verified ·
1 Parent(s): e95351d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +151 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torchvision.transforms.functional import to_tensor
4
+ from PIL import Image
5
+
6
+ def blue_loss(images):
7
+ """
8
+ Custom loss function to penalize or encourage the presence of blue hues in the images.
9
+ """
10
+ # Convert images to tensors
11
+ images_tensor = torch.tensor(images).float() / 255.0
12
+
13
+ # Extract the blue channel (last channel in RGB)
14
+ blue_channel = images_tensor[:, :, :, 2]
15
+
16
+ # Calculate variance of the blue channel
17
+ variance = torch.var(blue_channel)
18
+
19
+ # Return negative variance as the loss (penalize less blue)
20
+ return -variance
21
+
22
+ def generate_with_prompt_style_guidance(prompt, style, seed=42):
23
+ prompt = prompt + ' in style of s'
24
+
25
+ embed = torch.load(style)
26
+
27
+ height = 512
28
+ width = 512
29
+ num_inference_steps = 10
30
+ guidance_scale = 8
31
+ generator = torch.manual_seed(seed)
32
+ batch_size = 1
33
+ contrast_loss_scale = 200
34
+ blue_loss_scale = 100 # Scale for blue loss
35
+
36
+ # Prep text
37
+ text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
38
+ with torch.no_grad():
39
+ text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
40
+
41
+ input_ids = text_input.input_ids.to(torch_device)
42
+
43
+ # Get token embeddings
44
+ token_embeddings = token_emb_layer(input_ids)
45
+
46
+ # The new embedding - our special birb word
47
+ replacement_token_embedding = embed[list(embed.keys())[0]].to(torch_device)
48
+
49
+ # Insert this into the token embeddings
50
+ token_embeddings[0, torch.where(input_ids[0] == 338)] = replacement_token_embedding.to(torch_device)
51
+
52
+ # Combine with pos embs
53
+ input_embeddings = token_embeddings + position_embeddings
54
+
55
+ # Feed through to get final output embs
56
+ modified_output_embeddings = get_output_embeds(input_embeddings)
57
+
58
+ # And the uncond. input as before:
59
+ max_length = text_input.input_ids.shape[-1]
60
+ uncond_input = tokenizer(
61
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
62
+ )
63
+ with torch.no_grad():
64
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
65
+
66
+ text_embeddings = torch.cat([uncond_embeddings, modified_output_embeddings])
67
+
68
+ # Prep Scheduler
69
+ scheduler.set_timesteps(num_inference_steps)
70
+
71
+ # Prep latents
72
+ latents = torch.randn(
73
+ (batch_size, unet.config.in_channels, height // 8, width // 8),
74
+ generator=generator,
75
+ )
76
+ latents = latents.to(torch_device)
77
+ latents = latents * scheduler.init_noise_sigma
78
+
79
+ # Loop
80
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
81
+ # Expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
82
+ latent_model_input = torch.cat([latents] * 2)
83
+ sigma = scheduler.sigmas[i]
84
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
85
+
86
+ # Predict the noise residual
87
+ with torch.no_grad():
88
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
89
+
90
+ # Perform CFG
91
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
92
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
93
+
94
+ # Additional Guidance
95
+ if i % 5 == 0:
96
+ # Requires grad on the latents
97
+ latents = latents.detach().requires_grad_()
98
+
99
+ # Get the predicted x0
100
+ latents_x0 = latents - sigma * noise_pred
101
+
102
+ # Decode to image space
103
+ denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
104
+
105
+ # Calculate losses
106
+ contrast_loss_val = contrast_loss(denoised_images) * contrast_loss_scale
107
+ blue_loss_val = blue_loss(denoised_images) * blue_loss_scale
108
+
109
+ # Combine losses
110
+ total_loss = contrast_loss_val + blue_loss_val
111
+
112
+ # Get gradient
113
+ cond_grad = torch.autograd.grad(total_loss, latents)[0]
114
+
115
+ # Modify the latents based on this gradient
116
+ latents = latents.detach() - cond_grad * sigma**2
117
+
118
+ # Now step with scheduler
119
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
120
+
121
+ return latents_to_pil(latents)[0]
122
+
123
+ import gradio as gr
124
+
125
+ dict_styles = {
126
+ 'Dr Strange': 'styles/learned_embeds_dr_strange.bin',
127
+ 'GTA-5':'styles/learned_embeds_gta5.bin',
128
+ 'Manga':'styles/learned_embeds_manga.bin',
129
+ 'Pokemon':'styles/learned_embeds_pokemon.bin',
130
+ }
131
+
132
+ def inference(prompt, style):
133
+ if prompt is not None and style is not None:
134
+ style = dict_styles[style]
135
+ result = generate_with_prompt_style_guidance(prompt, style)
136
+ return np.array(result)
137
+ else:
138
+ return None
139
+
140
+ title = "Stable Diffusion and Textual Inversion"
141
+ description = "A simple Gradio interface to stylize Stable Diffusion outputs"
142
+ examples = [['A man sipping wine wearing a spacesuit on the moon', 'Stripes']]
143
+
144
+ demo = gr.Interface(inference,
145
+ inputs=[gr.Textbox(label='Prompt'),
146
+ gr.Dropdown(['Dr Strange', 'GTA-5', 'Manga', 'Pokemon'], label='Style')],
147
+ outputs=[gr.Image(label="Stable Diffusion Output")],
148
+ title=title,
149
+ description=description,
150
+ examples=examples)
151
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers==4.25.1
3
+ diffusers
4
+ ftfy
5
+ torchvision
6
+ tqdm
7
+ numpy
8
+ accelerate
9
+ scipy
10
+ Pillow