jpdefrutos commited on
Commit
15c9383
·
1 Parent(s): 99b6efe

Refactoring

Browse files

Improved image generation

Files changed (1) hide show
  1. Brain_study/ABSTRACT/figures.py +80 -32
Brain_study/ABSTRACT/figures.py CHANGED
@@ -7,22 +7,74 @@ import nibabel as nib
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
  from matplotlib import cm
10
- from matplotlib.colors import ListedColormap, LinearSegmentedColormap
 
11
 
12
- segm_cm = cm.get_cmap('Dark2', 256)
13
- segm_cm = segm_cm(np.linspace(0, 1, 28))
 
 
 
 
 
14
  segm_cm[0, :] = np.asarray([0, 0, 0, 0])
15
  segm_cm = ListedColormap(segm_cm)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  if __name__ == '__main__':
18
  parser = argparse.ArgumentParser()
19
 
20
  parser.add_argument('-d', '--dir', type=str, help='Directories where the models are stored', default=None)
21
  parser.add_argument('-o', '--output', type=str, help='Output directory', default=os.getcwd())
22
  parser.add_argument('--overwrite', type=bool, default=True)
 
 
23
  args = parser.parse_args()
24
  assert args.dir is not None, "No directories provided. Stopping"
25
-
26
  list_fix_img = list()
27
  list_mov_img = list()
28
  list_fix_seg = list()
@@ -30,10 +82,12 @@ if __name__ == '__main__':
30
  list_pred_img = list()
31
  list_pred_seg = list()
32
  print('Fetching data...')
 
33
  for r, d, f in os.walk(args.dir):
34
- if os.path.split(r)[1] == 'Evaluation_paper':
 
35
  for name in f:
36
- if re.search('^050', name) and name.endswith('nii.gz'):
37
  if re.search('fix_img', name) and name.endswith('nii.gz'):
38
  list_fix_img.append(os.path.join(r, name))
39
  elif re.search('mov_img', name):
@@ -58,16 +112,16 @@ if __name__ == '__main__':
58
  list_pred_img.sort()
59
  list_pred_seg.sort()
60
  print('Making Test_data.png...')
61
- selected_slice = 30
62
- fix_img = np.asarray(nib.load(list_fix_img[0]).dataobj)[..., selected_slice, 0]
63
- mov_img = np.asarray(nib.load(list_mov_img[0]).dataobj)[..., selected_slice, 0]
64
- fix_seg = np.asarray(nib.load(list_fix_seg[0]).dataobj)[..., selected_slice, 0]
65
- mov_seg = np.asarray(nib.load(list_mov_seg[0]).dataobj)[..., selected_slice, 0]
66
 
67
  fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(9, 3), dpi=200)
68
 
69
  for i, (img, title) in enumerate(zip([(fix_img, fix_seg), (mov_img, mov_seg)],
70
- [('Fixed image', 'Fixed Segms.'), ('Moving image', 'Moving Segms.')])):
71
 
72
  ax[i].imshow(img[0], origin='lower', cmap='Greys_r')
73
  ax[i+2].imshow(img[0], origin='lower', cmap='Greys_r')
@@ -84,14 +138,16 @@ if __name__ == '__main__':
84
  warnings.warn('File Test_data.png already exists. Skipping')
85
  else:
86
  plt.savefig(os.path.join(args.output, 'Test_data.png'), format='png')
 
 
87
  plt.close()
88
 
89
  print('Making Pred_data.png...')
90
- fig, ax = plt.subplots(nrows=2, ncols=6, figsize=(9, 3), dpi=200)
91
 
92
  for i, (pred_img_path, pred_seg_path) in enumerate(zip(list_pred_img, list_pred_seg)):
93
- img = np.asarray(nib.load(pred_img_path).dataobj)[..., selected_slice, 0]
94
- seg = np.asarray(nib.load(pred_seg_path).dataobj)[..., selected_slice, 0]
95
 
96
  ax[0, i].imshow(img, origin='lower', cmap='Greys_r')
97
  ax[1, i].imshow(img, origin='lower', cmap='Greys_r')
