Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -134,56 +134,46 @@ class DiffusionModel(nn.Module):
|
|
134 |
self.alphas = 1. - self.betas
|
135 |
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
|
136 |
|
137 |
-
@torch.no_grad()
|
138 |
-
def sample(
|
139 |
-
|
140 |
-
x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.5 # Reduced initial noise scale
|
141 |
-
|
142 |
-
# Convert labels to proper format
|
143 |
-
if labels.ndim == 1:
|
144 |
-
labels_one_hot = torch.zeros(num_images, num_classes, device=device)
|
145 |
-
labels_one_hot[torch.arange(num_images), labels] = 1
|
146 |
-
labels = labels_one_hot
|
147 |
-
else:
|
148 |
-
labels = labels.float().to(device)
|
149 |
-
|
150 |
-
# Reverse diffusion process
|
151 |
-
for t in reversed(range(timesteps)):
|
152 |
-
if cancel_event.is_set():
|
153 |
-
return None
|
154 |
-
|
155 |
-
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
|
156 |
-
|
157 |
-
# Model prediction with proper scaling
|
158 |
-
pred_noise = model.model(x_t, labels, t_tensor.float())
|
159 |
-
|
160 |
-
# Calculate diffusion parameters
|
161 |
-
alpha_t = model.alphas[t].to(device)
|
162 |
-
alpha_bar_t = model.alpha_bars[t].to(device)
|
163 |
-
beta_t = model.betas[t].to(device)
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
168 |
else:
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
x_t = (x_t - min_val) / (max_val - min_val + 1e-8) # Ensure we don't divide by zero
|
185 |
-
|
186 |
-
return x_t
|
187 |
|
188 |
|
189 |
def load_model(model_path, device):
|
@@ -210,7 +200,6 @@ def load_model(model_path, device):
|
|
210 |
# Verify model loading
|
211 |
test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
|
212 |
test_labels = torch.zeros(1, NUM_CLASSES).to(device)
|
213 |
-
test_labels[0, 0] = 1
|
214 |
test_time = torch.tensor([1]).to(device)
|
215 |
output = unet(test_input, test_labels, test_time)
|
216 |
print(f"Model test output shape: {output.shape}")
|
@@ -223,6 +212,18 @@ def load_model(model_path, device):
|
|
223 |
|
224 |
diffusion_model.eval()
|
225 |
return diffusion_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
def cancel_generation():
|
228 |
cancel_event.set()
|
@@ -232,7 +233,6 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
232 |
global loaded_model
|
233 |
cancel_event.clear()
|
234 |
|
235 |
-
# Input validation
|
236 |
if num_images < 1 or num_images > 10:
|
237 |
raise gr.Error("Number of images must be between 1 and 10")
|
238 |
|
@@ -250,7 +250,6 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
250 |
raise gr.Error("Generation was cancelled by user")
|
251 |
|
252 |
with torch.no_grad():
|
253 |
-
print(f"Generating {num_images} images for {label_str}")
|
254 |
images = loaded_model.sample(
|
255 |
num_images=num_images,
|
256 |
timesteps=TIMESTEPS,
|
@@ -264,19 +263,13 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
264 |
if images is None:
|
265 |
return None, None
|
266 |
|
267 |
-
print(f"Generated images range: {images.min().item():.3f}, {images.max().item():.3f}")
|
268 |
-
|
269 |
processed_images = []
|
270 |
for img in images:
|
271 |
-
# Convert to numpy and ensure proper range
|
272 |
img_np = img.cpu().numpy().transpose(1, 2, 0)
|
273 |
img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
|
274 |
-
|
275 |
-
# Convert to PIL Image
|
276 |
pil_img = Image.fromarray(img_np)
|
277 |
processed_images.append(pil_img)
|
278 |
|
279 |
-
# Return appropriate outputs based on count
|
280 |
if num_images == 1:
|
281 |
return processed_images[0], processed_images
|
282 |
else:
|
@@ -287,7 +280,7 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
287 |
raise gr.Error(f"Generation failed: {str(e)}")
|
288 |
finally:
|
289 |
torch.cuda.empty_cache()
|
290 |
-
|
291 |
# Gradio UI
|
292 |
with gr.Blocks(theme=gr.themes.Soft(
|
293 |
primary_hue="violet",
|
|
|
134 |
self.alphas = 1. - self.betas
|
135 |
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
|
136 |
|
137 |
+
@torch.no_grad()
|
138 |
+
def sample(self, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
|
139 |
+
x_t = torch.randn((num_images, 3, img_size, img_size), device=device) * 0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
+
if labels.ndim == 1:
|
142 |
+
labels_one_hot = torch.zeros(num_images, num_classes, device=device)
|
143 |
+
labels_one_hot[torch.arange(num_images), labels] = 1
|
144 |
+
labels = labels_one_hot
|
145 |
else:
|
146 |
+
labels = labels.float().to(device)
|
147 |
+
|
148 |
+
for t in reversed(range(timesteps)):
|
149 |
+
if cancel_event.is_set():
|
150 |
+
return None
|
151 |
+
|
152 |
+
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.long)
|
153 |
|
154 |
+
pred_noise = self.model(x_t, labels, t_tensor.float())
|
155 |
+
|
156 |
+
alpha_t = self.alphas[t].to(device)
|
157 |
+
alpha_bar_t = self.alpha_bars[t].to(device)
|
158 |
+
beta_t = self.betas[t].to(device)
|
159 |
+
|
160 |
+
if t > 0:
|
161 |
+
noise = torch.randn_like(x_t) * 0.5
|
162 |
+
else:
|
163 |
+
noise = torch.zeros_like(x_t)
|
164 |
+
|
165 |
+
x_t = (x_t - (1 - alpha_t)/torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)
|
166 |
+
x_t = x_t + noise * torch.sqrt(beta_t)
|
167 |
+
|
168 |
+
if progress_callback:
|
169 |
+
progress_callback((timesteps - t) / timesteps)
|
170 |
|
171 |
+
x_t = torch.clamp(x_t, -1., 1.)
|
172 |
+
min_val = x_t.min()
|
173 |
+
max_val = x_t.max()
|
174 |
+
x_t = (x_t - min_val) / (max_val - min_val + 1e-8)
|
175 |
+
|
176 |
+
return x_t
|
|
|
|
|
|
|
177 |
|
178 |
|
179 |
def load_model(model_path, device):
|
|
|
200 |
# Verify model loading
|
201 |
test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
|
202 |
test_labels = torch.zeros(1, NUM_CLASSES).to(device)
|
|
|
203 |
test_time = torch.tensor([1]).to(device)
|
204 |
output = unet(test_input, test_labels, test_time)
|
205 |
print(f"Model test output shape: {output.shape}")
|
|
|
212 |
|
213 |
diffusion_model.eval()
|
214 |
return diffusion_model
|
215 |
+
|
216 |
+
MODEL_NAME = "model_weights.pth"
|
217 |
+
model_path = MODEL_NAME
|
218 |
+
print("Loading model...")
|
219 |
+
try:
|
220 |
+
loaded_model = load_model(model_path, device)
|
221 |
+
print("Model loaded successfully!")
|
222 |
+
except Exception as e:
|
223 |
+
print(f"Failed to load model: {e}")
|
224 |
+
# Create a dummy model if loading fails
|
225 |
+
print("Creating dummy model for demonstration")
|
226 |
+
loaded_model = DiffusionModel(UNet(num_classes=NUM_CLASSES)).to(device)
|
227 |
|
228 |
def cancel_generation():
|
229 |
cancel_event.set()
|
|
|
233 |
global loaded_model
|
234 |
cancel_event.clear()
|
235 |
|
|
|
236 |
if num_images < 1 or num_images > 10:
|
237 |
raise gr.Error("Number of images must be between 1 and 10")
|
238 |
|
|
|
250 |
raise gr.Error("Generation was cancelled by user")
|
251 |
|
252 |
with torch.no_grad():
|
|
|
253 |
images = loaded_model.sample(
|
254 |
num_images=num_images,
|
255 |
timesteps=TIMESTEPS,
|
|
|
263 |
if images is None:
|
264 |
return None, None
|
265 |
|
|
|
|
|
266 |
processed_images = []
|
267 |
for img in images:
|
|
|
268 |
img_np = img.cpu().numpy().transpose(1, 2, 0)
|
269 |
img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
|
|
|
|
|
270 |
pil_img = Image.fromarray(img_np)
|
271 |
processed_images.append(pil_img)
|
272 |
|
|
|
273 |
if num_images == 1:
|
274 |
return processed_images[0], processed_images
|
275 |
else:
|
|
|
280 |
raise gr.Error(f"Generation failed: {str(e)}")
|
281 |
finally:
|
282 |
torch.cuda.empty_cache()
|
283 |
+
|
284 |
# Gradio UI
|
285 |
with gr.Blocks(theme=gr.themes.Soft(
|
286 |
primary_hue="violet",
|