awacke1 commited on
Commit
3a68d7d
·
verified ·
1 Parent(s): 72e6740

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -88
app.py CHANGED
@@ -14,8 +14,8 @@ from torchvision import transforms
14
  from torchvision.transforms import functional as TF
15
  from tqdm import trange
16
  from transformers import CLIPProcessor, CLIPModel
17
- from huggingface_hub import hf_hub_download # FIXED: Replaced deprecated function
18
- import gradio as gr # 🎨 The magic canvas for AI-powered image generation!
19
  import math
20
 
21
  # -----------------------------------------------------------------------------
@@ -27,10 +27,6 @@ import math
27
  class VQVAE2(nn.Module):
28
  def __init__(self, n_embed=8192, embed_dim=256, ch=128):
29
  super().__init__()
30
- # This is a simplified placeholder. The actual architecture would be more complex.
31
- # The key is having a 'decode' method that matches the state_dict.
32
- # A full implementation would require the original model's architecture file.
33
- # For this fix, we assume a basic structure that allows loading the state_dict.
34
  self.decoder = nn.Sequential(
35
  nn.Conv2d(embed_dim, ch * 4, 3, padding=1),
36
  nn.ReLU(),
@@ -40,40 +36,32 @@ class VQVAE2(nn.Module):
40
  nn.ReLU(),
41
  nn.ConvTranspose2d(ch, 3, 4, stride=2, padding=1),
42
  )
43
-
44
  def decode(self, latents):
45
- # A real VQVAE would involve lookup tables, but for generation we only need the decoder part.
46
- # This part is highly dependent on the model checkpoint.
47
- # The following is a guess to make it runnable, assuming latents are ready for the decoder.
48
  return self.decoder(latents)
49
 
50
  # Diffusion Model Definition
51
  class Diffusion(nn.Module):
52
- def __init__(self, n_inputs=3, n_embed=512, n_head=8, n_layer=12):
 
53
  super().__init__()
54
- # This is also a placeholder for the architecture.
55
- # A full UNet-style model is expected here. The key is that it can be called
56
- # with x, t, and conditional embeddings, and returns the predicted noise.
57
  self.time_embed = nn.Embedding(1000, n_inputs * 4)
58
  self.cond_embed = nn.Linear(n_embed, n_inputs * 4)
59
-
60
  self.layers = nn.ModuleList([
61
- nn.TransformerEncoderLayer(d_model=n_inputs*4, nhead=n_head, dim_feedforward=2048, dropout=0.1, activation='gelu')
62
  for _ in range(n_layer)
63
  ])
64
  self.out = nn.Linear(n_inputs*4, n_inputs)
65
 
66
  def forward(self, x, t, c):
67
- # A very simplified forward pass
68
- # The actual model is likely a UNet with cross-attention.
69
  bs, ch, h, w = x.shape
70
  x = x.permute(0, 2, 3, 1).reshape(bs, h * w, ch)
71
-
72
  t_emb = self.time_embed(t.long())
73
  c_emb = self.cond_embed(c)
74
  emb = t_emb + c_emb
75
-
76
- # This is a gross simplification; a real model would use cross-attention here.
77
  x_out = self.out(x + emb.unsqueeze(1))
78
  x_out = x_out.reshape(bs, h, w, ch).permute(0, 3, 1, 2)
79
  return x_out
@@ -81,42 +69,35 @@ class Diffusion(nn.Module):
81
 
82
  # Sampling Function Definitions
83
  def get_sigmas(n_steps):
84
- """Returns the sigma schedule."""
85
  t = torch.linspace(1, 0, n_steps + 1)
86
  return ((t[:-1] ** 2) / (t[1:] ** 2) - 1).sqrt()
87
 
88
  @torch.no_grad()
89
  def plms_sample(model, x, steps, **kwargs):
90
- """Poor Man's LMS Sampler"""
91
  ts = x.new_ones([x.shape[0]])
92
  sigmas = get_sigmas(steps)
93
  model_fn = lambda x, t: model(x, t * 1000, **kwargs)
94
-
95
- x_outs = []
96
  old_denoised = None
97
-
98
  for i in trange(len(sigmas) -1, disable=True):
99
  denoised = model_fn(x, ts * sigmas[i])
100
-
101
  if old_denoised is None:
102
  d = (denoised - x) / sigmas[i]
103
  else:
104
- d = (3 * denoised - old_denoised) / 2 - x / sigmas[i] # LMS step
105
 
106
  x = x + d * (sigmas[i+1] - sigmas[i])
107
  old_denoised = denoised
