File size: 5,494 Bytes
5806e12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import ast
import os
import re
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support
from segmentation import segment_batchalign
from segmentation import segment_SaT
from segmentation import segment_SaT_cunit_3l
from segmentation import segment_SaT_cunit_12l
from segmentation import segment_SaT_cunit_3l_r32a64
from segmentation import segment_SaT_cunit_3l_r64a128
from segmentation import segment_SaT_cunit_3l_no_shuffle

from tqdm import tqdm

def clean_text(text):
    return re.sub(r"[^\w\s]", "", text.lower()).strip()


def eval_segmentation(dataset_path, segmentation_model, model_name="unknown", chunk_num=10):
    os.makedirs("benchmark_result/segmentation", exist_ok=True)

    df = pd.read_csv(dataset_path)
    results = []

    for i in tqdm(range(0, len(df), chunk_num), desc="Evaluating chunks"):
        chunk = df.iloc[i:i + chunk_num]
        if len(chunk) < chunk_num:
            continue

        word_sequence = []
        gt_label_sequence = []

        for row in chunk["cleaned_transcription"]:
            if pd.isna(row):
                continue
            cleaned = clean_text(row)
            words = cleaned.split()
            if not words:
                continue
            word_sequence.extend(words)
            gt_label_sequence.extend([0] * (len(words) - 1) + [1])

        input_text = " ".join(word_sequence)
        predicted_labels = segmentation_model(input_text)

        if len(predicted_labels) != len(gt_label_sequence):
            print(f"Label length mismatch at chunk {i}. Skipping...")
            continue

        results.append({
            "word_sequence": input_text,
            "gt_label_sequence": " ".join(map(str, gt_label_sequence)),
            "predict_label_sequence": " ".join(map(str, predicted_labels))
        })

    result_df = pd.DataFrame(results)
    result_df.to_csv(f"benchmark_result/segmentation/{model_name}_results.csv", index=False)

    all_gt = []
    all_pred = []

    for row in results:
        all_gt.extend(map(int, row["gt_label_sequence"].split()))
        all_pred.extend(map(int, row["predict_label_sequence"].split()))

    tp = sum((g == 1 and p == 1) for g, p in zip(all_gt, all_pred))
    fp = sum((g == 0 and p == 1) for g, p in zip(all_gt, all_pred))
    fn = sum((g == 1 and p == 0) for g, p in zip(all_gt, all_pred))

    precision, recall, f1, _ = precision_recall_fscore_support(all_gt, all_pred, average='binary', zero_division=0)

    print(f"{model_name} - TP: {tp}, FP: {fp}, FN: {fn}")
    print(f"{model_name} - Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}")
    
    return precision, recall, f1


if __name__ == "__main__":
    dataset_path = "./data/enni_salt_for_segmentation/test.csv"
    
    # print("Evaluating BatchAlign segmentation model...")
    # batchalign_precision, batchalign_recall, batchalign_f1 = eval_segmentation(
        # dataset_path, segment_batchalign, "batchalign"
    # )
    
    print("\nEvaluating SaT segmentation model...")
    sat_precision, sat_recall, sat_f1 = eval_segmentation(
        dataset_path, segment_SaT, "SaT"
    )
    
    print("\nEvaluating SaT_cunit_3l segmentation model...")
    sat_cunit_3l_precision, sat_cunit_3l_recall, sat_cunit_3l_f1 = eval_segmentation(
        dataset_path, segment_SaT_cunit_3l, "SaT_cunit_3l"
    )
    
    print("\nEvaluating SaT_cunit_12l segmentation model...")
    sat_cunit_12l_precision, sat_cunit_12l_recall, sat_cunit_12l_f1 = eval_segmentation(
        dataset_path, segment_SaT_cunit_12l, "SaT_cunit_12l"
    )
    
    print("\nEvaluating SaT_cunit_3l_r32a64 segmentation model...")
    sat_cunit_3l_r32a64_precision, sat_cunit_3l_r32a64_recall, sat_cunit_3l_r32a64_f1 = eval_segmentation(
        dataset_path, segment_SaT_cunit_3l_r32a64, "SaT_cunit_3l_r32a64"
    )
    
    print("\nEvaluating SaT_cunit_3l_r64a128 segmentation model...")
    sat_cunit_3l_r64a128_precision, sat_cunit_3l_r64a128_recall, sat_cunit_3l_r64a128_f1 = eval_segmentation(
        dataset_path, segment_SaT_cunit_3l_r64a128, "SaT_cunit_3l_r64a128"
    )
    
    print("\nEvaluating SaT_cunit_3l_no_shuffle segmentation model...")
    sat_cunit_3l_no_shuffle_precision, sat_cunit_3l_no_shuffle_recall, sat_cunit_3l_no_shuffle_f1 = eval_segmentation(
        dataset_path, segment_SaT_cunit_3l_no_shuffle, "SaT_cunit_3l_no_shuffle"
    )
    
    print("\n" + "="*80)
    print("COMPARISON RESULTS:")
    print("="*80)
    # print(f"BatchAlign              - Precision: {batchalign_precision:.3f}, Recall: {batchalign_recall:.3f}, F1: {batchalign_f1:.3f}")
    print(f"SaT                     - Precision: {sat_precision:.3f}, Recall: {sat_recall:.3f}, F1: {sat_f1:.3f}")
    print(f"SaT_cunit_3l            - Precision: {sat_cunit_3l_precision:.3f}, Recall: {sat_cunit_3l_recall:.3f}, F1: {sat_cunit_3l_f1:.3f}")
    print(f"SaT_cunit_12l           - Precision: {sat_cunit_12l_precision:.3f}, Recall: {sat_cunit_12l_recall:.3f}, F1: {sat_cunit_12l_f1:.3f}")
    print(f"SaT_cunit_3l_r32a64     - Precision: {sat_cunit_3l_r32a64_precision:.3f}, Recall: {sat_cunit_3l_r32a64_recall:.3f}, F1: {sat_cunit_3l_r32a64_f1:.3f}")
    print(f"SaT_cunit_3l_r64a128    - Precision: {sat_cunit_3l_r64a128_precision:.3f}, Recall: {sat_cunit_3l_r64a128_recall:.3f}, F1: {sat_cunit_3l_r64a128_f1:.3f}")
    print(f"SaT_cunit_3l_no_shuffle - Precision: {sat_cunit_3l_no_shuffle_precision:.3f}, Recall: {sat_cunit_3l_no_shuffle_recall:.3f}, F1: {sat_cunit_3l_no_shuffle_f1:.3f}")