darragh commited on
Commit
eb469d0
·
1 Parent(s): 487d9cd
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -19,9 +19,11 @@ from swinunetr import SwinUnetrModelForInference, SwinUnetrConfig
19
  ffmpeg_path = shutil.which('ffmpeg')
20
  mediapy.set_ffmpeg(ffmpeg_path)
21
 
 
22
  model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny')
23
  model.eval()
24
 
 
25
  input_files = glob.glob('pmrc/SwinUNETR/BTCV/dataset/imagesSampleTs/*.nii.gz')
26
  input_files = dict((f.split('/')[-1], f) for f in input_files)
27
 
@@ -43,8 +45,6 @@ test_transform = transforms.Compose(
43
  transforms.ToTensord(keys=["image"]),
44
  ])
45
 
46
-
47
-
48
  # Create Data Loader
49
  def create_dl(test_files):
50
  ds = test_transform(test_files)
@@ -56,14 +56,17 @@ def create_dl(test_files):
56
  # Inference and video generation
57
  def generate_dicom_video(selected_file, n_frames):
58
 
 
59
  test_file = input_files[selected_file]
60
  test_files = [{'image': test_file}]
61
  dl = create_dl(test_files)
62
  batch = next(iter(dl))
63
 
 
64
  tst_inputs = batch["image"]
65
- tst_inputs = tst_inputs[:,:,:,:,-32:]
66
-
 
67
  with torch.no_grad():
68
  outputs = model(tst_inputs,
69
  (96,96,96),
@@ -75,7 +78,6 @@ def generate_dicom_video(selected_file, n_frames):
75
 
76
  # Write frames to video
77
  for inp, outp in zip(tst_inputs, tst_outputs):
78
-
79
  frames = []
80
  for idx in range(inp.shape[-1]):
81
  # Segmentation
@@ -96,7 +98,7 @@ theme = 'dark-peach'
96
  with gr.Blocks(theme=theme) as demo:
97
 
98
  gr.Markdown('''<center><h1>SwinUnetr BTCV</h1></center>
99
- This is a Gradio Blocks app of the winning transformer in the Beyond the Cranial Vault (BTCV) Segmentation Challenge, <a href="https://github.com/darraghdog/Project-MONAI-research-contributions/tree/main/SwinUNETR/BTCV">SwinUnetr (tiny version).</a>.
100
  ''')
101
  selected_dicom_key = gr.inputs.Dropdown(
102
  choices=sorted(input_files),
@@ -105,9 +107,6 @@ with gr.Blocks(theme=theme) as demo:
105
  n_frames = gr.Slider(1, 100, value=32, label="Number of dicom slices")
106
  button_gen_video = gr.Button("Generate Video")
107
  output_interpolation = gr.Video(label="Generated Video")
108
- gr.Markdown(
109
- '<center><img src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.anime-biggan" alt="visitor badge"/></center>'
110
- )
111
  button_gen_video.click(fn=generate_dicom_video,
112
  inputs=[selected_dicom_key, n_frames],
113
  outputs=output_interpolation)
 
19
  ffmpeg_path = shutil.which('ffmpeg')
20
  mediapy.set_ffmpeg(ffmpeg_path)
21
 
22
+ # Load model
23
  model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny')
24
  model.eval()
25
 
26
+ # Pull files from github
27
  input_files = glob.glob('pmrc/SwinUNETR/BTCV/dataset/imagesSampleTs/*.nii.gz')
28
  input_files = dict((f.split('/')[-1], f) for f in input_files)
29
 
 
45
  transforms.ToTensord(keys=["image"]),
46
  ])
47
 
 
 
48
  # Create Data Loader
49
  def create_dl(test_files):
50
  ds = test_transform(test_files)
 
56
  # Inference and video generation
57
  def generate_dicom_video(selected_file, n_frames):
58
 
59
+ # Data processor
60
  test_file = input_files[selected_file]
61
  test_files = [{'image': test_file}]
62
  dl = create_dl(test_files)
63
  batch = next(iter(dl))
64
 
65
+ # Select dicom slices
66
  tst_inputs = batch["image"]
67
+ tst_inputs = tst_inputs[:,:,:,:,-n_frames:]
68
+
69
+ # Inference
70
  with torch.no_grad():
71
  outputs = model(tst_inputs,
72
  (96,96,96),
 
78
 
79
  # Write frames to video
80
  for inp, outp in zip(tst_inputs, tst_outputs):
 
81
  frames = []
82
  for idx in range(inp.shape[-1]):
83
  # Segmentation
 
98
  with gr.Blocks(theme=theme) as demo:
99
 
100
  gr.Markdown('''<center><h1>SwinUnetr BTCV</h1></center>
101
+ This is a Gradio Blocks app of the winning transformer in the Beyond the Cranial Vault (BTCV) Segmentation Challenge, <a href="https://github.com/darraghdog/Project-MONAI-research-contributions/tree/main/SwinUNETR/BTCV">SwinUnetr</a> (tiny version).
102
  ''')
103
  selected_dicom_key = gr.inputs.Dropdown(
104
  choices=sorted(input_files),
 
107
  n_frames = gr.Slider(1, 100, value=32, label="Number of dicom slices")
108
  button_gen_video = gr.Button("Generate Video")
109
  output_interpolation = gr.Video(label="Generated Video")
 
 
 
110
  button_gen_video.click(fn=generate_dicom_video,
111
  inputs=[selected_dicom_key, n_frames],
112
  outputs=output_interpolation)