Vedansh-7 commited on
Commit
a9dbf38
·
1 Parent(s): 6b37aa7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -5
app.py CHANGED
@@ -8,6 +8,8 @@ import os
8
  from threading import Event
9
  import traceback
10
  import cv2 # Added for bilateral filtering
 
 
11
 
12
  # Constants
13
  IMG_SIZE = 128
@@ -300,13 +302,31 @@ def generate_images(label_str, num_images, progress=gr.Progress()):
300
 
301
  processed_images = []
302
  for img in images:
303
- # Convert to grayscale (X-rays are naturally grayscale)
304
- # Take the mean across RGB channels and convert to uint8
305
  img_np = img.cpu().permute(1, 2, 0).mean(dim=-1).numpy()
306
- img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
307
 
308
- # Create PIL image in grayscale mode
309
- pil_img = Image.fromarray(img_np, mode='L')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
  processed_images.append(pil_img)
312
 
 
8
  from threading import Event
9
  import traceback
10
  import cv2 # Added for bilateral filtering
11
+ import matplotlib.pyplot as plt
12
+ from io import BytesIO
13
 
14
  # Constants
15
  IMG_SIZE = 128
 
302
 
303
  processed_images = []
304
  for img in images:
305
+ # Convert to grayscale and apply bone colormap
 
306
  img_np = img.cpu().permute(1, 2, 0).mean(dim=-1).numpy()
 
307
 
308
+ # Normalize to 0-1
309
+ img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
310
+
311
+ # Apply additional sharpening with OpenCV
312
+ img_np_uint8 = (img_np * 255).astype(np.uint8)
313
+
314
+ # Apply unsharp mask for additional sharpness
315
+ blurred = cv2.GaussianBlur(img_np_uint8, (0, 0), 2.0)
316
+ sharpened = cv2.addWeighted(img_np_uint8, 1.5, blurred, -0.5, 0)
317
+
318
+ # Apply bone colormap using matplotlib
319
+ plt.figure(figsize=(2.56, 2.56), dpi=50)
320
+ plt.imshow(sharpened, cmap='bone')
321
+ plt.axis('off')
322
+ plt.tight_layout(pad=0)
323
+
324
+ # Save to buffer
325
+ buf = BytesIO()
326
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=50)
327
+ buf.seek(0)
328
+ pil_img = Image.open(buf)
329
+ plt.close()
330
 
331
  processed_images.append(pil_img)
332