108
- x_outs.append(x)
109
- return x_outs[-1]
110
 
111
- # NOTE: DDIM and DDPM samplers would be defined here as well if needed.
112
- # For simplicity, we are only defining the 'plms' sampler used in the UI default.
113
  def ddim_sample(model, x, steps, eta, **kwargs):
114
- # This is a placeholder for a full DDIM implementation
115
  print("Warning: DDIM sampler is not fully implemented. Using PLMS instead.")
116
  return plms_sample(model, x, steps, **kwargs)
117
 
118
  def ddpm_sample(model, x, steps, **kwargs):
119
- # This is a placeholder for a full DDPM implementation
120
  print("Warning: DDPM sampler is not fully implemented. Using PLMS instead.")
121
  return plms_sample(model, x, steps, **kwargs)
122
 
@@ -125,16 +106,12 @@ def ddpm_sample(model, x, steps, **kwargs):
125
  # -----------------------------------------------------------------------------
126
 
127
  # 🖼️ Download the necessary model files from HuggingFace
128
- # NOTE: The HuggingFace URLs you provided might be placeholders.
129
- # Make sure these point to the correct model files.
130
  try:
131
- # FIXED: Using the new hf_hub_download function with keyword arguments
132
  vqvae_model_path = hf_hub_download(repo_id="dalle-mini/vqgan_imagenet_f16_16384", filename="flax_model.msgpack")
133
  diffusion_model_path = hf_hub_download(repo_id="huggingface/dalle-2", filename="diffusion_model.ckpt")
134
  except Exception as e:
135
  print(f"Could not download models. Please ensure the HuggingFace URLs are correct.")
136
  print("Using placeholder models which will not produce good images.")
137
- # Create dummy files if download fails to allow script to run
138
  Path("vqvae_model.ckpt").touch()
139
  Path("diffusion_model.ckpt").touch()
140
  vqvae_model_path = "vqvae_model.ckpt"
@@ -142,26 +119,17 @@ except Exception as e:
142
 
143
 
144
  # 📐 Utility Functions: Math and images, what could go wrong?
145
- # These functions help parse prompts and resize/crop images to fit nicely
146
-
147
  def parse_prompt(prompt, default_weight=3.):
148
- """
149
- 🎯 Parses a prompt into text and weight.
150
- """
151
  vals = prompt.rsplit(':', 1)
152
  vals = vals + ['', default_weight][len(vals):]
153
  return vals[0], float(vals[1])
154
 
155
  def resize_and_center_crop(image, size):
156
- """
157
- ✂️ Resize and crop image to center it beautifully.
158
- """
159
  fac = max(size[0] / image.size[0], size[1] / image.size[1])
160
  image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
161
  return TF.center_crop(image, size[::-1])
162
 
163
  # 🧠 Model loading: the brain of our operation! 🔥
164
-
165
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
166
  print('Using device:', device)
167
  print('loading models... 🛠️')
@@ -171,23 +139,17 @@ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device
171
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
172
 
173
  # Load VQ-VAE-2 Autoencoder
174
- # NOTE: The VQVAE2 class is a placeholder. Loading a real checkpoint will likely fail
175
- # unless the class definition perfectly matches the architecture of the saved model.
176
  try:
177
  vqvae = VQVAE2()
178
- # vqvae.load_state_dict(torch.load(vqvae_model_path, map_location=device))
179
  print("Skipping VQVAE weight loading due to placeholder architecture.")
180
  except Exception as e:
181
  print(f"Could not load VQVAE model: {e}. Using placeholder.")
182
  vqvae = VQVAE2()
183
  vqvae.eval().requires_grad_(False).to(device)
184
 
185
-
186
  # Load Diffusion Model
187
- # NOTE: The Diffusion class is a placeholder. This will also likely fail.
188
  try:
189
  diffusion_model = Diffusion()
190
- # diffusion_model.load_state_dict(torch.load(diffusion_model_path, map_location=device))
191
  print("Skipping Diffusion Model weight loading due to placeholder architecture.")
192
  except Exception as e:
193
  print(f"Could not load Diffusion model: {e}. Using placeholder.")
@@ -196,27 +158,19 @@ diffusion_model = diffusion_model.to(device).eval().requires_grad_(False)
196
 
197
 
198
  # 🎨 The key function: Where the magic happens!
199
- # This is where we generate images based on text and image prompts
200
-
201
  def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='ddim', eta=None):
202
- """
203
- 🖼️ Generates a list of PIL images based on given text and image prompts.
204
- """
205
  zero_embed = torch.zeros([1, clip_model.config.projection_dim], device=device)
