Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -10,7 +10,7 @@ import traceback
|
|
10 |
|
11 |
# Constants
|
12 |
IMG_SIZE = 128
|
13 |
-
TIMESTEPS =
|
14 |
NUM_CLASSES = 2
|
15 |
|
16 |
# Global Cancellation Flag
|
@@ -135,47 +135,56 @@ class DiffusionModel(nn.Module):
|
|
135 |
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
|
136 |
|
137 |
@torch.no_grad()
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
t = torch.full((num_images,), i, device=device, dtype=torch.long)
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
x_t += noise * torch.sqrt(beta_t)
|
172 |
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
-
# Fix 3: Simplified scaling
|
177 |
-
x_t = torch.clamp(x_t, -1., 1.)
|
178 |
-
return (x_t + 1) / 2 # Scale to [0,1]
|
179 |
|
180 |
def load_model(model_path, device):
|
181 |
unet = UNet(num_classes=NUM_CLASSES).to(device)
|
@@ -231,7 +240,7 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
231 |
if label_str not in label_map:
|
232 |
raise gr.Error("Invalid condition selected")
|
233 |
|
234 |
-
labels = torch.zeros(num_images, NUM_CLASSES, device=device
|
235 |
labels[:, label_map[label_str]] = 1
|
236 |
|
237 |
try:
|
@@ -242,10 +251,9 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
242 |
|
243 |
with torch.no_grad():
|
244 |
print(f"Generating {num_images} images for {label_str}")
|
245 |
-
print(f"Labels shape: {labels.shape}, device: {labels.device}")
|
246 |
-
|
247 |
images = loaded_model.sample(
|
248 |
num_images=num_images,
|
|
|
249 |
img_size=IMG_SIZE,
|
250 |
num_classes=NUM_CLASSES,
|
251 |
labels=labels,
|
@@ -256,17 +264,15 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
256 |
if images is None:
|
257 |
return None, None
|
258 |
|
259 |
-
# Diagnostic print
|
260 |
print(f"Generated images range: {images.min().item():.3f}, {images.max().item():.3f}")
|
261 |
|
262 |
processed_images = []
|
263 |
for img in images:
|
264 |
-
#
|
265 |
-
img_np =
|
266 |
-
|
267 |
|
268 |
-
|
269 |
-
img_np = img_np.squeeze(-1)
|
270 |
pil_img = Image.fromarray(img_np)
|
271 |
processed_images.append(pil_img)
|
272 |
|
@@ -276,30 +282,11 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
276 |
else:
|
277 |
return None, processed_images
|
278 |
|
279 |
-
except torch.cuda.OutOfMemoryError:
|
280 |
-
torch.cuda.empty_cache()
|
281 |
-
raise gr.Error("Out of GPU memory - try generating fewer images")
|
282 |
except Exception as e:
|
283 |
traceback.print_exc()
|
284 |
-
|
285 |
-
raise gr.Error(f"Generation failed: {str(e)}")
|
286 |
-
return None, None
|
287 |
finally:
|
288 |
torch.cuda.empty_cache()
|
289 |
-
|
290 |
-
# Load model
|
291 |
-
MODEL_NAME = "model_weights.pth" # Updated to look in root folder
|
292 |
-
model_path = MODEL_NAME
|
293 |
-
print("Loading model...")
|
294 |
-
try:
|
295 |
-
loaded_model = load_model(model_path, device)
|
296 |
-
print("Model loaded successfully!")
|
297 |
-
except Exception as e:
|
298 |
-
print(f"Failed to load model: {e}")
|
299 |
-
# Create a dummy model for demo purposes
|
300 |
-
print("Creating dummy model for demonstration")
|
301 |
-
loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES)).to(device)
|
302 |
-
|
303 |
|
304 |
# Gradio UI
|
305 |
with gr.Blocks(theme=gr.themes.Soft(
|
|
|
10 |
|
11 |
# Constants
|
12 |
IMG_SIZE = 128
|
13 |
+
TIMESTEPS = 500
|
14 |
NUM_CLASSES = 2
|
15 |
|
16 |
# Global Cancellation Flag
|
|
|
135 |
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
|
136 |
|
137 |
@torch.no_grad()
|
138 |
+
def sample(model, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
|
139 |
+
# Initialize with properly scaled noise
|
140 |
+
x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.5 # Reduced initial noise scale
|
141 |
+
|
142 |
+
# Convert labels to proper format
|
143 |
+
if labels.ndim == 1:
|
144 |
+
labels_one_hot = torch.zeros(num_images, num_classes, device=device)
|
145 |
+
labels_one_hot[torch.arange(num_images), labels] = 1
|
146 |
+
labels = labels_one_hot
|
147 |
+
else:
|
148 |
+
labels = labels.float().to(device)
|
149 |
|
150 |
+
# Reverse diffusion process
|
151 |
+
for t in reversed(range(timesteps)):
|
152 |
+
if cancel_event.is_set():
|
153 |
+
return None
|
|
|
154 |
|
155 |
+
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
|
156 |
+
|
157 |
+
# Model prediction with proper scaling
|
158 |
+
pred_noise = model.model(x_t, labels, t_tensor.float())
|
159 |
+
|
160 |
+
# Calculate diffusion parameters
|
161 |
+
alpha_t = model.alphas[t].to(device)
|
162 |
+
alpha_bar_t = model.alpha_bars[t].to(device)
|
163 |
+
beta_t = model.betas[t].to(device)
|
164 |
+
|
165 |
+
# Improved denoising step
|
166 |
+
if t > 0:
|
167 |
+
noise = torch.randn_like(x_t) * 0.5 # Reduced noise scale
|
168 |
+
else:
|
169 |
+
noise = torch.zeros_like(x_t)
|
|
|
170 |
|
171 |
+
# More stable prediction
|
172 |
+
x_t = (x_t - (1 - alpha_t)/torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)
|
173 |
+
x_t = x_t + noise * torch.sqrt(beta_t)
|
174 |
+
|
175 |
+
if progress_callback:
|
176 |
+
progress_callback((timesteps - t) / timesteps)
|
177 |
+
|
178 |
+
# Better image normalization
|
179 |
+
x_t = torch.clamp(x_t, -1., 1.)
|
180 |
+
|
181 |
+
# Alternative normalization approach
|
182 |
+
min_val = x_t.min()
|
183 |
+
max_val = x_t.max()
|
184 |
+
x_t = (x_t - min_val) / (max_val - min_val + 1e-8) # Ensure we don't divide by zero
|
185 |
+
|
186 |
+
return x_t
|
187 |
|
|
|
|
|
|
|
188 |
|
189 |
def load_model(model_path, device):
|
190 |
unet = UNet(num_classes=NUM_CLASSES).to(device)
|
|
|
240 |
if label_str not in label_map:
|
241 |
raise gr.Error("Invalid condition selected")
|
242 |
|
243 |
+
labels = torch.zeros(num_images, NUM_CLASSES, device=device)
|
244 |
labels[:, label_map[label_str]] = 1
|
245 |
|
246 |
try:
|
|
|
251 |
|
252 |
with torch.no_grad():
|
253 |
print(f"Generating {num_images} images for {label_str}")
|
|
|
|
|
254 |
images = loaded_model.sample(
|
255 |
num_images=num_images,
|
256 |
+
timesteps=TIMESTEPS,
|
257 |
img_size=IMG_SIZE,
|
258 |
num_classes=NUM_CLASSES,
|
259 |
labels=labels,
|
|
|
264 |
if images is None:
|
265 |
return None, None
|
266 |
|
|
|
267 |
print(f"Generated images range: {images.min().item():.3f}, {images.max().item():.3f}")
|
268 |
|
269 |
processed_images = []
|
270 |
for img in images:
|
271 |
+
# Convert to numpy and ensure proper range
|
272 |
+
img_np = img.cpu().numpy().transpose(1, 2, 0)
|
273 |
+
img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
|
274 |
|
275 |
+
# Convert to PIL Image
|
|
|
276 |
pil_img = Image.fromarray(img_np)
|
277 |
processed_images.append(pil_img)
|
278 |
|
|
|
282 |
else:
|
283 |
return None, processed_images
|
284 |
|
|
|
|
|
|
|
285 |
except Exception as e:
|
286 |
traceback.print_exc()
|
287 |
+
raise gr.Error(f"Generation failed: {str(e)}")
|
|
|
|
|
288 |
finally:
|
289 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
|
291 |
# Gradio UI
|
292 |
with gr.Blocks(theme=gr.themes.Soft(
|