Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -124,18 +124,26 @@ class DiffusionModel(nn.Module):
|
|
124 |
super().__init__()
|
125 |
self.model = model
|
126 |
self.timesteps = timesteps
|
127 |
-
|
128 |
-
|
129 |
-
# Fix 1: Ensure consistent float32 types
|
130 |
scale = 1000 / timesteps
|
131 |
beta_start = scale * 0.0001
|
132 |
beta_end = scale * 0.02
|
133 |
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
|
|
|
|
|
134 |
self.alphas = 1. - self.betas
|
135 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
@torch.no_grad()
|
138 |
def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
|
|
|
139 |
x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.5
|
140 |
|
141 |
if labels.ndim == 1:
|
@@ -151,27 +159,41 @@ class DiffusionModel(nn.Module):
|
|
151 |
|
152 |
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
|
153 |
|
|
|
154 |
pred_noise = self.model(x_t, labels, t_tensor.float())
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
|
|
160 |
if t > 0:
|
161 |
-
noise = torch.randn_like(x_t) * 0.
|
162 |
else:
|
163 |
noise = torch.zeros_like(x_t)
|
164 |
|
165 |
-
|
166 |
-
x_t =
|
167 |
|
168 |
if progress_callback:
|
169 |
progress_callback((timesteps - t) / timesteps)
|
170 |
|
|
|
171 |
x_t = torch.clamp(x_t, -1., 1.)
|
172 |
-
|
173 |
-
max_val = x_t.max()
|
174 |
-
x_t = (x_t - min_val) / (max_val - min_val + 1e-8)
|
175 |
|
176 |
return x_t
|
177 |
|
|
|
124 |
super().__init__()
|
125 |
self.model = model
|
126 |
self.timesteps = timesteps
|
127 |
+
|
128 |
+
# Better noise schedule for medical images
|
|
|
129 |
scale = 1000 / timesteps
|
130 |
beta_start = scale * 0.0001
|
131 |
beta_end = scale * 0.02
|
132 |
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
|
133 |
+
|
134 |
+
# More stable alpha calculations
|
135 |
self.alphas = 1. - self.betas
|
136 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
137 |
+
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
|
138 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
|
139 |
+
|
140 |
+
# Parameters for posterior variance
|
141 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / self.alphas_cumprod))
|
142 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / self.alphas_cumprod - 1))
|
143 |
|
144 |
@torch.no_grad()
|
145 |
def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
|
146 |
+
# Initialize with proper scale
|
147 |
x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.5
|
148 |
|
149 |
if labels.ndim == 1:
|
|
|
159 |
|
160 |
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
|
161 |
|
162 |
+
# Predict noise with model
|
163 |
pred_noise = self.model(x_t, labels, t_tensor.float())
|
164 |
|
165 |
+
# Get current alpha values
|
166 |
+
alpha_t = self.alphas[t]
|
167 |
+
alpha_bar_t = self.alphas_cumprod[t]
|
168 |
+
alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
|
169 |
+
|
170 |
+
# Calculate coefficients for cleaner sampling
|
171 |
+
beta_t = self.betas[t]
|
172 |
+
sqrt_recip_alphas_t = self.sqrt_recip_alphas_cumprod[t]
|
173 |
+
sqrt_one_minus_alphas_bar_t = self.sqrt_one_minus_alphas_cumprod[t]
|
174 |
+
|
175 |
+
# Main denoising equation
|
176 |
+
pred_x0 = (x_t - sqrt_one_minus_alphas_bar_t * pred_noise) / sqrt_recip_alphas_t
|
177 |
+
pred_x0 = torch.clamp(pred_x0, -1., 1.)
|
178 |
+
|
179 |
+
# Calculate direction pointing to x_t
|
180 |
+
dir_xt = torch.sqrt(1. - alpha_bar_t_prev - beta_t**2) * pred_noise
|
181 |
|
182 |
+
# Noise for next step
|
183 |
if t > 0:
|
184 |
+
noise = torch.randn_like(x_t) * 0.25 # Reduced noise scale
|
185 |
else:
|
186 |
noise = torch.zeros_like(x_t)
|
187 |
|
188 |
+
# Update x_t
|
189 |
+
x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + dir_xt + noise * torch.sqrt(beta_t)
|
190 |
|
191 |
if progress_callback:
|
192 |
progress_callback((timesteps - t) / timesteps)
|
193 |
|
194 |
+
# Better normalization approach
|
195 |
x_t = torch.clamp(x_t, -1., 1.)
|
196 |
+
x_t = (x_t + 1) / 2 # Scale to [0, 1]
|
|
|
|
|
197 |
|
198 |
return x_t
|
199 |
|