206
  target_embeds, weights = [zero_embed], []
207
 
208
- # Parse text prompts and encode with CLIP
209
  for prompt in prompts:
210
  inputs = clip_processor(text=prompt, return_tensors="pt").to(device)
211
  text_embed = clip_model.get_text_features(**inputs).float()
212
  target_embeds.append(text_embed)
213
  weights.append(1.0)
214
 
215
- # Correctly process image prompts from Gradio
216
- # Assign a default weight for image prompts
217
  image_prompt_weight = 1.0
218
  for image_path in images:
219
- if image_path: # Check if a path was actually provided
220
  try:
221
  img = Image.open(image_path).convert('RGB')
222
  img = resize_and_center_crop(img, (224, 224))
@@ -227,28 +181,23 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
227
  except Exception as e:
228
  print(f"Warning: Could not process image prompt {image_path}. Error: {e}")
229
 
230
-
231
- # Adjust weights and set seed
232
  weights = torch.tensor([1 - sum(weights), *weights], device=device)
233
  torch.manual_seed(seed)
234
 
235
- # 💡 Model function with classifier-free guidance
236
  def cfg_model_fn(x, t):
237
  n = x.shape[0]
238
  n_conds = len(target_embeds)
239
  x_in = x.repeat([n_conds, 1, 1, 1])
240
  t_in = t.repeat([n_conds])
241
  embed_in = torch.cat(target_embeds).repeat_interleave(n, 0)
242
-
243
- # Ensure correct dimensions for the placeholder Diffusion model
244
  if isinstance(diffusion_model, Diffusion):
245
- embed_in = embed_in[:, :512] # Adjust embed dim if needed
246
-
247
  vs = diffusion_model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]])
248
  v = vs.mul(weights[:, None, None, None, None]).sum(0)
249
  return v
250
 
251
- # 🎞️ Run the sampler to generate images
252
  def run(x, steps):
253
  if method == 'ddpm':
254
  return ddpm_sample(cfg_model_fn, x, steps)
@@ -258,52 +207,39 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
258
  return plms_sample(cfg_model_fn, x, steps)
259
  assert False, f"Unknown method: {method}"
260
 
261
- # 🏃‍♂️ Generate the output images
262
  batch_size = n
263
  x = torch.randn([n, 3, 64, 64], device=device)
264
-
265
  pil_ims = []
266
  for i in trange(0, n, batch_size):
267
  cur_batch_size = min(n - i, batch_size)
268
  out_latents = run(x[i:i + cur_batch_size], steps)
269
-
270
- # The VQVAE expects specific dimensions. Adjusting for the placeholder.
271
- # This will likely need tuning for the real model.
272
  if isinstance(vqvae, VQVAE2):
273
- out_latents = F.interpolate(out_latents, size=32) # Guessing latent size
274
- # A real VQVAE needs quantized inputs, not raw latents. This will not produce good images.
275
- # We're just making it runnable.
276
- quant_guess = F.gumbel_softmax(out_latents, hard=True).permute(0, 2, 3, 1) # (B, H, W, C)
277
- pil_ims.append(transforms.ToPILImage()(quant_guess[0].permute(2, 0, 1)))
278
  else:
279
  outs = vqvae.decode(out_latents)
280
  for j, out in enumerate(outs):
281
  pil_ims.append(transforms.ToPILImage()(out.clamp(0, 1)))
282
-
283
  return pil_ims
284
 
285
  # 🖌️ Interface: Gradio's brush to paint the UI
286
  def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
287
- """
288
- 💡 Gradio function to wrap image generation.
289
- """
290
  if seed is None:
291
  seed = random.randint(0, 10000)
292
  prompts = [prompt]
293
  im_prompts = []
294
  if im_prompt is not None:
295
  im_prompts = [im_prompt]
296
-
297
  try:
298
  pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
299
  return pil_ims[0]
300
  except Exception as e:
301
  print(f"ERROR during generation: {e}")
302
- # Return a blank image on failure
303
  return Image.new('RGB', (256, 256), color = 'red')
304
 
