jpdefrutos commited on
Commit
99b6efe
·
1 Parent(s): 92caf3a

Minor improvements

Browse files

Now the user can select the folder where the metrics.csv file is using the flag --metrics-folder

Brain_study/ABSTRACT/format_tables_abstract.py CHANGED
@@ -1,6 +1,7 @@
1
  import warnings
2
 
3
  import pandas as pd
 
4
  import os
5
  import argparse
6
  import re
@@ -18,21 +19,33 @@ DICT_METRICS_NAMES = {'NCC': 'N',
18
 
19
 
20
  def row_name(in_path: str):
21
- model = re.search('((UW|SEGGUIDED|BASELINE).*)_\d', in_path).group(1).rstrip('_')
22
- model = model.replace('_Lsim', '')
23
- model = model.replace('_Lseg', '')
24
- model = model.replace('_L', '')
25
- model = model.replace('_', ' ')
26
- model = model.upper()
27
- elements = model.split()
28
- model = elements[0]
29
- metrics = list()
30
- model = DICT_MODEL_NAMES[model]
31
- for m in elements[1:]:
32
- if m != 'MACRO':
33
- metrics.append(DICT_METRICS_NAMES[m])
34
-
35
- return '{}-{}'.format(model, ''.join(metrics))
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  if __name__ == '__main__':
@@ -44,32 +57,43 @@ if __name__ == '__main__':
44
  parser.add_argument('--overwrite', type=bool, default=True)
45
  parser.add_argument('--filename', type=str, help='Output file name', default='metrics')
46
  parser.add_argument('--removemetrics', nargs='+', type=str, default=None)
 
47
  args = parser.parse_args()
48
  assert args.dir is not None, "No directories provided. Stopping"
49
 
50
  if len(args.dir) == 1:
51
  list_files = list()
 
 
 
 
 
52
  for r, d, f in os.walk(args.dir[0]):
53
- for name in f:
54
- if 'metrics.csv' == name: # and os.path.split(r)[1] == 'Evaluation_paper':
55
- list_files.append(os.path.join(r, name))
 
 
56
  else:
57
  list_files = [os.path.join(d, 'metrics.csv') for d in args.dir]
58
 
59
  for d in list_files:
60
  assert os.path.exists(d), "Missing metrics.csv file in: " + os.path.split(d)[0]
61
-
62
- print('Metric files found: {}'.format(list_files))
63
 
64
  dataframes = list()
65
  if len(list_files):
66
  for d in list_files:
67
- df = pd.read_csv(d, sep=';', header=0)
68
  model = row_name(d)
69
 
70
  df.insert(0, "Model", model)
71
  df.drop(columns=list(df.filter(regex='Unnamed')), inplace=True)
72
- df.drop(columns=['File', 'MSE', 'No_missing_lbls'], inplace=True)
 
 
 
73
  dataframes.append(df)
74
 
75
  full_table = pd.concat(dataframes)
@@ -79,7 +103,7 @@ if __name__ == '__main__':
79
  # mean_table.insert(column='Type', value='Avg.', loc=1)
80
  # mean_table = mean_table.groupby(['Type', 'Model']).mean().round(3)
81
  mean_table = mean_table.groupby(['Model'])
82
- hd95 = mean_table.HD.quantile(0.95).map('{:.2f}'.format)
83
  mean_table = mean_table.mean().round(3)
84
 
85
  std_table = full_table.copy()
@@ -91,21 +115,25 @@ if __name__ == '__main__':
91
  metrics_table = mean_table.applymap('{:.2f}'.format) + u"\u00B1" + std_table.applymap('{:.2f}'.format)
92
  time_col = metrics_table.pop('Time')
93
  metrics_table.insert(len(metrics_table.columns), 'Time', time_col)
94
- metrics_table.insert(5, 'HD 95%ile', hd95)
 
95
 
96
  metrics_file = os.path.join(args.output, args.filename + '.tex')
97
  if os.path.exists(metrics_file) and args.overwrite:
98
  shutil.rmtree(metrics_file, ignore_errors=True)
99
  metrics_table.to_latex(metrics_file,
100
- column_format='l' + 'c' * len(metrics_table.columns),
101
- caption='Average and standard deviation of the metrics: MSE, NCC, SSIM, DICE and HD. As well as the number of missing labels in the predicted images.')
 
102
  elif os.path.exists(metrics_file):
103
  warnings.warn('File {} already exists. Skipping'.format(metrics_file))
104
  else:
105
  metrics_table.to_latex(metrics_file,
106
- column_format='l' + 'c' * len(metrics_table.columns),
107
- caption='Average and standard deviation of the metrics: MSE, NCC, SSIM, DICE and HD. As well as the number of missing labels in the predicted images.')
 
108
 
 
109
  print('Done')
110
  else:
111
  print('No files found in {}!'.format(args.dir))
 
1
  import warnings
2
 
3
  import pandas as pd
4
+ pd.options.display.max_columns = 10
5
  import os
6
  import argparse
7
  import re
 
19
 
20
 
21
  def row_name(in_path: str):
22
+ model = re.search('((UW|SEGGUIDED|BASELINE).*)_\d', in_path)
23
+ ret_val = None
24
+ if model:
25
+ model = model.group(1).rstrip('_')
26
+ model = model.replace('_Lsim', '')
27
+ model = model.replace('_Lseg', '')
28
+ model = model.replace('_L', '')
29
+ model = model.replace('_', ' ')
30
+ model = model.upper()
31
+ elements = model.split()
32
+ model = elements[0]
33
+ metrics = list()
34
+ model = DICT_MODEL_NAMES[model]
35
+ for m in elements[1:]:
36
+ if m != 'MACRO':
37
+ metrics.append(DICT_METRICS_NAMES[m])
38
+
39
+ ret_val = '{}-{}'.format(model, ''.join(metrics))
40
+ elif re.search('((COMET|IXI).*)', in_path):
41
+ model = re.search('((COMET|IXI).*)', in_path)
42
+ ret_val = model.group(1).split('_')[0]
43
+ else:
44
+ try:
45
+ ret_val = re.search('(SyNCC|SyN)', in_path).group(1)
46
+ except AttributeError:
47
+ raise ValueError('Unknown folder name/model: '+ in_path)
48
+ return ret_val
49
 
50
 
51
  if __name__ == '__main__':
 
57
  parser.add_argument('--overwrite', type=bool, default=True)
58
  parser.add_argument('--filename', type=str, help='Output file name', default='metrics')
59
  parser.add_argument('--removemetrics', nargs='+', type=str, default=None)
60
+ parser.add_argument('--metrics-folder', type=str, default=None)
61
  args = parser.parse_args()
62
  assert args.dir is not None, "No directories provided. Stopping"
63
 
64
  if len(args.dir) == 1:
65
  list_files = list()
66
+ if args.metrics_folder:
67
+ file_found_condition = lambda name: 'metrics.csv' == name and args.metrics_folder in r.split(os.sep)
68
+ else:
69
+ file_found_condition = lambda name: 'metrics.csv' == name
70
+ starting_level = args.dir[0].count(os.sep)
71
  for r, d, f in os.walk(args.dir[0]):
72
+ level = r.count(os.sep) - starting_level
73
+ if level < 3:
74
+ for name in f:
75
+ if file_found_condition(name):
76
+ list_files.append(os.path.join(r, name))
77
  else:
78
  list_files = [os.path.join(d, 'metrics.csv') for d in args.dir]
79
 
80
  for d in list_files:
81
  assert os.path.exists(d), "Missing metrics.csv file in: " + os.path.split(d)[0]
82
+ list_files.sort()
83
+ print('Metric files found ({}):\n\t{}'.format(len(list_files), '\n\t'.join(list_files)))
84
 
85
  dataframes = list()
86
  if len(list_files):
87
  for d in list_files:
88
+ df = pd.read_csv(d, sep=';', header=0, dtype={'TRE':float})
89
  model = row_name(d)
90
 
91
  df.insert(0, "Model", model)
92
  df.drop(columns=list(df.filter(regex='Unnamed')), inplace=True)
93
+ if not 'SyN' in model:
94
+ df.drop(columns=['File', 'MSE', 'No_missing_lbls'], inplace=True)
95
+ else:
96
+ df.drop(columns=['File', 'MSE'], inplace=True)
97
  dataframes.append(df)
98
 
99
  full_table = pd.concat(dataframes)
 
103
  # mean_table.insert(column='Type', value='Avg.', loc=1)
104
  # mean_table = mean_table.groupby(['Type', 'Model']).mean().round(3)
105
  mean_table = mean_table.groupby(['Model'])
106
+ # hd95 = mean_table.HD.quantile(0.95).map('{:.2f}'.format)
107
  mean_table = mean_table.mean().round(3)
108
 
109
  std_table = full_table.copy()
 
115
  metrics_table = mean_table.applymap('{:.2f}'.format) + u"\u00B1" + std_table.applymap('{:.2f}'.format)
116
  time_col = metrics_table.pop('Time')
117
  metrics_table.insert(len(metrics_table.columns), 'Time', time_col)
118
+ # metrics_table.insert(4, 'HD95', hd95)
119
+ metrics_table.rename(columns={'DICE_MACRO': 'DSC', 'Time': 'Runtime'}, inplace=True)
120
 
121
  metrics_file = os.path.join(args.output, args.filename + '.tex')
122
  if os.path.exists(metrics_file) and args.overwrite:
123
  shutil.rmtree(metrics_file, ignore_errors=True)
124
  metrics_table.to_latex(metrics_file,
125
+ bold_rows=True,
126
+ column_format='r' + 'c' * len(metrics_table.columns),
127
+ caption='Average and standard deviation of the metrics: MSE, NCC, SSIM, DSC and HD.')
128
  elif os.path.exists(metrics_file):
129
  warnings.warn('File {} already exists. Skipping'.format(metrics_file))
130
  else:
131
  metrics_table.to_latex(metrics_file,
132
+ bold_rows=True,
133
+ column_format='r' + 'c' * len(metrics_table.columns),
134
+ caption='Average and standard deviation of the metrics: MSE, NCC, SSIM, DSC and HD.')
135
 
136
+ print(metrics_table)
137
  print('Done')
138
  else:
139
  print('No files found in {}!'.format(args.dir))