File size: 2,791 Bytes
a2b5ed5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from typing import Dict


def classify_predictions(gold: dict, pred: dict, union=False) -> Dict[str, float]:
    """
    Returns true positives, false positives, and false negatives for one example
    If union is True, then disregards the type of the tag and only considers the union of all tags
    """
    n_tp = 0
    n_fp = 0
    n_fn = 0
    if union:
        gold_phrases = set(phrase for phrases in gold.values() for phrase in phrases)
        pred_phrases = set(phrase for phrases in pred.values() for phrase in phrases)
        n_tp = len(gold_phrases & pred_phrases)
        n_fp = len(pred_phrases - gold_phrases)
        n_fn = len(gold_phrases - pred_phrases)
        return n_tp, n_fp, n_fn

    for tag in set(gold.keys()).union(pred.keys()):
        gold_phrases = set(gold.get(tag, []))
        pred_phrases = set(pred.get(tag, []))

        n_tp += len(gold_phrases & pred_phrases)
        n_fp += len(pred_phrases - gold_phrases)
        n_fn += len(gold_phrases - pred_phrases)

    return n_tp, n_fp, n_fn


def compute_metrics(running_time, pred_times, runtype, eval_metrics=None):
    metrics = {}
    metrics["avg_pred_response_time_per_sentence"] = (
        round(sum(pred_times) / len(pred_times), 4) if pred_times else 0
    )
    metrics["total_time"] = round(running_time, 4)

    if runtype == "eval" and eval_metrics is not None:
        n_tp, n_fp, n_fn, n_tp_union, n_fp_union, n_fn_union = eval_metrics

        precision = round(n_tp / (n_tp + n_fp) if (n_tp + n_fp) > 0 else 0, 4)
        recall = round(n_tp / (n_tp + n_fn) if (n_tp + n_fn) > 0 else 0, 4)
        f1 = round(
            (
                2 * (precision * recall) / (precision + recall)
                if (precision + recall) > 0
                else 0
            ),
            4,
        )
        union_precision = round(
            (
                n_tp_union / (n_tp_union + n_fp_union)
                if (n_tp_union + n_fp_union) > 0
                else 0
            ),
            4,
        )
        union_recall = round(
            (
                n_tp_union / (n_tp_union + n_fn_union)
                if (n_tp_union + n_fn_union) > 0
                else 0
            ),
            4,
        )
        union_f1 = round(
            (
                2 * (union_precision * union_recall) / (union_precision + union_recall)
                if (union_precision + union_recall) > 0
                else 0
            ),
            4,
        )

        metrics.update(
            {
                "precision": precision,
                "recall": recall,
                "f1": f1,
                "union_precision": union_precision,
                "union_recall": union_recall,
                "union_f1": union_f1,
            }
        )

    return metrics