305
-
306
- # 🖼️ Gradio UI: The interface where users can input text or image prompts
307
  iface = gr.Interface(
308
  fn=gen_ims,
309
  inputs=[
 
14
  from torchvision.transforms import functional as TF
15
  from tqdm import trange
16
  from transformers import CLIPProcessor, CLIPModel
17
+ from huggingface_hub import hf_hub_download
18
+ import gradio as gr
19
  import math
20
 
21
  # -----------------------------------------------------------------------------
 
27
  class VQVAE2(nn.Module):
28
  def __init__(self, n_embed=8192, embed_dim=256, ch=128):
29
  super().__init__()
 
 
 
 
30
  self.decoder = nn.Sequential(
31
  nn.Conv2d(embed_dim, ch * 4, 3, padding=1),
32
  nn.ReLU(),
 
36
  nn.ReLU(),
37
  nn.ConvTranspose2d(ch, 3, 4, stride=2, padding=1),
38
  )
39
+
40
  def decode(self, latents):
 
 
 
41
  return self.decoder(latents)
42
 
43
  # Diffusion Model Definition
44
  class Diffusion(nn.Module):
45
+ # FIXED: Changed n_head default from 8 to 4 to make dimensions compatible
46
+ def __init__(self, n_inputs=3, n_embed=512, n_head=4, n_layer=12):
47
  super().__init__()
 
 
 
48
  self.time_embed = nn.Embedding(1000, n_inputs * 4)
49
  self.cond_embed = nn.Linear(n_embed, n_inputs * 4)
50
+
51
  self.layers = nn.ModuleList([
52
+ nn.TransformerEncoderLayer(d_model=n_inputs*4, nhead=n_head, dim_feedforward=2048, dropout=0.1, activation='gelu', batch_first=True)
53
  for _ in range(n_layer)
54
  ])
55
  self.out = nn.Linear(n_inputs*4, n_inputs)
56
 
57
  def forward(self, x, t, c):
 
 
58
  bs, ch, h, w = x.shape
59
  x = x.permute(0, 2, 3, 1).reshape(bs, h * w, ch)
60
+
61
  t_emb = self.time_embed(t.long())
62
  c_emb = self.cond_embed(c)
63
  emb = t_emb + c_emb
64
+
 
65
  x_out = self.out(x + emb.unsqueeze(1))
66
  x_out = x_out.reshape(bs, h, w, ch).permute(0, 3, 1, 2)
67
  return x_out
 
69
 
70
  # Sampling Function Definitions
71
  def get_sigmas(n_steps):
 
72
  t = torch.linspace(1, 0, n_steps + 1)
73
  return ((t[:-1] ** 2) / (t[1:] ** 2) - 1).sqrt()
74
 
75
  @torch.no_grad()
76
  def plms_sample(model, x, steps, **kwargs):
 
77
  ts = x.new_ones([x.shape[0]])
78
  sigmas = get_sigmas(steps)
79
  model_fn = lambda x, t: model(x, t * 1000, **kwargs)
80
+
 
81
  old_denoised = None
82
+
83
  for i in trange(len(sigmas) -1, disable=True):
84
  denoised = model_fn(x, ts * sigmas[i])
85
+
86
  if old_denoised is None:
87
  d = (denoised - x) / sigmas[i]
88
  else:
89
+ d = (3 * denoised - old_denoised) / 2 - x / sigmas[i]
90
 
91
  x = x + d * (sigmas[i+1] - sigmas[i])
92
  old_denoised = denoised
 
 
93
 
94
+ return x
95
+
96
  def ddim_sample(model, x, steps, eta, **kwargs):
 
97
  print("Warning: DDIM sampler is not fully implemented. Using PLMS instead.")
98
  return plms_sample(model, x, steps, **kwargs)
99
 
100
  def ddpm_sample(model, x, steps, **kwargs):
 
101
  print("Warning: DDPM sampler is not fully implemented. Using PLMS instead.")
102
  return plms_sample(model, x, steps, **kwargs)
103
 
 
106
  # -----------------------------------------------------------------------------
107
 
108
  # 🖼️ Download the necessary model files from HuggingFace
 
 
109
  try:
 
110
  vqvae_model_path = hf_hub_download(repo_id="dalle-mini/vqgan_imagenet_f16_16384", filename="flax_model.msgpack")
111
  diffusion_model_path = hf_hub_download(repo_id="huggingface/dalle-2", filename="diffusion_model.ckpt")
112
  except Exception as e:
113
  print(f"Could not download models. Please ensure the HuggingFace URLs are correct.")
114
  print("Using placeholder models which will not produce good images.")
 
115
  Path("vqvae_model.ckpt").touch()
116
  Path("diffusion_model.ckpt").touch()
117
  vqvae_model_path = "vqvae_model.ckpt"
 
119
 
120
 
121
  # 📐 Utility Functions: Math and images, what could go wrong?
 
 
122
  def parse_prompt(prompt, default_weight=3.):
 
 
 
123
  vals = prompt.rsplit(':', 1)
124
  vals = vals + ['', default_weight][len(vals):]
125
  return vals[0], float(vals[1])
126
 
127
  def resize_and_center_crop(image, size):
 
 
 
128
  fac = max(size[0] / image.size[0], size[1] / image.size[1])
129
  image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
130
  return TF.center_crop(image, size[::-1])
131
 
132
  # 🧠 Model loading: the brain of our operation! 🔥
 
133
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
134
  print('Using device:', device)
135
  print('loading models... 🛠️')
 
139
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
140
 
141
  # Load VQ-VAE-2 Autoencoder
 
 
142
  try:
143
  vqvae = VQVAE2()
 
144
  print("Skipping VQVAE weight loading due to placeholder architecture.")
145
  except Exception as e:
146
  print(f"Could not load VQVAE model: {e}. Using placeholder.")
147
  vqvae = VQVAE2()
148
  vqvae.eval().requires_grad_(False).to(device)
149
 
 
150
  # Load Diffusion Model
 
151
  try:
152
  diffusion_model = Diffusion()
 
153
  print("Skipping Diffusion Model weight loading due to placeholder architecture.")
154
  except Exception as e:
155
  print(f"Could not load Diffusion model: {e}. Using placeholder.")
 
158
 
159
 
160
  # 🎨 The key function: Where the magic happens!
 
 
161
  def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='ddim', eta=None):
 
 
 
162
  zero_embed = torch.zeros([1, clip_model.config.projection_dim], device=device)
163
  target_embeds, weights = [zero_embed], []
164
 
 
165
  for prompt in prompts:
166
  inputs = clip_processor(text=prompt, return_tensors="pt").to(device)
167
  text_embed = clip_model.get_text_features(**inputs).float()
168
  target_embeds.append(text_embed)
169
  weights.append(1.0)
170
 
 
 
171
  image_prompt_weight = 1.0
172
  for image_path in images:
173
+ if image_path:
174
  try:
175
  img = Image.open(image_path).convert('RGB')
176
  img = resize_and_center_crop(img, (224, 224))
 
181
  except Exception as e:
182
  print(f"Warning: Could not process image prompt {image_path}. Error: {e}")
183
 
 
 
184
  weights = torch.tensor([1 - sum(weights), *weights], device=device)
185
  torch.manual_seed(seed)
186
 
 
187
  def cfg_model_fn(x, t):
188
  n = x.shape[0]
189
  n_conds = len(target_embeds)
190
  x_in = x.repeat([n_conds, 1, 1, 1])
191
  t_in = t.repeat([n_conds])
192
  embed_in = torch.cat(target_embeds).repeat_interleave(n, 0)
193
+
 
194
  if isinstance(diffusion_model, Diffusion):
195
+ embed_in = embed_in[:, :512]
196
+
197
  vs = diffusion_model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]])
