Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -125,7 +125,6 @@ class DiffusionModel(nn.Module):
|
|
125 |
self.model = model
|
126 |
self.timesteps = timesteps
|
127 |
|
128 |
-
# Use the exact same noise schedule as Colab
|
129 |
beta_start = 0.0001
|
130 |
beta_end = 0.02
|
131 |
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
|
@@ -134,29 +133,28 @@ class DiffusionModel(nn.Module):
|
|
134 |
|
135 |
@torch.no_grad()
|
136 |
def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
|
137 |
-
"""
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
# Identical label handling
|
142 |
if labels.ndim == 1:
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
145 |
|
146 |
-
# Same sampling loop
|
147 |
for t in reversed(range(self.timesteps)):
|
148 |
if cancel_event.is_set():
|
149 |
return None
|
150 |
|
151 |
-
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
|
|
|
152 |
predicted_noise = self.model(x_t, labels, t_tensor)
|
153 |
|
154 |
-
# Identical coefficients calculation
|
155 |
beta_t = self.betas[t].to(device)
|
156 |
alpha_t = self.alphas[t].to(device)
|
157 |
alpha_bar_t = self.alpha_bars[t].to(device)
|
158 |
|
159 |
-
# Same mean/variance calculation
|
160 |
mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
|
161 |
variance = beta_t
|
162 |
|
@@ -170,35 +168,15 @@ class DiffusionModel(nn.Module):
|
|
170 |
if progress_callback:
|
171 |
progress_callback((self.timesteps - t) / self.timesteps)
|
172 |
|
173 |
-
|
174 |
-
|
175 |
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
176 |
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
return
|
181 |
|
182 |
-
def _post_process(self, images):
|
183 |
-
"""Apply post-processing to reduce noise and enhance contrast"""
|
184 |
-
# Normalize to [0,1]
|
185 |
-
images = torch.clamp(images, -1, 1)
|
186 |
-
images = (images + 1) / 2
|
187 |
-
|
188 |
-
# Apply mild blur (convert NHWC to NCHW for conv2d)
|
189 |
-
if images.dim() == 4 and images.shape[-1] != 3: # NCHW format
|
190 |
-
images = images.permute(0, 2, 3, 1)
|
191 |
-
|
192 |
-
x = images.permute(0, 3, 1, 2) # NHWC to NCHW
|
193 |
-
x = torch.nn.functional.conv2d(x, self.blur_kernel, padding=1, groups=3)
|
194 |
-
images = x.permute(0, 2, 3, 1) # NCHW to NHWC
|
195 |
-
|
196 |
-
# Contrast adjustment
|
197 |
-
mean_val = images.mean(dim=(1,2,3), keepdim=True)
|
198 |
-
images = (images - mean_val) * 1.2 + mean_val
|
199 |
-
|
200 |
-
return torch.clamp(images, 0, 1)
|
201 |
-
|
202 |
def load_model(model_path, device):
|
203 |
unet = UNet(num_classes=NUM_CLASSES).to(device)
|
204 |
diffusion_model = DiffusionModel(unet).to(device)
|
@@ -207,20 +185,17 @@ def load_model(model_path, device):
|
|
207 |
try:
|
208 |
checkpoint = torch.load(model_path, map_location=device)
|
209 |
|
210 |
-
# Handle both full model and state_dict loading
|
211 |
if 'model_state_dict' in checkpoint:
|
212 |
state_dict = checkpoint['model_state_dict']
|
213 |
else:
|
214 |
state_dict = checkpoint
|
215 |
|
216 |
-
# Handle both prefixed and non-prefixed state dicts
|
217 |
if all(k.startswith('model.') for k in state_dict.keys()):
|
218 |
state_dict = {k[6:]: v for k, v in state_dict.items()}
|
219 |
|
220 |
unet.load_state_dict(state_dict, strict=False)
|
221 |
print("Model loaded successfully")
|
222 |
|
223 |
-
# Verify model loading
|
224 |
test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
|
225 |
test_labels = torch.zeros(1, NUM_CLASSES).to(device)
|
226 |
test_time = torch.tensor([1]).to(device)
|
@@ -244,7 +219,6 @@ try:
|
|
244 |
print("Model loaded successfully!")
|
245 |
except Exception as e:
|
246 |
print(f"Failed to load model: {e}")
|
247 |
-
# Create a dummy model if loading fails
|
248 |
print("Creating dummy model for demonstration")
|
249 |
loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES)).to(device)
|
250 |
|
@@ -263,7 +237,6 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
263 |
if label_str not in label_map:
|
264 |
raise gr.Error("Invalid condition selected")
|
265 |
|
266 |
-
# Create one-hot encoded labels
|
267 |
labels = torch.zeros(num_images, NUM_CLASSES)
|
268 |
labels[:, label_map[label_str]] = 1
|
269 |
|
@@ -288,7 +261,6 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
288 |
|
289 |
processed_images = []
|
290 |
for img in images:
|
291 |
-
# Convert to numpy and permute dimensions (C,H,W) -> (H,W,C)
|
292 |
img_np = img.cpu().permute(1, 2, 0).numpy()
|
293 |
img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
|
294 |
pil_img = Image.fromarray(img_np)
|
|
|
125 |
self.model = model
|
126 |
self.timesteps = timesteps
|
127 |
|
|
|
128 |
beta_start = 0.0001
|
129 |
beta_end = 0.02
|
130 |
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
|
|
|
133 |
|
134 |
@torch.no_grad()
|
135 |
def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
|
136 |
+
"""Your exact sampling function from Colab"""
|
137 |
+
x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
|
138 |
+
|
|
|
|
|
139 |
if labels.ndim == 1:
|
140 |
+
labels_one_hot = torch.zeros(num_images, num_classes).to(device)
|
141 |
+
labels_one_hot[torch.arange(num_images), labels] = 1
|
142 |
+
labels = labels_one_hot
|
143 |
+
else:
|
144 |
+
labels = labels.to(device)
|
145 |
|
|
|
146 |
for t in reversed(range(self.timesteps)):
|
147 |
if cancel_event.is_set():
|
148 |
return None
|
149 |
|
150 |
+
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float) # Pass time as float
|
151 |
+
|
152 |
predicted_noise = self.model(x_t, labels, t_tensor)
|
153 |
|
|
|
154 |
beta_t = self.betas[t].to(device)
|
155 |
alpha_t = self.alphas[t].to(device)
|
156 |
alpha_bar_t = self.alpha_bars[t].to(device)
|
157 |
|
|
|
158 |
mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
|
159 |
variance = beta_t
|
160 |
|
|
|
168 |
if progress_callback:
|
169 |
progress_callback((self.timesteps - t) / self.timesteps)
|
170 |
|
171 |
+
x_0 = torch.clamp(x_t, -1., 1.)
|
172 |
+
|
173 |
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
174 |
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
175 |
+
x_0 = std * x_0 + mean
|
176 |
+
x_0 = torch.clamp(x_0, 0., 1.)
|
177 |
+
|
178 |
+
return x_0
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
def load_model(model_path, device):
|
181 |
unet = UNet(num_classes=NUM_CLASSES).to(device)
|
182 |
diffusion_model = DiffusionModel(unet).to(device)
|
|
|
185 |
try:
|
186 |
checkpoint = torch.load(model_path, map_location=device)
|
187 |
|
|
|
188 |
if 'model_state_dict' in checkpoint:
|
189 |
state_dict = checkpoint['model_state_dict']
|
190 |
else:
|
191 |
state_dict = checkpoint
|
192 |
|
|
|
193 |
if all(k.startswith('model.') for k in state_dict.keys()):
|
194 |
state_dict = {k[6:]: v for k, v in state_dict.items()}
|
195 |
|
196 |
unet.load_state_dict(state_dict, strict=False)
|
197 |
print("Model loaded successfully")
|
198 |
|
|
|
199 |
test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
|
200 |
test_labels = torch.zeros(1, NUM_CLASSES).to(device)
|
201 |
test_time = torch.tensor([1]).to(device)
|
|
|
219 |
print("Model loaded successfully!")
|
220 |
except Exception as e:
|
221 |
print(f"Failed to load model: {e}")
|
|
|
222 |
print("Creating dummy model for demonstration")
|
223 |
loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES)).to(device)
|
224 |
|
|
|
237 |
if label_str not in label_map:
|
238 |
raise gr.Error("Invalid condition selected")
|
239 |
|
|
|
240 |
labels = torch.zeros(num_images, NUM_CLASSES)
|
241 |
labels[:, label_map[label_str]] = 1
|
242 |
|
|
|
261 |
|
262 |
processed_images = []
|
263 |
for img in images:
|
|
|
264 |
img_np = img.cpu().permute(1, 2, 0).numpy()
|
265 |
img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
|
266 |
pil_img = Image.fromarray(img_np)
|