File size: 5,885 Bytes
286a978
 
 
99b6efe
286a978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99b6efe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286a978
 
 
 
 
 
 
 
 
 
 
99b6efe
286a978
 
 
 
 
99b6efe
 
 
 
 
286a978
99b6efe
 
 
 
 
286a978
 
 
 
 
99b6efe
 
286a978
 
 
 
99b6efe
286a978
 
 
 
99b6efe
 
 
 
286a978
 
 
 
 
 
 
 
 
99b6efe
286a978
 
 
 
 
 
 
 
 
 
 
99b6efe
 
286a978
 
 
 
 
99b6efe
 
 
286a978
 
 
 
99b6efe
 
 
286a978
99b6efe
286a978
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import warnings

import pandas as pd
pd.options.display.max_columns = 10
import os
import argparse
import re
import shutil

DICT_MODEL_NAMES = {'BASELINE': 'BL',
                    'SEGGUIDED': 'SG',
                    'UW': 'UW'}

DICT_METRICS_NAMES = {'NCC': 'N',
                      'SSIM': 'S',
                      'DICE': 'D',
                      'DICE MACRO': 'D',
                      'HD': 'H', }


def row_name(in_path: str):
    model = re.search('((UW|SEGGUIDED|BASELINE).*)_\d', in_path)
    ret_val = None
    if model:
        model = model.group(1).rstrip('_')
        model = model.replace('_Lsim', '')
        model = model.replace('_Lseg', '')
        model = model.replace('_L', '')
        model = model.replace('_', ' ')
        model = model.upper()
        elements = model.split()
        model = elements[0]
        metrics = list()
        model = DICT_MODEL_NAMES[model]
        for m in elements[1:]:
            if m != 'MACRO':
                metrics.append(DICT_METRICS_NAMES[m])

        ret_val = '{}-{}'.format(model, ''.join(metrics))
    elif re.search('((COMET|IXI).*)', in_path):
        model = re.search('((COMET|IXI).*)', in_path)
        ret_val = model.group(1).split('_')[0]
    else:
        try:
            ret_val = re.search('(SyNCC|SyN)', in_path).group(1)
        except AttributeError:
            raise ValueError('Unknown folder name/model: '+ in_path)
    return ret_val


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('-d', '--dir', nargs='+', type=str, help='List of directories where metrics.csv file is',
                        default=None)
    parser.add_argument('-o', '--output', type=str, help='Output directory', default=os.getcwd())
    parser.add_argument('--overwrite', type=bool, default=True)
    parser.add_argument('--filename', type=str, help='Output file name', default='metrics')
    parser.add_argument('--removemetrics', nargs='+', type=str, default=None)
    parser.add_argument('--metrics-folder', type=str, default=None)
    args = parser.parse_args()
    assert args.dir is not None, "No directories provided. Stopping"

    if len(args.dir) == 1:
        list_files = list()
        if args.metrics_folder:
            file_found_condition = lambda name: 'metrics.csv' == name and args.metrics_folder in r.split(os.sep)
        else:
            file_found_condition = lambda name: 'metrics.csv' == name
        starting_level = args.dir[0].count(os.sep)
        for r, d, f in os.walk(args.dir[0]):
            level = r.count(os.sep) - starting_level
            if level < 3:
                for name in f:
                    if file_found_condition(name):
                        list_files.append(os.path.join(r, name))
    else:
        list_files = [os.path.join(d, 'metrics.csv') for d in args.dir]

        for d in list_files:
            assert os.path.exists(d), "Missing metrics.csv file in: " + os.path.split(d)[0]
    list_files.sort()
    print('Metric files found ({}):\n\t{}'.format(len(list_files), '\n\t'.join(list_files)))

    dataframes = list()
    if len(list_files):
        for d in list_files:
            df = pd.read_csv(d, sep=';', header=0, dtype={'TRE':float})
            model = row_name(d)

            df.insert(0, "Model", model)
            df.drop(columns=list(df.filter(regex='Unnamed')), inplace=True)
            if not 'SyN' in model:
                df.drop(columns=['File', 'MSE', 'No_missing_lbls'], inplace=True)
            else:
                df.drop(columns=['File', 'MSE'], inplace=True)
            dataframes.append(df)

        full_table = pd.concat(dataframes)
        if args.removemetrics is not None:
            full_table = full_table.drop(columns=args.removemetrics)
        mean_table = full_table.copy()
        # mean_table.insert(column='Type', value='Avg.', loc=1)
        # mean_table = mean_table.groupby(['Type', 'Model']).mean().round(3)
        mean_table = mean_table.groupby(['Model'])
        # hd95 = mean_table.HD.quantile(0.95).map('{:.2f}'.format)
        mean_table = mean_table.mean().round(3)

        std_table = full_table.copy()
        # std_table.insert(column='Type', value='STD', loc=1)
        # std_table = std_table.groupby(['Type', 'Model']).std().round(3)
        std_table = std_table.groupby(['Model']).std().round(3)

        # metrics_table = pd.concat([mean_table, std_table]).swaplevel(axis='rows')
        metrics_table = mean_table.applymap('{:.2f}'.format) + u"\u00B1" + std_table.applymap('{:.2f}'.format)
        time_col = metrics_table.pop('Time')
        metrics_table.insert(len(metrics_table.columns), 'Time', time_col)
        # metrics_table.insert(4, 'HD95', hd95)
        metrics_table.rename(columns={'DICE_MACRO': 'DSC', 'Time': 'Runtime'}, inplace=True)

        metrics_file = os.path.join(args.output, args.filename + '.tex')
        if os.path.exists(metrics_file) and args.overwrite:
            shutil.rmtree(metrics_file, ignore_errors=True)
            metrics_table.to_latex(metrics_file,
                                   bold_rows=True,
                                   column_format='r' + 'c' * len(metrics_table.columns),
                                   caption='Average and standard deviation of the metrics: MSE, NCC, SSIM, DSC and HD.')
        elif os.path.exists(metrics_file):
            warnings.warn('File {} already exists. Skipping'.format(metrics_file))
        else:
            metrics_table.to_latex(metrics_file,
                                   bold_rows=True,
                                   column_format='r' + 'c' * len(metrics_table.columns),
                                   caption='Average and standard deviation of the metrics: MSE, NCC, SSIM, DSC and HD.')

        print(metrics_table)
        print('Done')
    else:
        print('No files found in {}!'.format(args.dir))