Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -125,99 +125,91 @@ class DiffusionModel(nn.Module):
|
|
125 |
self.model = model
|
126 |
self.timesteps = timesteps
|
127 |
|
128 |
-
#
|
129 |
-
scale = 1000 / timesteps
|
130 |
beta_start = 0.0001
|
131 |
beta_end = 0.02
|
132 |
-
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
|
133 |
|
134 |
self.alphas = 1. - self.betas
|
135 |
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
|
|
|
|
136 |
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
|
137 |
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
|
138 |
-
|
139 |
-
@torch.no_grad()
|
140 |
-
def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
|
141 |
-
# Initialize with reduced noise scale
|
142 |
-
x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.7
|
143 |
-
|
144 |
-
# Convert labels if needed
|
145 |
-
if labels.ndim == 1:
|
146 |
-
labels_one_hot = torch.zeros(num_images, num_classes, device=device)
|
147 |
-
labels_one_hot[torch.arange(num_images), labels] = 1
|
148 |
-
labels = labels_one_hot
|
149 |
-
|
150 |
-
for t in reversed(range(timesteps)):
|
151 |
-
if cancel_event.is_set():
|
152 |
-
return None
|
153 |
-
|
154 |
-
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
|
155 |
|
156 |
-
#
|
157 |
-
|
|
|
158 |
|
159 |
-
#
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
-
|
165 |
-
|
166 |
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
-
#
|
171 |
-
if
|
172 |
-
|
173 |
-
noise = torch.randn_like(x_t) * noise_scale
|
174 |
-
else:
|
175 |
-
noise = torch.zeros_like(x_t)
|
176 |
|
177 |
-
|
178 |
-
|
|
|
179 |
|
180 |
-
#
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
# Enhanced normalization with contrast adjustment
|
185 |
-
x_t = torch.clamp(x_t, -1, 1)
|
186 |
-
x_t = (x_t + 1) / 2 # Scale to [0,1]
|
187 |
-
|
188 |
-
# Post-processing directly in the tensor
|
189 |
-
x_t = self._post_process(x_t)
|
190 |
-
|
191 |
-
return x_t
|
192 |
-
|
193 |
-
def _post_process(self, image_tensor):
|
194 |
-
"""Apply simple post-processing to reduce noise"""
|
195 |
-
# Contrast adjustment
|
196 |
-
mean_val = image_tensor.mean()
|
197 |
-
image_tensor = (image_tensor - mean_val) * 1.2 + mean_val
|
198 |
-
|
199 |
-
# Mild Gaussian blur (implemented as depthwise convolution)
|
200 |
-
if hasattr(self, '_blur_kernel'):
|
201 |
-
blur_kernel = self._blur_kernel.to(image_tensor.device)
|
202 |
-
else:
|
203 |
-
blur_kernel = torch.tensor([
|
204 |
-
[0.05, 0.1, 0.05],
|
205 |
-
[0.1, 0.4, 0.1],
|
206 |
-
[0.05, 0.1, 0.05]
|
207 |
-
], dtype=torch.float32).view(1, 1, 3, 3).repeat(3, 1, 1, 1)
|
208 |
-
self._blur_kernel = blur_kernel
|
209 |
|
210 |
-
|
211 |
-
padding = (1, 1, 1, 1)
|
212 |
-
image_tensor = torch.nn.functional.conv2d(
|
213 |
-
image_tensor.permute(0, 3, 1, 2), # NHWC to NCHW
|
214 |
-
blur_kernel,
|
215 |
-
padding=1,
|
216 |
-
groups=3
|
217 |
-
).permute(0, 2, 3, 1) # Back to NHWC
|
218 |
-
|
219 |
-
return torch.clamp(image_tensor, 0, 1)
|
220 |
-
|
221 |
|
222 |
def load_model(model_path, device):
|
223 |
unet = UNet(num_classes=NUM_CLASSES).to(device)
|
@@ -295,7 +287,6 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
295 |
with torch.no_grad():
|
296 |
images = loaded_model.sample(
|
297 |
num_images=num_images,
|
298 |
-
timesteps=int(TIMESTEPS * 1.5), # More timesteps for cleaner images
|
299 |
img_size=IMG_SIZE,
|
300 |
num_classes=NUM_CLASSES,
|
301 |
labels=labels,
|
|
|
125 |
self.model = model
|
126 |
self.timesteps = timesteps
|
127 |
|
128 |
+
# Improved noise schedule
|
|
|
129 |
beta_start = 0.0001
|
130 |
beta_end = 0.02
|
131 |
+
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
|
132 |
|
133 |
self.alphas = 1. - self.betas
|
134 |
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
135 |
+
|
136 |
+
# Pre-calculate values for sampling
|
137 |
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
|
138 |
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
|
139 |
+
self.register_buffer('sqrt_recip_alphas', torch.sqrt(1. / self.alphas))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
+
# Calculations for posterior q(x_{t-1} | x_t, x_0)
|
142 |
+
posterior_variance = self.betas * (1. - self.alphas_cumprod[:-1]) / (1. - self.alphas_cumprod[1:])
|
143 |
+
self.register_buffer('posterior_variance', posterior_variance)
|
144 |
|
145 |
+
# Blur kernel for post-processing
|
146 |
+
self.register_buffer('blur_kernel', torch.tensor([
|
147 |
+
[0.05, 0.1, 0.05],
|
148 |
+
[0.1, 0.4, 0.1],
|
149 |
+
[0.05, 0.1, 0.05]
|
150 |
+
], dtype=torch.float32).view(1, 1, 3, 3).repeat(3, 1, 1, 1))
|
151 |
+
|
152 |
+
@torch.no_grad()
|
153 |
+
def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
|
154 |
+
"""Generate samples from the model"""
|
155 |
+
shape = (num_images, 3, img_size, img_size)
|
156 |
+
x_t = torch.randn(shape, device=device) * 0.7 # Slightly reduced initial noise
|
157 |
|
158 |
+
if labels.ndim == 1:
|
159 |
+
labels = torch.zeros(num_images, num_classes, device=device).scatter_(1, labels.unsqueeze(1), 1)
|
160 |
|
161 |
+
for t in reversed(range(self.timesteps)):
|
162 |
+
if cancel_event.is_set():
|
163 |
+
return None
|
164 |
+
|
165 |
+
t_batch = torch.full((num_images,), t, device=device, dtype=torch.long)
|
166 |
+
pred_noise = self.model(x_t, labels, t_batch.float())
|
167 |
+
|
168 |
+
alpha_bar_t = self.alphas_cumprod[t]
|
169 |
+
alpha_bar_t_prev = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
|
170 |
+
|
171 |
+
# Calculate predicted x0
|
172 |
+
pred_x0 = (x_t - torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_bar_t)
|
173 |
+
|
174 |
+
# Calculate direction pointing to x_t
|
175 |
+
pred_dir = torch.sqrt(1 - alpha_bar_t_prev) * pred_noise
|
176 |
+
|
177 |
+
# Dynamic noise scaling
|
178 |
+
if t > 0:
|
179 |
+
noise_scale = 0.3 * (t / self.timesteps)
|
180 |
+
noise = torch.randn_like(x_t) * noise_scale
|
181 |
+
else:
|
182 |
+
noise = torch.zeros_like(x_t)
|
183 |
+
|
184 |
+
# Update x_t
|
185 |
+
x_t = torch.sqrt(alpha_bar_t_prev) * pred_x0 + pred_dir + noise
|
186 |
+
|
187 |
+
if progress_callback:
|
188 |
+
progress_callback((self.timesteps - t) / self.timesteps)
|
189 |
+
|
190 |
+
# Post-processing
|
191 |
+
x_t = self._post_process(x_t)
|
192 |
+
return x_t
|
193 |
+
|
194 |
+
def _post_process(self, images):
|
195 |
+
"""Apply post-processing to reduce noise and enhance contrast"""
|
196 |
+
# Normalize to [0,1]
|
197 |
+
images = torch.clamp(images, -1, 1)
|
198 |
+
images = (images + 1) / 2
|
199 |
|
200 |
+
# Apply mild blur (convert NHWC to NCHW for conv2d)
|
201 |
+
if images.dim() == 4 and images.shape[-1] != 3: # NCHW format
|
202 |
+
images = images.permute(0, 2, 3, 1)
|
|
|
|
|
|
|
203 |
|
204 |
+
x = images.permute(0, 3, 1, 2) # NHWC to NCHW
|
205 |
+
x = torch.nn.functional.conv2d(x, self.blur_kernel, padding=1, groups=3)
|
206 |
+
images = x.permute(0, 2, 3, 1) # NCHW to NHWC
|
207 |
|
208 |
+
# Contrast adjustment
|
209 |
+
mean_val = images.mean(dim=(1,2,3), keepdim=True)
|
210 |
+
images = (images - mean_val) * 1.2 + mean_val
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
+
return torch.clamp(images, 0, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
def load_model(model_path, device):
|
215 |
unet = UNet(num_classes=NUM_CLASSES).to(device)
|
|
|
287 |
with torch.no_grad():
|
288 |
images = loaded_model.sample(
|
289 |
num_images=num_images,
|
|
|
290 |
img_size=IMG_SIZE,
|
291 |
num_classes=NUM_CLASSES,
|
292 |
labels=labels,
|