Vedansh-7 commited on
Commit
3d428a2
·
1 Parent(s): aa082d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -27
app.py CHANGED
@@ -248,11 +248,11 @@ def generate_single_image(label_str):
248
  img_pil = Image.fromarray((img_np * 255).astype(np.uint8))
249
 
250
  return img_pil
251
-
252
- def generate_batch_images(label_str, num_images, progress=gr.Progress()):
253
  global loaded_model
254
  cancel_event.clear()
255
 
 
256
  if num_images < 1 or num_images > 10:
257
  raise gr.Error("Number of images must be between 1 and 10")
258
 
@@ -282,8 +282,9 @@ def generate_batch_images(label_str, num_images, progress=gr.Progress()):
282
  )
283
 
284
  if images is None:
285
- return None
286
 
 
287
  processed_images = []
288
  for img in images:
289
  img_np = img.cpu().permute(1, 2, 0).numpy()
@@ -291,7 +292,11 @@ def generate_batch_images(label_str, num_images, progress=gr.Progress()):
291
  pil_img = Image.fromarray((img_np * 255).astype(np.uint8))
292
  processed_images.append(pil_img)
293
 
294
- return processed_images
 
 
 
 
295
 
296
  except torch.cuda.OutOfMemoryError:
297
  torch.cuda.empty_cache()
@@ -300,7 +305,7 @@ def generate_batch_images(label_str, num_images, progress=gr.Progress()):
300
  traceback.print_exc()
301
  if str(e) != "Generation was cancelled by user":
302
  raise gr.Error(f"Generation failed: {str(e)}")
303
- return None
304
  finally:
305
  torch.cuda.empty_cache()
306
 
@@ -312,7 +317,7 @@ print("Loading model...")
312
  loaded_model = load_model(model_path, device)
313
  print("Model loaded successfully!")
314
 
315
- # --- Gradio UI (from first file with modifications) ---
316
  with gr.Blocks(theme=gr.themes.Soft(
317
  primary_hue="violet",
318
  neutral_hue="slate",
@@ -351,32 +356,44 @@ with gr.Blocks(theme=gr.themes.Soft(
351
  """)
352
 
353
  with gr.Column(scale=2):
354
- with gr.Tab("Single Image"):
355
- single_image = gr.Image(
356
- type="pil",
357
- label="Generated X-ray",
358
- height=400
359
- )
360
- with gr.Tab("Batch Images"):
361
- gallery = gr.Gallery(
362
- label="Generated X-rays",
363
- columns=3,
364
- height="auto",
365
- object_fit="contain"
366
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
- # Single image generation
369
- condition.change(
370
- fn=generate_single_image,
371
- inputs=condition,
372
- outputs=single_image
373
  )
374
 
375
- # Batch image generation
376
  submit_btn.click(
377
- fn=generate_batch_images,
378
  inputs=[condition, num_images],
379
- outputs=gallery
380
  )
381
 
382
  cancel_btn.click(
 
248
  img_pil = Image.fromarray((img_np * 255).astype(np.uint8))
249
 
250
  return img_pil
251
+ def generate_images(label_str, num_images, progress=gr.Progress()):
 
252
  global loaded_model
253
  cancel_event.clear()
254
 
255
+ # Input validation
256
  if num_images < 1 or num_images > 10:
257
  raise gr.Error("Number of images must be between 1 and 10")
258
 
 
282
  )
283
 
284
  if images is None:
285
+ return None, None
286
 
287
+ # Process all generated images
288
  processed_images = []
289
  for img in images:
290
  img_np = img.cpu().permute(1, 2, 0).numpy()
 
292
  pil_img = Image.fromarray((img_np * 255).astype(np.uint8))
293
  processed_images.append(pil_img)
294
 
295
+ # Return both single image and gallery based on count
296
+ if num_images == 1:
297
+ return processed_images[0], processed_images
298
+ else:
299
+ return None, processed_images
300
 
301
  except torch.cuda.OutOfMemoryError:
302
  torch.cuda.empty_cache()
 
305
  traceback.print_exc()
306
  if str(e) != "Generation was cancelled by user":
307
  raise gr.Error(f"Generation failed: {str(e)}")
308
+ return None, None
309
  finally:
310
  torch.cuda.empty_cache()
311
 
 
317
  loaded_model = load_model(model_path, device)
318
  print("Model loaded successfully!")
319
 
320
+ # Unified Gradio UI
321
  with gr.Blocks(theme=gr.themes.Soft(
322
  primary_hue="violet",
323
  neutral_hue="slate",
 
356
  """)
357
 
358
  with gr.Column(scale=2):
359
+ # Unified output display that adapts to single/batch
360
+ with gr.Tabs():
361
+ with gr.TabItem("Output", id="output_tab"):
362
+ single_image = gr.Image(
363
+ label="Generated X-ray",
364
+ height=400,
365
+ visible=True
366
+ )
367
+ gallery = gr.Gallery(
368
+ label="Generated X-rays",
369
+ columns=3,
370
+ height="auto",
371
+ object_fit="contain",
372
+ visible=False
373
+ )
374
+
375
+ def update_ui_based_on_count(num_images):
376
+ if num_images == 1:
377
+ return {
378
+ single_image: gr.update(visible=True),
379
+ gallery: gr.update(visible=False)
380
+ }
381
+ else:
382
+ return {
383
+ single_image: gr.update(visible=False),
384
+ gallery: gr.update(visible=True)
385
+ }
386
 
387
+ num_images.change(
388
+ fn=update_ui_based_on_count,
389
+ inputs=num_images,
390
+ outputs=[single_image, gallery]
 
391
  )
392
 
 
393
  submit_btn.click(
394
+ fn=generate_images,
395
  inputs=[condition, num_images],
396
+ outputs=[single_image, gallery]
397
  )
398
 
399
  cancel_btn.click(