@@ -100,13 +156,7 @@ if __name__ == '__main__':
100
  ax[0, i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
101
  ax[1, i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
102
 
103
- model = re.search('((UW|SEGGUIDED|BASELINE).*)_{2,}MET', pred_img_path).group(1).rstrip('_')
104
- model = model.replace('_Lsim', ' ')
105
- model = model.replace('_Lseg', ' ')
106
- model = model.replace('_L', ' ')
107
- model = model.replace('_', ' ')
108
- model = model.upper()
109
- model = ' '.join(model.split())
110
 
111
  ax[1, i].set_xlabel(model, fontsize=9)
112
  plt.tight_layout()
@@ -114,18 +164,20 @@ if __name__ == '__main__':
114
  warnings.warn('File Pred_data.png already exists. Skipping')
115
  else:
116
  plt.savefig(os.path.join(args.output, 'Pred_data.png'), format='png')
 
 
117
  plt.close()
118
 
119
  print('Making Pred_data_large.png...')
120
- fig, ax = plt.subplots(nrows=2, ncols=8, figsize=(9, 3), dpi=200)
121
  list_pred_img = [list_mov_img[0]] + list_pred_img
122
  list_pred_img = [list_fix_img[0]] + list_pred_img
123
  list_pred_seg = [list_mov_seg[0]] + list_pred_seg
124
  list_pred_seg = [list_fix_seg[0]] + list_pred_seg
125
 
126
  for i, (pred_img_path, pred_seg_path) in enumerate(zip(list_pred_img, list_pred_seg)):
127
- img = np.asarray(nib.load(pred_img_path).dataobj)[..., selected_slice, 0]
128
- seg = np.asarray(nib.load(pred_seg_path).dataobj)[..., selected_slice, 0]
129
 
130
  ax[0, i].imshow(img, origin='lower', cmap='Greys_r')
131
  ax[1, i].imshow(img, origin='lower', cmap='Greys_r')
@@ -135,13 +187,7 @@ if __name__ == '__main__':
135
  ax[1, i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
136
 
137
  if i > 1:
138
- model = re.search('((UW|SEGGUIDED|BASELINE).*)_{2,}MET', pred_img_path).group(1).rstrip('_')
139
- model = model.replace('_Lsim', ' ')
140
- model = model.replace('_Lseg', ' ')
141
- model = model.replace('_L', ' ')
142
- model = model.replace('_', ' ')
143
- model = model.upper()
144
- model = ' '.join(model.split())
145
  elif i == 0:
146
  model = 'Moving image'
147
  else:
@@ -153,6 +199,8 @@ if __name__ == '__main__':
153
  warnings.warn('File Pred_data.png already exists. Skipping')
154
  else:
155
  plt.savefig(os.path.join(args.output, 'Pred_data_large.png'), format='png')
 
 
156
  plt.close()
157
 
158
  print('...done!')
 
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
  from matplotlib import cm
10
+ from matplotlib.colors import ListedColormap, LinearSegmentedColormap, to_rgba, CSS4_COLORS
11
+ import tikzplotlib
12
 
13
+ from DeepDeformationMapRegistration.utils.misc import segmentation_ohe_to_cardinal
14
+
15
+ # segm_cm = np.asarray([to_rgba(CSS4_COLORS[c], 1) for c in CSS4_COLORS.keys()])
16
+ # # segm_cm.sort()
17
+ # segm_cm = segm_cm[np.linspace(0, len(segm_cm), 4, endpoint=False).astype(int), ...]
18
+ segm_cm = cm.get_cmap('jet').reversed()
19
+ segm_cm = segm_cm(np.linspace(0, 1, 30))
20
  segm_cm[0, :] = np.asarray([0, 0, 0, 0])
21
  segm_cm = ListedColormap(segm_cm)
22
 
23
+ DICT_MODEL_NAMES = {'BASELINE': 'BL',
24
+ 'SEGGUIDED': 'SG',
25
+ 'UW': 'UW'}
26
+
27
+ DICT_METRICS_NAMES = {'NCC': 'N',
28
+ 'SSIM': 'S',
29
+ 'DICE': 'D',
30
+ 'DICE_MACRO': 'D',
31
+ 'HD': 'H', }
32
+
33
+
34
+ def get_model_name(in_path: str):
35
+ model = re.search('((UW|SEGGUIDED|BASELINE).*)_\d+-\d+', in_path)
36
+ if model:
37
+ model = model.group(1).rstrip('_')
38
+ model = model.replace('_Lsim', '')
39
+ model = model.replace('_Lseg', '')
40
+ model = model.replace('_L', '')
41
+ model = model.replace('_', ' ')
42
+ model = model.upper()
43
+ elements = model.split()
44
+ model = elements[0]
45
+ metrics = list()
46
+ model = DICT_MODEL_NAMES[model]
47
+ for m in elements[1:]:
48
+ if m != 'MACRO':
49
+ metrics.append(DICT_METRICS_NAMES[m])
50
+
51
+ return '{}-{}'.format(model, ''.join(metrics))
52
+ else:
53
+ try:
54
+ model = re.search('(SyNCC|SyN)', in_path).group(1)
55
+ except AttributeError:
56
+ raise ValueError('Unknown folder name/model: '+ in_path)
57
+ return model
58
+
59
+
60
+ def load_segmentation(file_path) -> np.ndarray:
61
+ segm = np.asarray(nib.load(file_path).dataobj)
62
+ if segm.shape[-1] > 1:
63
+ segm = segmentation_ohe_to_cardinal(segm)
64
+ return segm
65
+
66
+
67
  if __name__ == '__main__':
68
  parser = argparse.ArgumentParser()
69
 
70
  parser.add_argument('-d', '--dir', type=str, help='Directories where the models are stored', default=None)
71
  parser.add_argument('-o', '--output', type=str, help='Output directory', default=os.getcwd())
72
  parser.add_argument('--overwrite', type=bool, default=True)
73
+ parser.add_argument('--fileno', type=int, default=2)
74
+ parser.add_argument('--tikz', type=bool, default=False)
75
  args = parser.parse_args()
76
  assert args.dir is not None, "No directories provided. Stopping"
77
+ os.makedirs(args.output, exist_ok=True)
78
  list_fix_img = list()
79
  list_mov_img = list()
80
  list_fix_seg = list()
 
82
  list_pred_img = list()
83
  list_pred_seg = list()
84
  print('Fetching data...')
85
+ init_lvl = args.dir.count(os.sep)
86
  for r, d, f in os.walk(args.dir):
87
+ current_lvl = r.count(os.sep) - init_lvl
88
+ if current_lvl < 3:
89
  for name in f:
90
+ if re.search('^{:03d}'.format(args.fileno), name) and name.endswith('nii.gz'):
91
  if re.search('fix_img', name) and name.endswith('nii.gz'):
92
  list_fix_img.append(os.path.join(r, name))
93
  elif re.search('mov_img', name):
 
112
  list_pred_img.sort()
113
  list_pred_seg.sort()
114
  print('Making Test_data.png...')
115
+ selected_slice = 64
116
+ fix_img = np.asarray(nib.load(list_fix_img[0]).dataobj)[selected_slice, ..., 0].T
117
+ mov_img = np.asarray(nib.load(list_mov_img[0]).dataobj)[selected_slice, ..., 0].T
118
+ fix_seg = load_segmentation(list_fix_seg[0])[selected_slice, ..., 0].T
119
+ mov_seg = load_segmentation(list_mov_seg[0])[selected_slice, ..., 0].T
120
 
121
  fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(9, 3), dpi=200)
122
 
123
  for i, (img, title) in enumerate(zip([(fix_img, fix_seg), (mov_img, mov_seg)],
124
+ [('Fixed image', 'Fixed segms.'), ('Moving image', 'Moving segms.')])):
125
 
126
  ax[i].imshow(img[0], origin='lower', cmap='Greys_r')
127
  ax[i+2].imshow(img[0], origin='lower', cmap='Greys_r')
 
138
  warnings.warn('File Test_data.png already exists. Skipping')
139
  else:
140
  plt.savefig(os.path.join(args.output, 'Test_data.png'), format='png')
141
+ if args.tikz:
142
+ tikzplotlib.save(os.path.join(args.output, 'Test_data.tex'))
143
  plt.close()
144
 
145
  print('Making Pred_data.png...')
146
+ fig, ax = plt.subplots(nrows=2, ncols=len(list_pred_img), figsize=(9, 3), dpi=200)
147
 
148
  for i, (pred_img_path, pred_seg_path) in enumerate(zip(list_pred_img, list_pred_seg)):
149
+ img = np.asarray(nib.load(pred_img_path).dataobj)[selected_slice, ..., 0].T
150
+ seg = load_segmentation(pred_seg_path)[selected_slice, ..., 0].T
151
 
152
  ax[0, i].imshow(img, origin='lower', cmap='Greys_r')
153
  ax[1, i].imshow(img, origin='lower', cmap='Greys_r')
 
156
  ax[0, i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
157
  ax[1, i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
158
 
159
+ model = get_model_name(pred_img_path)
 
 
 
 
 
 
160
 
161
  ax[1, i].set_xlabel(model, fontsize=9)
162
  plt.tight_layout()
 
164
  warnings.warn('File Pred_data.png already exists. Skipping')
165
  else:
166
  plt.savefig(os.path.join(args.output, 'Pred_data.png'), format='png')
167
+ if args.tikz:
168
+ tikzplotlib.save(os.path.join(args.output, 'Pred_data.tex'))
169
  plt.close()
170
 
171
  print('Making Pred_data_large.png...')
172
+ fig, ax = plt.subplots(nrows=2, ncols=len(list_pred_img) + 2, figsize=(9, 3), dpi=200)
173
  list_pred_img = [list_mov_img[0]] + list_pred_img
174
  list_pred_img = [list_fix_img[0]] + list_pred_img
175
  list_pred_seg = [list_mov_seg[0]] + list_pred_seg
176
  list_pred_seg = [list_fix_seg[0]] + list_pred_seg
177
 
178
  for i, (pred_img_path, pred_seg_path) in enumerate(zip(list_pred_img, list_pred_seg)):
179
+ img = np.asarray(nib.load(pred_img_path).dataobj)[selected_slice, ..., 0].T
180
+ seg = load_segmentation(pred_seg_path)[selected_slice, ..., 0].T
181
 
182
  ax[0, i].imshow(img, origin='lower', cmap='Greys_r')
183
  ax[1, i].imshow(img, origin='lower', cmap='Greys_r')
 
187
  ax[1, i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
188
 
189
  if i > 1:
190
+ model = get_model_name(pred_img_path)
 
 
 
 
 
 
191
  elif i == 0:
192
  model = 'Moving image'
193
  else:
 
199
  warnings.warn('File Pred_data.png already exists. Skipping')
200
  else:
201
  plt.savefig(os.path.join(args.output, 'Pred_data_large.png'), format='png')
202
+ if args.tikz:
203
+ tikzplotlib.save(os.path.join(args.output, 'Pred_data_large.png'))
204
  plt.close()
205
 
206
  print('...done!')