Vedansh-7 commited on
Commit
9cb7ad8
·
1 Parent(s): 736b0bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -22
app.py CHANGED
@@ -125,26 +125,21 @@ class DiffusionModel(nn.Module):
125
  self.model = model
126
  self.timesteps = timesteps
127
 
128
- # Better noise schedule for medical images
129
  scale = 1000 / timesteps
130
  beta_start = scale * 0.0001
131
  beta_end = scale * 0.02
132
  self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
133
 
134
- # More stable alpha calculations
135
  self.alphas = 1. - self.betas
136
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
137
  self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
138
  self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
139
-
140
- # Parameters for posterior variance
141
- self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / self.alphas_cumprod))
142
- self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / self.alphas_cumprod - 1))
143
 
144
  @torch.no_grad()
145
  def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
146
- # Initialize with proper scale
147
- x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.5
148
 
149
  if labels.ndim == 1:
150
  labels_one_hot = torch.zeros(num_images, num_classes, device=device)
@@ -167,37 +162,39 @@ class DiffusionModel(nn.Module):
167
  alpha_bar_t = self.alphas_cumprod[t]
168
  alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
169
 
170
- # Calculate coefficients for cleaner sampling
171
  beta_t = self.betas[t]
172
- sqrt_recip_alphas_t = self.sqrt_recip_alphas_cumprod[t]
173
- sqrt_one_minus_alphas_bar_t = self.sqrt_one_minus_alphas_cumprod[t]
174
 
175
- # Main denoising equation
176
- pred_x0 = (x_t - sqrt_one_minus_alphas_bar_t * pred_noise) / sqrt_recip_alphas_t
177
- pred_x0 = torch.clamp(pred_x0, -1., 1.)
178
 
179
  # Calculate direction pointing to x_t
180
- dir_xt = torch.sqrt(1. - alpha_bar_t_prev - beta_t**2) * pred_noise
181
 
182
  # Noise for next step
183
  if t > 0:
184
- noise = torch.randn_like(x_t) * 0.25 # Reduced noise scale
185
  else:
186
  noise = torch.zeros_like(x_t)
187
 
188
- # Update x_t
189
- x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + dir_xt + noise * torch.sqrt(beta_t)
190
 
 
 
 
 
191
  if progress_callback:
192
  progress_callback((timesteps - t) / timesteps)
193
 
194
- # Better normalization approach
195
- x_t = torch.clamp(x_t, -1., 1.)
196
- x_t = (x_t + 1) / 2 # Scale to [0, 1]
197
 
198
  return x_t
199
 
200
-
201
  def load_model(model_path, device):
202
  unet = UNet(num_classes=NUM_CLASSES).to(device)
203
  diffusion_model = DiffusionModel(unet).to(device)
 
125
  self.model = model
126
  self.timesteps = timesteps
127
 
128
+ # More conservative noise schedule
129
  scale = 1000 / timesteps
130
  beta_start = scale * 0.0001
131
  beta_end = scale * 0.02
132
  self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
133
 
 
134
  self.alphas = 1. - self.betas
135
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
136
  self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
137
  self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
 
 
 
 
138
 
139
  @torch.no_grad()
140
  def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
141
+ # Initialize with standard normal distribution (scale=1.0)
142
+ x_t = torch.randn((num_images, 3, img_size, img_size), device=device)
143
 
144
  if labels.ndim == 1:
145
  labels_one_hot = torch.zeros(num_images, num_classes, device=device)
 
162
  alpha_bar_t = self.alphas_cumprod[t]
163
  alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
164
 
165
+ # Calculate coefficients
166
  beta_t = self.betas[t]
167
+ sqrt_recip_alpha_t = torch.sqrt(1.0 / alpha_t)
168
+ sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)
169
 
170
+ # Calculate predicted x0
171
+ pred_x0 = (x_t - sqrt_one_minus_alpha_bar_t * pred_noise) * sqrt_recip_alpha_t
 
172
 
173
  # Calculate direction pointing to x_t
174
+ pred_dir = torch.sqrt(1.0 - alpha_bar_t_prev) * pred_noise
175
 
176
  # Noise for next step
177
  if t > 0:
178
+ noise = torch.randn_like(x_t) * 0.5
179
  else:
180
  noise = torch.zeros_like(x_t)
181
 
182
+ # Update x_t with stability checks
183
+ x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + pred_dir + noise * torch.sqrt(beta_t)
184
 
185
+ # Numerical stability check
186
+ if torch.isnan(x_t).any() or torch.isinf(x_t).any():
187
+ x_t = torch.randn_like(x_t) * 0.1
188
+
189
  if progress_callback:
190
  progress_callback((timesteps - t) / timesteps)
191
 
192
+ # Gentle normalization
193
+ x_t = (x_t - x_t.min()) / (x_t.max() - x_t.min() + 1e-8) # [0, 1]
194
+ x_t = torch.clamp(x_t, 0, 1) # Final safety clamp
195
 
196
  return x_t
197
 
 
198
  def load_model(model_path, device):
199
  unet = UNet(num_classes=NUM_CLASSES).to(device)
200
  diffusion_model = DiffusionModel(unet).to(device)