skaliy commited on
Commit
cf8f815
Β·
1 Parent(s): 5ded0d9

Dev: gradio app

Browse files
Files changed (2) hide show
  1. app.py +89 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from fastMONAI.vision_all import *
3
+ from huggingface_hub import snapshot_download
4
+ from pathlib import Path
5
+ import torch
6
+ import cv2
7
+
8
+ def initialize_system():
9
+ """Initial setup of model paths and other constants."""
10
+ models_path = Path(snapshot_download(repo_id="skaliy/endometrial_cancer_segmentation", cache_dir='models', revision='main'))
11
+ save_dir = Path.cwd() / 'ec_pred'
12
+ save_dir.mkdir(parents=True, exist_ok=True)
13
+ download_example_endometrial_cancer_data(path=save_dir, multi_channel=False)
14
+
15
+ return models_path, save_dir
16
+
17
+ def load_system_resources(models_path):
18
+ """Load necessary resources like learner and variables."""
19
+
20
+ learner = load_learner(models_path / 'vibe-learner.pkl', cpu=True) # TODO: add an option to run on GPU
21
+ vars_fn = models_path / 'vars.pkl'
22
+ _, reorder, resample = load_variables(pkl_fn=vars_fn)
23
+
24
+ return learner, reorder, resample
25
+
26
+ def get_mid_slice(img, mask_data):
27
+ """Extract the middle slice of the mask in a 3D array."""
28
+
29
+ sums = mask_data.sum(axis=(0,1))
30
+ mid_idx = np.argmax(sums)
31
+ img, mask_data = img[:, :, mid_idx], mask_data[:, :, mid_idx]
32
+
33
+ return np.fliplr(np.rot90(img, -1)), np.fliplr(np.rot90(mask_data, -1))
34
+
35
+
36
+ def get_fused_image(img, pred_mask, alpha=0.8):
37
+ """Overlay the mask on the image."""
38
+ gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
39
+ mask_color = np.array([0, 0, 255])
40
+ colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
41
+
42
+ return cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
43
+
44
+ def compute_tumor_volume(mask_data):
45
+ """Compute the volume of the tumor in milliliters (ml)."""
46
+
47
+ dx, dy, dz = mask_data.spacing
48
+ voxel_volume_ml = dx * dy * dz / 1000
49
+ return np.sum(mask_data) * voxel_volume_ml
50
+
51
+ def predict(fileobj, learner, reorder, resample, save_dir):
52
+ """Predict function using the learner and other resources."""
53
+ img_path = Path(fileobj.name)
54
+
55
+ save_fn = 'pred_' + img_path.stem
56
+ save_path = save_dir / save_fn
57
+ org_img, input_img, org_size = med_img_reader(img_path, reorder=reorder, resample=resample, only_tensor=False)
58
+
59
+ mask_data = inference(learner, reorder=reorder, resample=resample, org_img=org_img, input_img=input_img, org_size=org_size).data
60
+
61
+ if "".join(org_img.orientation) == "LSA":
62
+ mask_data = mask_data.permute(0,1,3,2)
63
+ mask_data = torch.flip(mask_data[0], dims=[1])
64
+ mask_data = torch.Tensor(mask_data)[None]
65
+
66
+ img = org_img.data #TEMP
67
+
68
+ org_img.set_data(mask_data)
69
+ org_img.save(save_path)
70
+
71
+ img, pred_mask = get_mid_slice(img[0], mask_data[0])
72
+ img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(np.uint8) #normalize
73
+ volume = compute_tumor_volume(org_img)
74
+
75
+ return get_fused_image(img, pred_mask), round(volume, 2)
76
+
77
+
78
+ models_path, save_dir = initialize_system()
79
+ learner, reorder, resample = load_system_resources(models_path)
80
+ output_text = gr.Textbox(label="Volume of the predicted tumor:")
81
+
82
+ demo = gr.Interface(
83
+ fn=lambda fileobj: predict(fileobj, learner, reorder, resample, save_dir),
84
+ inputs=["file"],
85
+ outputs=["image", output_text],
86
+ examples=[[save_dir/"vibe.nii.gz"]]
87
+ )
88
+
89
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ fastMONAI
2
+ cv2