Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -125,26 +125,21 @@ class DiffusionModel(nn.Module):
|
|
125 |
self.model = model
|
126 |
self.timesteps = timesteps
|
127 |
|
128 |
-
#
|
129 |
scale = 1000 / timesteps
|
130 |
beta_start = scale * 0.0001
|
131 |
beta_end = scale * 0.02
|
132 |
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
|
133 |
|
134 |
-
# More stable alpha calculations
|
135 |
self.alphas = 1. - self.betas
|
136 |
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
137 |
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
|
138 |
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
|
139 |
-
|
140 |
-
# Parameters for posterior variance
|
141 |
-
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / self.alphas_cumprod))
|
142 |
-
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / self.alphas_cumprod - 1))
|
143 |
|
144 |
@torch.no_grad()
|
145 |
def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
|
146 |
-
# Initialize with
|
147 |
-
x_t = torch.randn((num_images, 3, img_size, img_size), device=device)
|
148 |
|
149 |
if labels.ndim == 1:
|
150 |
labels_one_hot = torch.zeros(num_images, num_classes, device=device)
|
@@ -167,37 +162,39 @@ class DiffusionModel(nn.Module):
|
|
167 |
alpha_bar_t = self.alphas_cumprod[t]
|
168 |
alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
|
169 |
|
170 |
-
# Calculate coefficients
|
171 |
beta_t = self.betas[t]
|
172 |
-
|
173 |
-
|
174 |
|
175 |
-
#
|
176 |
-
pred_x0 = (x_t -
|
177 |
-
pred_x0 = torch.clamp(pred_x0, -1., 1.)
|
178 |
|
179 |
# Calculate direction pointing to x_t
|
180 |
-
|
181 |
|
182 |
# Noise for next step
|
183 |
if t > 0:
|
184 |
-
noise = torch.randn_like(x_t) * 0.
|
185 |
else:
|
186 |
noise = torch.zeros_like(x_t)
|
187 |
|
188 |
-
# Update x_t
|
189 |
-
x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 +
|
190 |
|
|
|
|
|
|
|
|
|
191 |
if progress_callback:
|
192 |
progress_callback((timesteps - t) / timesteps)
|
193 |
|
194 |
-
#
|
195 |
-
x_t =
|
196 |
-
x_t = (x_t
|
197 |
|
198 |
return x_t
|
199 |
|
200 |
-
|
201 |
def load_model(model_path, device):
|
202 |
unet = UNet(num_classes=NUM_CLASSES).to(device)
|
203 |
diffusion_model = DiffusionModel(unet).to(device)
|
|
|
125 |
self.model = model
|
126 |
self.timesteps = timesteps
|
127 |
|
128 |
+
# More conservative noise schedule
|
129 |
scale = 1000 / timesteps
|
130 |
beta_start = scale * 0.0001
|
131 |
beta_end = scale * 0.02
|
132 |
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
|
133 |
|
|
|
134 |
self.alphas = 1. - self.betas
|
135 |
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
136 |
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
|
137 |
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
|
|
|
|
|
|
|
|
|
138 |
|
139 |
@torch.no_grad()
|
140 |
def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
|
141 |
+
# Initialize with standard normal distribution (scale=1.0)
|
142 |
+
x_t = torch.randn((num_images, 3, img_size, img_size), device=device)
|
143 |
|
144 |
if labels.ndim == 1:
|
145 |
labels_one_hot = torch.zeros(num_images, num_classes, device=device)
|
|
|
162 |
alpha_bar_t = self.alphas_cumprod[t]
|
163 |
alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
|
164 |
|
165 |
+
# Calculate coefficients
|
166 |
beta_t = self.betas[t]
|
167 |
+
sqrt_recip_alpha_t = torch.sqrt(1.0 / alpha_t)
|
168 |
+
sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)
|
169 |
|
170 |
+
# Calculate predicted x0
|
171 |
+
pred_x0 = (x_t - sqrt_one_minus_alpha_bar_t * pred_noise) * sqrt_recip_alpha_t
|
|
|
172 |
|
173 |
# Calculate direction pointing to x_t
|
174 |
+
pred_dir = torch.sqrt(1.0 - alpha_bar_t_prev) * pred_noise
|
175 |
|
176 |
# Noise for next step
|
177 |
if t > 0:
|
178 |
+
noise = torch.randn_like(x_t) * 0.5
|
179 |
else:
|
180 |
noise = torch.zeros_like(x_t)
|
181 |
|
182 |
+
# Update x_t with stability checks
|
183 |
+
x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + pred_dir + noise * torch.sqrt(beta_t)
|
184 |
|
185 |
+
# Numerical stability check
|
186 |
+
if torch.isnan(x_t).any() or torch.isinf(x_t).any():
|
187 |
+
x_t = torch.randn_like(x_t) * 0.1
|
188 |
+
|
189 |
if progress_callback:
|
190 |
progress_callback((timesteps - t) / timesteps)
|
191 |
|
192 |
+
# Gentle normalization
|
193 |
+
x_t = (x_t - x_t.min()) / (x_t.max() - x_t.min() + 1e-8) # [0, 1]
|
194 |
+
x_t = torch.clamp(x_t, 0, 1) # Final safety clamp
|
195 |
|
196 |
return x_t
|
197 |
|
|
|
198 |
def load_model(model_path, device):
|
199 |
unet = UNet(num_classes=NUM_CLASSES).to(device)
|
200 |
diffusion_model = DiffusionModel(unet).to(device)
|