Vedansh-7 commited on
Commit
e92022e
·
1 Parent(s): bb3aba9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -22
app.py CHANGED
@@ -120,64 +120,58 @@ class UNet(nn.Module):
120
  return output
121
 
122
  class DiffusionModel(nn.Module):
123
- def __init__(self, model, timesteps=TIMESTEPS, time_dim=256):
124
  super().__init__()
125
  self.model = model
126
  self.timesteps = timesteps
127
 
128
- # Noise schedule from working code
129
  beta_start = 0.0001
130
  beta_end = 0.02
131
  self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
132
-
133
  self.alphas = 1. - self.betas
134
  self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
135
- self.register_buffer('sqrt_one_minus_alpha_bars', torch.sqrt(1. - self.alpha_bars))
136
- self.register_buffer('sqrt_recip_alphas', torch.sqrt(1. / self.alphas))
137
 
138
  @torch.no_grad()
139
  def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
140
- """Improved sampling method based on working code"""
 
141
  x_t = torch.randn((num_images, 3, img_size, img_size), device=device)
142
 
143
- # Handle labels (class indices or one-hot)
144
  if labels.ndim == 1:
145
  labels = torch.zeros(num_images, num_classes, device=device).scatter_(1, labels.unsqueeze(1), 1)
146
- else:
147
- labels = labels.float().to(device)
148
 
 
149
  for t in reversed(range(self.timesteps)):
150
  if cancel_event.is_set():
151
  return None
152
 
153
  t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
154
-
155
- # Predict noise with model
156
- pred_noise = self.model(x_t, labels, t_tensor)
157
-
158
- # Calculate coefficients from working code
159
  beta_t = self.betas[t].to(device)
160
  alpha_t = self.alphas[t].to(device)
161
  alpha_bar_t = self.alpha_bars[t].to(device)
162
-
163
- # Improved reverse diffusion step
164
- mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * pred_noise)
165
  variance = beta_t
166
-
167
  if t > 0:
168
  noise = torch.randn_like(x_t)
169
  else:
170
  noise = torch.zeros_like(x_t)
171
-
172
  x_t = mean + torch.sqrt(variance) * noise
173
 
174
  if progress_callback:
175
  progress_callback((self.timesteps - t) / self.timesteps)
176
 
177
- # Improved normalization from working code
178
  x_t = torch.clamp(x_t, -1., 1.)
179
-
180
- # Denormalize using ImageNet stats (from working code)
181
  mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
182
  std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
183
  x_t = std * x_t + mean
 
120
  return output
121
 
122
  class DiffusionModel(nn.Module):
123
+ def __init__(self, model, timesteps=TIMESTEPS):
124
  super().__init__()
125
  self.model = model
126
  self.timesteps = timesteps
127
 
128
+ # Use the exact same noise schedule as Colab
129
  beta_start = 0.0001
130
  beta_end = 0.02
131
  self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
 
132
  self.alphas = 1. - self.betas
133
  self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
 
 
134
 
135
  @torch.no_grad()
136
  def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
137
+ """Identical implementation to Colab version"""
138
+ # Start with random noise (same scale)
139
  x_t = torch.randn((num_images, 3, img_size, img_size), device=device)
140
 
141
+ # Identical label handling
142
  if labels.ndim == 1:
143
  labels = torch.zeros(num_images, num_classes, device=device).scatter_(1, labels.unsqueeze(1), 1)
144
+ labels = labels.to(device)
 
145
 
146
+ # Same sampling loop
147
  for t in reversed(range(self.timesteps)):
148
  if cancel_event.is_set():
149
  return None
150
 
151
  t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
152
+ predicted_noise = self.model(x_t, labels, t_tensor)
153
+
154
+ # Identical coefficients calculation
 
 
155
  beta_t = self.betas[t].to(device)
156
  alpha_t = self.alphas[t].to(device)
157
  alpha_bar_t = self.alpha_bars[t].to(device)
158
+
159
+ # Same mean/variance calculation
160
+ mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
161
  variance = beta_t
162
+
163
  if t > 0:
164
  noise = torch.randn_like(x_t)
165
  else:
166
  noise = torch.zeros_like(x_t)
167
+
168
  x_t = mean + torch.sqrt(variance) * noise
169
 
170
  if progress_callback:
171
  progress_callback((self.timesteps - t) / self.timesteps)
172
 
173
+ # Identical denormalization
174
  x_t = torch.clamp(x_t, -1., 1.)
 
 
175
  mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
176
  std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
177
  x_t = std * x_t + mean