File size: 2,432 Bytes
519d358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dora import Explorer
import treetable as tt


class MyExplorer(Explorer):
    test_metrics = ['nsdr', 'sdr_med']

    def get_grid_metrics(self):
        """Return the metrics that should be displayed in the tracking table.

        """
        return [
            tt.group("train", [
                tt.leaf("epoch"),
                tt.leaf("reco", ".3f"),
             ], align=">"),
            tt.group("valid", [
                tt.leaf("penalty", ".1f"),
                tt.leaf("ms", ".1f"),
                tt.leaf("reco", ".2%"),
                tt.leaf("breco", ".2%"),
                tt.leaf("b_nsdr", ".2f"),
                # tt.leaf("b_nsdr_drums", ".2f"),
                # tt.leaf("b_nsdr_bass", ".2f"),
                # tt.leaf("b_nsdr_other", ".2f"),
                # tt.leaf("b_nsdr_vocals", ".2f"),
             ], align=">"),
            tt.group("test", [
                tt.leaf(name, ".2f")
                for name in self.test_metrics
             ], align=">")
        ]

    def process_history(self, history):
        train = {
            'epoch': len(history),
        }
        valid = {}
        test = {}
        best_v_main = float('inf')
        breco = float('inf')
        for metrics in history:
            train.update(metrics['train'])
            valid.update(metrics['valid'])
            if 'main' in metrics['valid']:
                best_v_main = min(best_v_main, metrics['valid']['main']['loss'])
            valid['bmain'] = best_v_main
            valid['breco'] = min(breco, metrics['valid']['reco'])
            breco = valid['breco']
            if (metrics['valid']['loss'] == metrics['valid']['best'] or
                    metrics['valid'].get('nsdr') == metrics['valid']['best']):
                for k, v in metrics['valid'].items():
                    if k.startswith('reco_'):
                        valid['b_' + k[len('reco_'):]] = v
                    if k.startswith('nsdr'):
                        valid[f'b_{k}'] = v
            if 'test' in metrics:
                test.update(metrics['test'])
            metrics = history[-1]
        return {"train": train, "valid": valid, "test": test}