Vedansh-7 commited on
Commit
522b335
·
1 Parent(s): f6ed1f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -13
app.py CHANGED
@@ -124,18 +124,26 @@ class DiffusionModel(nn.Module):
124
  super().__init__()
125
  self.model = model
126
  self.timesteps = timesteps
127
- self.time_dim = time_dim
128
-
129
- # Fix 1: Ensure consistent float32 types
130
  scale = 1000 / timesteps
131
  beta_start = scale * 0.0001
132
  beta_end = scale * 0.02
133
  self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
 
 
134
  self.alphas = 1. - self.betas
135
- self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
 
 
 
 
 
 
136
 
137
  @torch.no_grad()
138
  def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
 
139
  x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.5
140
 
141
  if labels.ndim == 1:
@@ -151,27 +159,41 @@ class DiffusionModel(nn.Module):
151
 
152
  t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
153
 
 
154
  pred_noise = self.model(x_t, labels, t_tensor.float())
155
 
156
- alpha_t = self.alphas[t].to(device)
157
- alpha_bar_t = self.alpha_bars[t].to(device)
158
- beta_t = self.betas[t].to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
 
160
  if t > 0:
161
- noise = torch.randn_like(x_t) * 0.5
162
  else:
163
  noise = torch.zeros_like(x_t)
164
 
165
- x_t = (x_t - (1 - alpha_t)/torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)
166
- x_t = x_t + noise * torch.sqrt(beta_t)
167
 
168
  if progress_callback:
169
  progress_callback((timesteps - t) / timesteps)
170
 
 
171
  x_t = torch.clamp(x_t, -1., 1.)
172
- min_val = x_t.min()
173
- max_val = x_t.max()
174
- x_t = (x_t - min_val) / (max_val - min_val + 1e-8)
175
 
176
  return x_t
177
 
 
124
  super().__init__()
125
  self.model = model
126
  self.timesteps = timesteps
127
+
128
+ # Better noise schedule for medical images
 
129
  scale = 1000 / timesteps
130
  beta_start = scale * 0.0001
131
  beta_end = scale * 0.02
132
  self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
133
+
134
+ # More stable alpha calculations
135
  self.alphas = 1. - self.betas
136
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
137
+ self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
138
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
139
+
140
+ # Parameters for posterior variance
141
+ self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / self.alphas_cumprod))
142
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / self.alphas_cumprod - 1))
143
 
144
  @torch.no_grad()
145
  def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
146
+ # Initialize with proper scale
147
  x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.5
148
 
149
  if labels.ndim == 1:
 
159
 
160
  t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
161
 
162
+ # Predict noise with model
163
  pred_noise = self.model(x_t, labels, t_tensor.float())
164
 
165
+ # Get current alpha values
166
+ alpha_t = self.alphas[t]
167
+ alpha_bar_t = self.alphas_cumprod[t]
168
+ alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
169
+
170
+ # Calculate coefficients for cleaner sampling
171
+ beta_t = self.betas[t]
172
+ sqrt_recip_alphas_t = self.sqrt_recip_alphas_cumprod[t]
173
+ sqrt_one_minus_alphas_bar_t = self.sqrt_one_minus_alphas_cumprod[t]
174
+
175
+ # Main denoising equation
176
+ pred_x0 = (x_t - sqrt_one_minus_alphas_bar_t * pred_noise) / sqrt_recip_alphas_t
177
+ pred_x0 = torch.clamp(pred_x0, -1., 1.)
178
+
179
+ # Calculate direction pointing to x_t
180
+ dir_xt = torch.sqrt(1. - alpha_bar_t_prev - beta_t**2) * pred_noise
181
 
182
+ # Noise for next step
183
  if t > 0:
184
+ noise = torch.randn_like(x_t) * 0.25 # Reduced noise scale
185
  else:
186
  noise = torch.zeros_like(x_t)
187
 
188
+ # Update x_t
189
+ x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + dir_xt + noise * torch.sqrt(beta_t)
190
 
191
  if progress_callback:
192
  progress_callback((timesteps - t) / timesteps)
193
 
194
+ # Better normalization approach
195
  x_t = torch.clamp(x_t, -1., 1.)
196
+ x_t = (x_t + 1) / 2 # Scale to [0, 1]
 
 
197
 
198
  return x_t
199