Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -127,74 +127,98 @@ class DiffusionModel(nn.Module):
|
|
127 |
|
128 |
# More conservative noise schedule
|
129 |
scale = 1000 / timesteps
|
130 |
-
beta_start =
|
131 |
-
beta_end =
|
132 |
-
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
|
133 |
|
134 |
self.alphas = 1. - self.betas
|
135 |
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
136 |
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
|
137 |
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
else:
|
149 |
-
|
150 |
-
|
151 |
-
for t in reversed(range(timesteps)):
|
152 |
-
if cancel_event.is_set():
|
153 |
-
return None
|
154 |
-
|
155 |
-
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
|
156 |
-
|
157 |
-
# Predict noise with model
|
158 |
-
pred_noise = self.model(x_t, labels, t_tensor.float())
|
159 |
-
|
160 |
-
# Get current alpha values
|
161 |
-
alpha_t = self.alphas[t]
|
162 |
-
alpha_bar_t = self.alphas_cumprod[t]
|
163 |
-
alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
|
164 |
-
|
165 |
-
# Calculate coefficients
|
166 |
-
beta_t = self.betas[t]
|
167 |
-
sqrt_recip_alpha_t = torch.sqrt(1.0 / alpha_t)
|
168 |
-
sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)
|
169 |
-
|
170 |
-
# Calculate predicted x0
|
171 |
-
pred_x0 = (x_t - sqrt_one_minus_alpha_bar_t * pred_noise) * sqrt_recip_alpha_t
|
172 |
-
|
173 |
-
# Calculate direction pointing to x_t
|
174 |
-
pred_dir = torch.sqrt(1.0 - alpha_bar_t_prev) * pred_noise
|
175 |
-
|
176 |
-
# Noise for next step
|
177 |
-
if t > 0:
|
178 |
-
noise = torch.randn_like(x_t) * 0.5
|
179 |
-
else:
|
180 |
-
noise = torch.zeros_like(x_t)
|
181 |
-
|
182 |
-
# Update x_t with stability checks
|
183 |
-
x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + pred_dir + noise * torch.sqrt(beta_t)
|
184 |
|
185 |
-
|
186 |
-
|
187 |
-
x_t = torch.randn_like(x_t) * 0.1
|
188 |
-
|
189 |
-
if progress_callback:
|
190 |
-
progress_callback((timesteps - t) / timesteps)
|
191 |
-
|
192 |
-
# Gentle normalization
|
193 |
-
x_t = (x_t - x_t.min()) / (x_t.max() - x_t.min() + 1e-8) # [0, 1]
|
194 |
-
x_t = torch.clamp(x_t, 0, 1) # Final safety clamp
|
195 |
|
196 |
-
|
|
|
|
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
def load_model(model_path, device):
|
199 |
unet = UNet(num_classes=NUM_CLASSES).to(device)
|
200 |
diffusion_model = DiffusionModel(unet).to(device)
|
@@ -270,24 +294,30 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
270 |
|
271 |
with torch.no_grad():
|
272 |
images = loaded_model.sample(
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
)
|
281 |
-
|
282 |
if images is None:
|
283 |
return None, None
|
284 |
-
|
285 |
processed_images = []
|
286 |
for img in images:
|
287 |
-
img_np = img.cpu().numpy()
|
|
|
|
|
288 |
img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
|
289 |
pil_img = Image.fromarray(img_np)
|
|
|
|
|
|
|
290 |
processed_images.append(pil_img)
|
|
|
291 |
|
292 |
if num_images == 1:
|
293 |
return processed_images[0], processed_images
|
|
|
127 |
|
128 |
# More conservative noise schedule
|
129 |
scale = 1000 / timesteps
|
130 |
+
beta_start = 0.0001
|
131 |
+
beta_end = 0.02
|
132 |
+
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)**1.5
|
133 |
|
134 |
self.alphas = 1. - self.betas
|
135 |
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
136 |
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
|
137 |
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
|
138 |
|
139 |
+
@torch.no_grad()
|
140 |
+
def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
|
141 |
+
# Initialize with reduced noise scale
|
142 |
+
x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.7
|
143 |
+
|
144 |
+
# Convert labels if needed
|
145 |
+
if labels.ndim == 1:
|
146 |
+
labels_one_hot = torch.zeros(num_images, num_classes, device=device)
|
147 |
+
labels_one_hot[torch.arange(num_images), labels] = 1
|
148 |
+
labels = labels_one_hot
|
149 |
+
|
150 |
+
for t in reversed(range(timesteps)):
|
151 |
+
if cancel_event.is_set():
|
152 |
+
return None
|
153 |
+
|
154 |
+
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
|
155 |
+
|
156 |
+
# Predict noise with model
|
157 |
+
pred_noise = self.model(x_t, labels, t_tensor.float())
|
158 |
|
159 |
+
# Get current alpha values
|
160 |
+
alpha_t = self.alphas[t]
|
161 |
+
alpha_bar_t = self.alphas_cumprod[t]
|
162 |
+
alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
|
163 |
+
|
164 |
+
# Calculate predicted x0 with more stable equations
|
165 |
+
pred_x0 = (x_t - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t)
|
166 |
+
|
167 |
+
# Direction pointing to x_t with reduced noise impact
|
168 |
+
pred_dir = torch.sqrt(1 - alpha_bar_t_prev) * pred_noise
|
169 |
+
|
170 |
+
# Dynamic noise scaling based on timestep
|
171 |
+
if t > 0:
|
172 |
+
noise_scale = 0.3 * (t / timesteps) # Reduce noise as we get closer to final image
|
173 |
+
noise = torch.randn_like(x_t) * noise_scale
|
174 |
else:
|
175 |
+
noise = torch.zeros_like(x_t)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
+
# Update x_t with more stable combination
|
178 |
+
x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + pred_dir + noise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
+
# Progress callback
|
181 |
+
if progress_callback:
|
182 |
+
progress_callback((timesteps - t) / timesteps)
|
183 |
|
184 |
+
# Enhanced normalization with contrast adjustment
|
185 |
+
x_t = torch.clamp(x_t, -1, 1)
|
186 |
+
x_t = (x_t + 1) / 2 # Scale to [0,1]
|
187 |
+
|
188 |
+
# Post-processing directly in the tensor
|
189 |
+
x_t = self._post_process(x_t)
|
190 |
+
|
191 |
+
return x_t
|
192 |
+
|
193 |
+
def _post_process(self, image_tensor):
|
194 |
+
"""Apply simple post-processing to reduce noise"""
|
195 |
+
# Contrast adjustment
|
196 |
+
mean_val = image_tensor.mean()
|
197 |
+
image_tensor = (image_tensor - mean_val) * 1.2 + mean_val
|
198 |
+
|
199 |
+
# Mild Gaussian blur (implemented as depthwise convolution)
|
200 |
+
if hasattr(self, '_blur_kernel'):
|
201 |
+
blur_kernel = self._blur_kernel.to(image_tensor.device)
|
202 |
+
else:
|
203 |
+
blur_kernel = torch.tensor([
|
204 |
+
[0.05, 0.1, 0.05],
|
205 |
+
[0.1, 0.4, 0.1],
|
206 |
+
[0.05, 0.1, 0.05]
|
207 |
+
], dtype=torch.float32).view(1, 1, 3, 3).repeat(3, 1, 1, 1)
|
208 |
+
self._blur_kernel = blur_kernel
|
209 |
+
|
210 |
+
# Apply blur to each channel
|
211 |
+
padding = (1, 1, 1, 1)
|
212 |
+
image_tensor = torch.nn.functional.conv2d(
|
213 |
+
image_tensor.permute(0, 3, 1, 2), # NHWC to NCHW
|
214 |
+
blur_kernel,
|
215 |
+
padding=1,
|
216 |
+
groups=3
|
217 |
+
).permute(0, 2, 3, 1) # Back to NHWC
|
218 |
+
|
219 |
+
return torch.clamp(image_tensor, 0, 1)
|
220 |
+
|
221 |
+
|
222 |
def load_model(model_path, device):
|
223 |
unet = UNet(num_classes=NUM_CLASSES).to(device)
|
224 |
diffusion_model = DiffusionModel(unet).to(device)
|
|
|
294 |
|
295 |
with torch.no_grad():
|
296 |
images = loaded_model.sample(
|
297 |
+
num_images=num_images,
|
298 |
+
timesteps=int(TIMESTEPS * 1.5), # More timesteps for cleaner images
|
299 |
+
img_size=IMG_SIZE,
|
300 |
+
num_classes=NUM_CLASSES,
|
301 |
+
labels=labels,
|
302 |
+
device=device,
|
303 |
+
progress_callback=progress_callback
|
304 |
)
|
305 |
+
|
306 |
if images is None:
|
307 |
return None, None
|
308 |
+
|
309 |
processed_images = []
|
310 |
for img in images:
|
311 |
+
img_np = img.cpu().numpy()
|
312 |
+
|
313 |
+
# Convert to PIL with enhanced contrast
|
314 |
img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
|
315 |
pil_img = Image.fromarray(img_np)
|
316 |
+
|
317 |
+
# Apply additional PIL-based enhancements
|
318 |
+
pil_img = pil_img.filter(ImageFilter.SMOOTH_MORE)
|
319 |
processed_images.append(pil_img)
|
320 |
+
|
321 |
|
322 |
if num_images == 1:
|
323 |
return processed_images[0], processed_images
|