Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -248,11 +248,11 @@ def generate_single_image(label_str):
|
|
248 |
img_pil = Image.fromarray((img_np * 255).astype(np.uint8))
|
249 |
|
250 |
return img_pil
|
251 |
-
|
252 |
-
def generate_batch_images(label_str, num_images, progress=gr.Progress()):
|
253 |
global loaded_model
|
254 |
cancel_event.clear()
|
255 |
|
|
|
256 |
if num_images < 1 or num_images > 10:
|
257 |
raise gr.Error("Number of images must be between 1 and 10")
|
258 |
|
@@ -282,8 +282,9 @@ def generate_batch_images(label_str, num_images, progress=gr.Progress()):
|
|
282 |
)
|
283 |
|
284 |
if images is None:
|
285 |
-
return None
|
286 |
|
|
|
287 |
processed_images = []
|
288 |
for img in images:
|
289 |
img_np = img.cpu().permute(1, 2, 0).numpy()
|
@@ -291,7 +292,11 @@ def generate_batch_images(label_str, num_images, progress=gr.Progress()):
|
|
291 |
pil_img = Image.fromarray((img_np * 255).astype(np.uint8))
|
292 |
processed_images.append(pil_img)
|
293 |
|
294 |
-
|
|
|
|
|
|
|
|
|
295 |
|
296 |
except torch.cuda.OutOfMemoryError:
|
297 |
torch.cuda.empty_cache()
|
@@ -300,7 +305,7 @@ def generate_batch_images(label_str, num_images, progress=gr.Progress()):
|
|
300 |
traceback.print_exc()
|
301 |
if str(e) != "Generation was cancelled by user":
|
302 |
raise gr.Error(f"Generation failed: {str(e)}")
|
303 |
-
return None
|
304 |
finally:
|
305 |
torch.cuda.empty_cache()
|
306 |
|
@@ -312,7 +317,7 @@ print("Loading model...")
|
|
312 |
loaded_model = load_model(model_path, device)
|
313 |
print("Model loaded successfully!")
|
314 |
|
315 |
-
#
|
316 |
with gr.Blocks(theme=gr.themes.Soft(
|
317 |
primary_hue="violet",
|
318 |
neutral_hue="slate",
|
@@ -351,32 +356,44 @@ with gr.Blocks(theme=gr.themes.Soft(
|
|
351 |
""")
|
352 |
|
353 |
with gr.Column(scale=2):
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
outputs=single_image
|
373 |
)
|
374 |
|
375 |
-
# Batch image generation
|
376 |
submit_btn.click(
|
377 |
-
fn=
|
378 |
inputs=[condition, num_images],
|
379 |
-
outputs=gallery
|
380 |
)
|
381 |
|
382 |
cancel_btn.click(
|
|
|
248 |
img_pil = Image.fromarray((img_np * 255).astype(np.uint8))
|
249 |
|
250 |
return img_pil
|
251 |
+
def generate_images(label_str, num_images, progress=gr.Progress()):
|
|
|
252 |
global loaded_model
|
253 |
cancel_event.clear()
|
254 |
|
255 |
+
# Input validation
|
256 |
if num_images < 1 or num_images > 10:
|
257 |
raise gr.Error("Number of images must be between 1 and 10")
|
258 |
|
|
|
282 |
)
|
283 |
|
284 |
if images is None:
|
285 |
+
return None, None
|
286 |
|
287 |
+
# Process all generated images
|
288 |
processed_images = []
|
289 |
for img in images:
|
290 |
img_np = img.cpu().permute(1, 2, 0).numpy()
|
|
|
292 |
pil_img = Image.fromarray((img_np * 255).astype(np.uint8))
|
293 |
processed_images.append(pil_img)
|
294 |
|
295 |
+
# Return both single image and gallery based on count
|
296 |
+
if num_images == 1:
|
297 |
+
return processed_images[0], processed_images
|
298 |
+
else:
|
299 |
+
return None, processed_images
|
300 |
|
301 |
except torch.cuda.OutOfMemoryError:
|
302 |
torch.cuda.empty_cache()
|
|
|
305 |
traceback.print_exc()
|
306 |
if str(e) != "Generation was cancelled by user":
|
307 |
raise gr.Error(f"Generation failed: {str(e)}")
|
308 |
+
return None, None
|
309 |
finally:
|
310 |
torch.cuda.empty_cache()
|
311 |
|
|
|
317 |
loaded_model = load_model(model_path, device)
|
318 |
print("Model loaded successfully!")
|
319 |
|
320 |
+
# Unified Gradio UI
|
321 |
with gr.Blocks(theme=gr.themes.Soft(
|
322 |
primary_hue="violet",
|
323 |
neutral_hue="slate",
|
|
|
356 |
""")
|
357 |
|
358 |
with gr.Column(scale=2):
|
359 |
+
# Unified output display that adapts to single/batch
|
360 |
+
with gr.Tabs():
|
361 |
+
with gr.TabItem("Output", id="output_tab"):
|
362 |
+
single_image = gr.Image(
|
363 |
+
label="Generated X-ray",
|
364 |
+
height=400,
|
365 |
+
visible=True
|
366 |
+
)
|
367 |
+
gallery = gr.Gallery(
|
368 |
+
label="Generated X-rays",
|
369 |
+
columns=3,
|
370 |
+
height="auto",
|
371 |
+
object_fit="contain",
|
372 |
+
visible=False
|
373 |
+
)
|
374 |
+
|
375 |
+
def update_ui_based_on_count(num_images):
|
376 |
+
if num_images == 1:
|
377 |
+
return {
|
378 |
+
single_image: gr.update(visible=True),
|
379 |
+
gallery: gr.update(visible=False)
|
380 |
+
}
|
381 |
+
else:
|
382 |
+
return {
|
383 |
+
single_image: gr.update(visible=False),
|
384 |
+
gallery: gr.update(visible=True)
|
385 |
+
}
|
386 |
|
387 |
+
num_images.change(
|
388 |
+
fn=update_ui_based_on_count,
|
389 |
+
inputs=num_images,
|
390 |
+
outputs=[single_image, gallery]
|
|
|
391 |
)
|
392 |
|
|
|
393 |
submit_btn.click(
|
394 |
+
fn=generate_images,
|
395 |
inputs=[condition, num_images],
|
396 |
+
outputs=[single_image, gallery]
|
397 |
)
|
398 |
|
399 |
cancel_btn.click(
|