skaliy commited on
Commit
89fac49
Β·
1 Parent(s): 23eb6b2
Files changed (1) hide show
  1. app.py +31 -38
app.py CHANGED
@@ -1,89 +1,82 @@
1
  import gradio as gr
2
- from fastMONAI.vision_all import *
3
- from huggingface_hub import snapshot_download
4
- from pathlib import Path
5
  import torch
6
  import cv2
 
 
7
 
8
  def initialize_system():
9
  """Initial setup of model paths and other constants."""
10
- models_path = Path(snapshot_download(repo_id="skaliy/endometrial_cancer_segmentation", cache_dir='models', revision='main'))
 
 
11
  save_dir = Path.cwd() / 'ec_pred'
12
  save_dir.mkdir(parents=True, exist_ok=True)
13
  download_example_endometrial_cancer_data(path=save_dir, multi_channel=False)
14
 
15
  return models_path, save_dir
16
 
17
- def load_system_resources(models_path):
18
- """Load necessary resources like learner and variables."""
19
 
20
- learner = load_learner(models_path / 'vibe-learner.pkl', cpu=True) # TODO: add an option to run on GPU
21
- vars_fn = models_path / 'vars.pkl'
22
- _, reorder, resample = load_variables(pkl_fn=vars_fn)
23
 
24
- return learner, reorder, resample
25
-
26
- def get_mid_slice(img, mask_data):
27
- """Extract the middle slice of the mask in a 3D array."""
28
-
29
- sums = mask_data.sum(axis=(0,1))
30
- mid_idx = np.argmax(sums)
31
- img, mask_data = img[:, :, mid_idx], mask_data[:, :, mid_idx]
32
-
33
  return np.fliplr(np.rot90(img, -1)), np.fliplr(np.rot90(mask_data, -1))
34
 
35
-
36
  def get_fused_image(img, pred_mask, alpha=0.8):
37
- """Overlay the mask on the image."""
 
38
  gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
39
  mask_color = np.array([0, 0, 255])
40
  colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
41
-
42
  return cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
43
 
44
- def compute_tumor_volume(mask_data):
45
- """Compute the volume of the tumor in milliliters (ml)."""
46
-
47
- dx, dy, dz = mask_data.spacing
48
- voxel_volume_ml = dx * dy * dz / 1000
49
- return np.sum(mask_data) * voxel_volume_ml
50
 
51
- def predict(fileobj, learner, reorder, resample, save_dir):
52
  """Predict function using the learner and other resources."""
53
  img_path = Path(fileobj.name)
54
 
55
  save_fn = 'pred_' + img_path.stem
56
  save_path = save_dir / save_fn
57
- org_img, input_img, org_size = med_img_reader(img_path, reorder=reorder, resample=resample, only_tensor=False)
 
 
 
58
 
59
- mask_data = inference(learner, reorder=reorder, resample=resample, org_img=org_img, input_img=input_img, org_size=org_size).data
 
 
60
 
61
  if "".join(org_img.orientation) == "LSA":
62
  mask_data = mask_data.permute(0,1,3,2)
63
  mask_data = torch.flip(mask_data[0], dims=[1])
64
  mask_data = torch.Tensor(mask_data)[None]
65
 
66
- img = org_img.data #TEMP
67
-
68
  org_img.set_data(mask_data)
69
  org_img.save(save_path)
70
 
71
- img, pred_mask = get_mid_slice(img[0], mask_data[0])
72
  img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(np.uint8) #normalize
73
- volume = compute_tumor_volume(org_img)
 
74
 
75
  return get_fused_image(img, pred_mask), round(volume, 2)
76
 
77
 
78
  models_path, save_dir = initialize_system()
79
- learner, reorder, resample = load_system_resources(models_path)
 
 
80
  output_text = gr.Textbox(label="Volume of the predicted tumor:")
81
 
82
  demo = gr.Interface(
83
- fn=lambda fileobj: predict(fileobj, learner, reorder, resample, save_dir),
84
  inputs=["file"],
85
  outputs=["image", output_text],
86
- examples=[[save_dir/"vibe.nii.gz"]]
87
- )
88
 
89
  demo.launch()
 
1
  import gradio as gr
 
 
 
2
  import torch
3
  import cv2
4
+ from huggingface_hub import snapshot_download
5
+ from fastMONAI.vision_all import *
6
 
7
  def initialize_system():
8
  """Initial setup of model paths and other constants."""
9
+ models_path = Path(snapshot_download(repo_id="skaliy/endometrial_cancer_segmentation",
10
+ cache_dir='models',
11
+ revision='main'))
12
  save_dir = Path.cwd() / 'ec_pred'
13
  save_dir.mkdir(parents=True, exist_ok=True)
14
  download_example_endometrial_cancer_data(path=save_dir, multi_channel=False)
15
 
16
  return models_path, save_dir
17
 
18
+ def extract_slice_from_mask(img, mask_data):
19
+ """Extract a slice from the 3D [W, H, D] image and mask data based on mask data."""
20
 
21
+ sums = mask_data.sum(axis=(0, 1))
22
+ idx = np.argmax(sums)
23
+ img, mask_data = img[:, :, idx], mask_data[:, :, idx]
24
 
 
 
 
 
 
 
 
 
 
25
  return np.fliplr(np.rot90(img, -1)), np.fliplr(np.rot90(mask_data, -1))
26
 
27
+ #| export
28
  def get_fused_image(img, pred_mask, alpha=0.8):
29
+ """Fuse a grayscale image with a mask overlay."""
30
+
31
  gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
32
  mask_color = np.array([0, 0, 255])
33
  colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
34
+
35
  return cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
36
 
 
 
 
 
 
 
37
 
38
+ def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir):
39
  """Predict function using the learner and other resources."""
40
  img_path = Path(fileobj.name)
41
 
42
  save_fn = 'pred_' + img_path.stem
43
  save_path = save_dir / save_fn
44
+ org_img, input_img, org_size = med_img_reader(img_path,
45
+ reorder=reorder,
46
+ resample=resample,
47
+ only_tensor=False)
48
 
49
+ mask_data = inference(learn, reorder=reorder, resample=resample,
50
+ org_img=org_img, input_img=input_img,
51
+ org_size=org_size).data
52
 
53
  if "".join(org_img.orientation) == "LSA":
54
  mask_data = mask_data.permute(0,1,3,2)
55
  mask_data = torch.flip(mask_data[0], dims=[1])
56
  mask_data = torch.Tensor(mask_data)[None]
57
 
58
+ img = org_img.data
 
59
  org_img.set_data(mask_data)
60
  org_img.save(save_path)
61
 
62
+ img, pred_mask = extract_slice_from_mask(img[0], mask_data[0])
63
  img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(np.uint8) #normalize
64
+
65
+ volume = compute_binary_tumor_volume(org_img)
66
 
67
  return get_fused_image(img, pred_mask), round(volume, 2)
68
 
69
 
70
  models_path, save_dir = initialize_system()
71
+ learn, reorder, resample = load_system_resources(models_path=models_path,
72
+ learner_fn='vibe-learner.pkl',
73
+ variables_fn='vars.pkl')
74
  output_text = gr.Textbox(label="Volume of the predicted tumor:")
75
 
76
  demo = gr.Interface(
77
+ fn=lambda fileobj: gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir),
78
  inputs=["file"],
79
  outputs=["image", output_text],
80
+ examples=[[save_dir/"vibe.nii.gz"]])
 
81
 
82
  demo.launch()