198
  v = vs.mul(weights[:, None, None, None, None]).sum(0)
199
  return v
200
 
 
201
  def run(x, steps):
202
  if method == 'ddpm':
203
  return ddpm_sample(cfg_model_fn, x, steps)
 
207
  return plms_sample(cfg_model_fn, x, steps)
208
  assert False, f"Unknown method: {method}"
209
 
 
210
  batch_size = n
211
  x = torch.randn([n, 3, 64, 64], device=device)
 
212
  pil_ims = []
213
  for i in trange(0, n, batch_size):
214
  cur_batch_size = min(n - i, batch_size)
215
  out_latents = run(x[i:i + cur_batch_size], steps)
216
+
 
 
217
  if isinstance(vqvae, VQVAE2):
218
+ outs = vqvae.decode(out_latents)
219
+ for j, out in enumerate(outs):
220
+ pil_ims.append(transforms.ToPILImage()(out.clamp(0, 1)))
 
 
221
  else:
222
  outs = vqvae.decode(out_latents)
223
  for j, out in enumerate(outs):
224
  pil_ims.append(transforms.ToPILImage()(out.clamp(0, 1)))
 
225
  return pil_ims
226
 
227
  # 🖌️ Interface: Gradio's brush to paint the UI
228
  def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
 
 
 
229
  if seed is None:
230
  seed = random.randint(0, 10000)
231
  prompts = [prompt]
232
  im_prompts = []
233
  if im_prompt is not None:
234
  im_prompts = [im_prompt]
 
235
  try:
236
  pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
237
  return pil_ims[0]
238
  except Exception as e:
239
  print(f"ERROR during generation: {e}")
 
240
  return Image.new('RGB', (256, 256), color = 'red')
241
 
242
+ # 🖼️ Gradio UI
 
243
  iface = gr.Interface(
244
  fn=gen_ims,
245
  inputs=[