Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -19,25 +19,20 @@ cancel_event = Event()
|
|
19 |
# Device Configuration
|
20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
|
22 |
-
# --- Model Definitions
|
23 |
class SinusoidalPositionEmbeddings(nn.Module):
|
24 |
def __init__(self, dim):
|
25 |
super().__init__()
|
26 |
self.dim = dim
|
27 |
-
self.register_buffer('embeddings', self._precompute_embeddings(dim))
|
28 |
-
|
29 |
-
def _precompute_embeddings(self, dim):
|
30 |
half_dim = dim // 2
|
31 |
emb = math.log(10000) / (half_dim - 1)
|
32 |
-
emb = torch.exp(torch.arange(half_dim) * -emb)
|
33 |
-
|
34 |
|
35 |
def forward(self, time):
|
36 |
-
|
37 |
-
embeddings =
|
38 |
-
|
39 |
-
output = torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
|
40 |
-
return output
|
41 |
|
42 |
class UNet(nn.Module):
|
43 |
def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
|
@@ -125,95 +120,97 @@ class UNet(nn.Module):
|
|
125 |
return output
|
126 |
|
127 |
class DiffusionModel(nn.Module):
|
128 |
-
def __init__(self, model, timesteps=
|
129 |
super().__init__()
|
130 |
self.model = model
|
131 |
self.timesteps = timesteps
|
132 |
self.time_dim = time_dim
|
133 |
|
134 |
-
|
135 |
-
self.alphas = 1. - self.betas
|
136 |
-
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0).float())
|
137 |
-
|
138 |
-
def linear_schedule(self, timesteps):
|
139 |
scale = 1000 / timesteps
|
140 |
beta_start = scale * 0.0001
|
141 |
beta_end = scale * 0.02
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
x_0 = x_0.float()
|
146 |
-
noise = noise.float()
|
147 |
-
alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1)
|
148 |
-
x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1. - alpha_bar_t) * noise
|
149 |
-
return x_t
|
150 |
-
|
151 |
-
def forward(self, x_0, labels):
|
152 |
-
t = torch.randint(0, self.timesteps, (x_0.shape[0],), device=x_0.device).long()
|
153 |
-
noise = torch.randn_like(x_0)
|
154 |
-
x_t = self.forward_diffusion(x_0, t, noise)
|
155 |
-
predicted_noise = self.model(x_t, labels, t.float())
|
156 |
-
return predicted_noise, noise, t
|
157 |
-
|
158 |
-
@torch.no_grad()
|
159 |
-
def sample(model, num_images, timesteps, img_size, num_classes, labels, device, progress_callback=None):
|
160 |
-
x_t = torch.randn(num_images, 3, img_size, img_size).to(device)
|
161 |
-
|
162 |
-
if labels.ndim == 1:
|
163 |
-
labels_one_hot = torch.zeros(num_images, num_classes).to(device)
|
164 |
-
labels_one_hot[torch.arange(num_images), labels] = 1
|
165 |
-
labels = labels_one_hot
|
166 |
-
else:
|
167 |
-
labels = labels.to(device)
|
168 |
-
|
169 |
-
for t in reversed(range(timesteps)):
|
170 |
-
if cancel_event.is_set():
|
171 |
-
return None
|
172 |
-
|
173 |
-
t_tensor = torch.full((num_images,), t, device=device, dtype=torch.float)
|
174 |
-
|
175 |
-
predicted_noise = model.model(x_t, labels, t_tensor)
|
176 |
-
|
177 |
-
beta_t = model.betas[t].to(device)
|
178 |
-
alpha_t = model.alphas[t].to(device)
|
179 |
-
alpha_bar_t = model.alpha_bars[t].to(device)
|
180 |
-
|
181 |
-
mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
|
182 |
-
variance = beta_t
|
183 |
-
|
184 |
-
if t > 0:
|
185 |
-
noise = torch.randn_like(x_t)
|
186 |
-
else:
|
187 |
-
noise = torch.zeros_like(x_t)
|
188 |
|
189 |
-
|
|
|
|
|
|
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
|
|
|
|
195 |
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
-
|
|
|
|
|
202 |
|
203 |
def load_model(model_path, device):
|
204 |
-
|
205 |
-
diffusion_model = DiffusionModel(
|
206 |
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
diffusion_model.eval()
|
219 |
return diffusion_model
|
@@ -222,32 +219,6 @@ def cancel_generation():
|
|
222 |
cancel_event.set()
|
223 |
return "Generation cancelled"
|
224 |
|
225 |
-
def generate_single_image(label_str):
|
226 |
-
label_map = {'Pneumonia': 0, 'Pneumothorax': 1}
|
227 |
-
try:
|
228 |
-
label_index = label_map[label_str]
|
229 |
-
except KeyError:
|
230 |
-
raise gr.Error(f"Invalid label '{label_str}'. Please select either 'Pneumonia' or 'Pneumothorax'.")
|
231 |
-
|
232 |
-
labels = torch.zeros(1, NUM_CLASSES, device=device)
|
233 |
-
labels[0, label_index] = 1
|
234 |
-
|
235 |
-
with torch.no_grad():
|
236 |
-
generated_image = sample(
|
237 |
-
model=loaded_model,
|
238 |
-
num_images=1,
|
239 |
-
timesteps=TIMESTEPS,
|
240 |
-
img_size=IMG_SIZE,
|
241 |
-
num_classes=NUM_CLASSES,
|
242 |
-
labels=labels,
|
243 |
-
device=device
|
244 |
-
)
|
245 |
-
|
246 |
-
img_np = generated_image.squeeze(0).cpu().permute(1, 2, 0).numpy()
|
247 |
-
img_np = np.clip(img_np, 0, 1)
|
248 |
-
img_pil = Image.fromarray((img_np * 255).astype(np.uint8))
|
249 |
-
|
250 |
-
return img_pil
|
251 |
def generate_images(label_str, num_images, progress=gr.Progress()):
|
252 |
global loaded_model
|
253 |
cancel_event.clear()
|
@@ -260,7 +231,7 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
260 |
if label_str not in label_map:
|
261 |
raise gr.Error("Invalid condition selected")
|
262 |
|
263 |
-
labels = torch.zeros(num_images, NUM_CLASSES, device=device)
|
264 |
labels[:, label_map[label_str]] = 1
|
265 |
|
266 |
try:
|
@@ -270,10 +241,11 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
270 |
raise gr.Error("Generation was cancelled by user")
|
271 |
|
272 |
with torch.no_grad():
|
273 |
-
images
|
274 |
-
|
|
|
|
|
275 |
num_images=num_images,
|
276 |
-
timesteps=TIMESTEPS,
|
277 |
img_size=IMG_SIZE,
|
278 |
num_classes=NUM_CLASSES,
|
279 |
labels=labels,
|
@@ -284,15 +256,21 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
284 |
if images is None:
|
285 |
return None, None
|
286 |
|
287 |
-
#
|
|
|
|
|
288 |
processed_images = []
|
289 |
for img in images:
|
290 |
-
|
291 |
-
img_np =
|
292 |
-
|
|
|
|
|
|
|
|
|
293 |
processed_images.append(pil_img)
|
294 |
|
295 |
-
# Return
|
296 |
if num_images == 1:
|
297 |
return processed_images[0], processed_images
|
298 |
else:
|
@@ -317,7 +295,7 @@ print("Loading model...")
|
|
317 |
loaded_model = load_model(model_path, device)
|
318 |
print("Model loaded successfully!")
|
319 |
|
320 |
-
#
|
321 |
with gr.Blocks(theme=gr.themes.Soft(
|
322 |
primary_hue="violet",
|
323 |
neutral_hue="slate",
|
@@ -356,7 +334,6 @@ with gr.Blocks(theme=gr.themes.Soft(
|
|
356 |
""")
|
357 |
|
358 |
with gr.Column(scale=2):
|
359 |
-
# Unified output display that adapts to single/batch
|
360 |
with gr.Tabs():
|
361 |
with gr.TabItem("Output", id="output_tab"):
|
362 |
single_image = gr.Image(
|
|
|
19 |
# Device Configuration
|
20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
|
22 |
+
# --- Model Definitions ---
|
23 |
class SinusoidalPositionEmbeddings(nn.Module):
|
24 |
def __init__(self, dim):
|
25 |
super().__init__()
|
26 |
self.dim = dim
|
|
|
|
|
|
|
27 |
half_dim = dim // 2
|
28 |
emb = math.log(10000) / (half_dim - 1)
|
29 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
30 |
+
self.register_buffer('embeddings', emb)
|
31 |
|
32 |
def forward(self, time):
|
33 |
+
embeddings = self.embeddings.to(time.device)
|
34 |
+
embeddings = time.float()[:, None] * embeddings[None, :]
|
35 |
+
return torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)
|
|
|
|
|
36 |
|
37 |
class UNet(nn.Module):
|
38 |
def __init__(self, in_channels=3, out_channels=3, num_classes=2, time_dim=256):
|
|
|
120 |
return output
|
121 |
|
122 |
class DiffusionModel(nn.Module):
|
123 |
+
def __init__(self, model, timesteps=TIMESTEPS, time_dim=256):
|
124 |
super().__init__()
|
125 |
self.model = model
|
126 |
self.timesteps = timesteps
|
127 |
self.time_dim = time_dim
|
128 |
|
129 |
+
# Fix 1: Ensure consistent float32 types
|
|
|
|
|
|
|
|
|
130 |
scale = 1000 / timesteps
|
131 |
beta_start = scale * 0.0001
|
132 |
beta_end = scale * 0.02
|
133 |
+
self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
|
134 |
+
self.alphas = 1. - self.betas
|
135 |
+
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
+
@torch.no_grad()
|
138 |
+
def sample(self, num_images, img_size, num_classes, labels, device, progress_callback=None):
|
139 |
+
# Initialize with noise
|
140 |
+
x_t = torch.randn((num_images, 3, img_size, img_size), device=device, dtype=torch.float32)
|
141 |
|
142 |
+
# Convert labels to proper format
|
143 |
+
if labels.ndim == 1:
|
144 |
+
labels_one_hot = torch.zeros(num_images, num_classes, device=device)
|
145 |
+
labels_one_hot[torch.arange(num_images), labels] = 1
|
146 |
+
labels = labels_one_hot
|
147 |
+
else:
|
148 |
+
labels = labels.to(device)
|
149 |
|
150 |
+
for i in reversed(range(0, self.timesteps)):
|
151 |
+
if cancel_event.is_set():
|
152 |
+
return None
|
153 |
+
|
154 |
+
t = torch.full((num_images,), i, device=device, dtype=torch.long)
|
155 |
+
|
156 |
+
# Model prediction with type stability
|
157 |
+
pred_noise = self.model(x_t, labels, t.float())
|
158 |
+
|
159 |
+
# Calculate diffusion parameters
|
160 |
+
beta_t = self.betas[t].view(-1, 1, 1, 1).to(device)
|
161 |
+
alpha_t = self.alphas[t].view(-1, 1, 1, 1).to(device)
|
162 |
+
alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1).to(device)
|
163 |
+
|
164 |
+
# Improved denoising step (Fix 2)
|
165 |
+
if i > 0:
|
166 |
+
noise = torch.randn_like(x_t)
|
167 |
+
else:
|
168 |
+
noise = torch.zeros_like(x_t)
|
169 |
+
|
170 |
+
x_t = (x_t - (1 - alpha_t)/torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)
|
171 |
+
x_t += noise * torch.sqrt(beta_t)
|
172 |
+
|
173 |
+
if progress_callback:
|
174 |
+
progress_callback((self.timesteps - i) / self.timesteps)
|
175 |
|
176 |
+
# Fix 3: Simplified scaling
|
177 |
+
x_t = torch.clamp(x_t, -1., 1.)
|
178 |
+
return (x_t + 1) / 2 # Scale to [0,1]
|
179 |
|
180 |
def load_model(model_path, device):
|
181 |
+
unet = UNet(num_classes=NUM_CLASSES).to(device)
|
182 |
+
diffusion_model = DiffusionModel(unet).to(device)
|
183 |
|
184 |
+
if os.path.exists(model_path):
|
185 |
+
try:
|
186 |
+
checkpoint = torch.load(model_path, map_location=device)
|
187 |
+
|
188 |
+
# Handle both full model and state_dict loading
|
189 |
+
if 'model_state_dict' in checkpoint:
|
190 |
+
state_dict = checkpoint['model_state_dict']
|
191 |
+
else:
|
192 |
+
state_dict = checkpoint
|
193 |
+
|
194 |
+
# Handle both prefixed and non-prefixed state dicts
|
195 |
+
if all(k.startswith('model.') for k in state_dict.keys()):
|
196 |
+
state_dict = {k[6:]: v for k, v in state_dict.items()}
|
197 |
+
|
198 |
+
unet.load_state_dict(state_dict, strict=False)
|
199 |
+
print("Model loaded successfully")
|
200 |
+
|
201 |
+
# Verify model loading
|
202 |
+
test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
|
203 |
+
test_labels = torch.zeros(1, NUM_CLASSES).to(device)
|
204 |
+
test_labels[0, 0] = 1
|
205 |
+
test_time = torch.tensor([1]).to(device)
|
206 |
+
output = unet(test_input, test_labels, test_time)
|
207 |
+
print(f"Model test output shape: {output.shape}")
|
208 |
+
|
209 |
+
except Exception as e:
|
210 |
+
traceback.print_exc()
|
211 |
+
raise ValueError(f"Error loading model: {str(e)}")
|
212 |
+
else:
|
213 |
+
raise FileNotFoundError(f"Model weights not found at {model_path}")
|
214 |
|
215 |
diffusion_model.eval()
|
216 |
return diffusion_model
|
|
|
219 |
cancel_event.set()
|
220 |
return "Generation cancelled"
|
221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
def generate_images(label_str, num_images, progress=gr.Progress()):
|
223 |
global loaded_model
|
224 |
cancel_event.clear()
|
|
|
231 |
if label_str not in label_map:
|
232 |
raise gr.Error("Invalid condition selected")
|
233 |
|
234 |
+
labels = torch.zeros(num_images, NUM_CLASSES, device=device, dtype=torch.float32)
|
235 |
labels[:, label_map[label_str]] = 1
|
236 |
|
237 |
try:
|
|
|
241 |
raise gr.Error("Generation was cancelled by user")
|
242 |
|
243 |
with torch.no_grad():
|
244 |
+
print(f"Generating {num_images} images for {label_str}")
|
245 |
+
print(f"Labels shape: {labels.shape}, device: {labels.device}")
|
246 |
+
|
247 |
+
images = loaded_model.sample(
|
248 |
num_images=num_images,
|
|
|
249 |
img_size=IMG_SIZE,
|
250 |
num_classes=NUM_CLASSES,
|
251 |
labels=labels,
|
|
|
256 |
if images is None:
|
257 |
return None, None
|
258 |
|
259 |
+
# Diagnostic print
|
260 |
+
print(f"Generated images range: {images.min().item():.3f}, {images.max().item():.3f}")
|
261 |
+
|
262 |
processed_images = []
|
263 |
for img in images:
|
264 |
+
# Fix 3: Improved image conversion
|
265 |
+
img_np = (img.cpu().numpy().transpose(1, 2, 0) * 255).clip(0, 255).astype(np.uint8)
|
266 |
+
print(f"Image range after conversion: {img_np.min()}, {img_np.max()}")
|
267 |
+
|
268 |
+
if img_np.shape[2] == 1: # Handle grayscale if needed
|
269 |
+
img_np = img_np.squeeze(-1)
|
270 |
+
pil_img = Image.fromarray(img_np)
|
271 |
processed_images.append(pil_img)
|
272 |
|
273 |
+
# Return appropriate outputs based on count
|
274 |
if num_images == 1:
|
275 |
return processed_images[0], processed_images
|
276 |
else:
|
|
|
295 |
loaded_model = load_model(model_path, device)
|
296 |
print("Model loaded successfully!")
|
297 |
|
298 |
+
# Gradio UI
|
299 |
with gr.Blocks(theme=gr.themes.Soft(
|
300 |
primary_hue="violet",
|
301 |
neutral_hue="slate",
|
|
|
334 |
""")
|
335 |
|
336 |
with gr.Column(scale=2):
|
|
|
337 |
with gr.Tabs():
|
338 |
with gr.TabItem("Output", id="output_tab"):
|
339 |
single_image = gr.Image(
|