Spaces:
Build error
Build error
| import json | |
| from typing import List, Tuple | |
| import pandas as pd | |
| from sftp import SpanPredictor | |
| def main(): | |
| # data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_dev.jsonl" | |
| # data_file = "/home/p289731/cloned/lome/preproc/svm_challenge.jsonl" | |
| data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_test.jsonl" | |
| models = [ | |
| ( | |
| "lome-en", | |
| "/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", | |
| ), | |
| ( | |
| "lome-it-best", | |
| "/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz", | |
| ), | |
| # ( | |
| # "lome-it-freeze", | |
| # "/data/p289731/cloned/lome/train-evalita-plus-fn-freeze/model.tar.gz", | |
| # ), | |
| # ( | |
| # "lome-it-mono", | |
| # "/data/p289731/cloned/lome/train-evalita-it_mono/model.tar.gz", | |
| # ), | |
| ] | |
| for (model_name, model_path) in models: | |
| print("testing model: ", model_name) | |
| predictor = SpanPredictor.from_path(model_path) | |
| print("=== FD (run 1) ===") | |
| eval_frame_detection(data_file, predictor, model_name=model_name) | |
| for run in [1, 2]: | |
| print(f"=== BD (run {run}) ===") | |
| eval_boundary_detection(data_file, predictor, run=run) | |
| for run in [1, 2, 3]: | |
| print(f"=== AC (run {run}) ===") | |
| eval_argument_classification(data_file, predictor, run=run) | |
| def predict_frame( | |
| predictor: SpanPredictor, tokens: List[str], predicate_span: Tuple[int, int] | |
| ): | |
| _, labels, _ = predictor.force_decode(tokens, child_spans=[predicate_span]) | |
| return labels[0] | |
| def eval_frame_detection(data_file, predictor, verbose=False, model_name="_"): | |
| true_pos = 0 | |
| false_pos = 0 | |
| out = [] | |
| with open(data_file, encoding="utf-8") as f: | |
| for sent_id, sent in enumerate(f): | |
| sent_data = json.loads(sent) | |
| tokens = sent_data["tokens"] | |
| annotation = sent_data["annotations"][0] | |
| predicate_span = tuple(annotation["span"]) | |
| predicate = tokens[predicate_span[0] : predicate_span[1] + 1] | |
| frame_gold = annotation["label"] | |
| frame_pred = predict_frame(predictor, tokens, predicate_span) | |
| if frame_pred == frame_gold: | |
| true_pos += 1 | |
| else: | |
| false_pos += 1 | |
| out.append({ | |
| "sentence": " ".join(tokens), | |
| "predicate": predicate, | |
| "frame_gold": frame_gold, | |
| "frame_pred": frame_pred | |
| }) | |
| if verbose: | |
| print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") | |
| print(f"\tpredicate: {predicate}") | |
| print(f"\t gold: {frame_gold}") | |
| print(f"\tpredicted: {frame_pred}") | |
| print() | |
| acc_score = true_pos / (true_pos + false_pos) | |
| print("ACC =", acc_score) | |
| data_sect = "rai" if "svm_challenge" in data_file else "dev" if "dev" in data_file else "test" | |
| df_out = pd.DataFrame(out) | |
| df_out.to_csv(f"frame_prediction_output_{model_name}_{data_sect}.csv") | |
| def predict_boundaries(predictor: SpanPredictor, tokens, predicate_span, frame): | |
| boundaries, labels, _ = predictor.force_decode( | |
| tokens, parent_span=predicate_span, parent_label=frame | |
| ) | |
| out = [] | |
| for bnd, lab in zip(boundaries, labels): | |
| bnd = tuple(bnd) | |
| if bnd == predicate_span and lab == "Target": | |
| continue | |
| out.append(bnd) | |
| return out | |
| def get_gold_boundaries(annotation, predicate_span): | |
| return { | |
| tuple(c["span"]) | |
| for c in annotation["children"] | |
| if not (tuple(c["span"]) == predicate_span and c["label"] == "Target") | |
| } | |
| def eval_boundary_detection(data_file, predictor, run=1, verbose=False): | |
| assert run in [1, 2] | |
| true_pos = 0 | |
| false_pos = 0 | |
| false_neg = 0 | |
| true_pos_tok = 0 | |
| false_pos_tok = 0 | |
| false_neg_tok = 0 | |
| with open(data_file, encoding="utf-8") as f: | |
| for sent_id, sent in enumerate(f): | |
| sent_data = json.loads(sent) | |
| tokens = sent_data["tokens"] | |
| annotation = sent_data["annotations"][0] | |
| predicate_span = tuple(annotation["span"]) | |
| predicate = tokens[predicate_span[0] : predicate_span[1] + 1] | |
| if run == 1: | |
| frame = predict_frame(predictor, tokens, predicate_span) | |
| else: | |
| frame = annotation["label"] | |
| boundaries_gold = get_gold_boundaries(annotation, predicate_span) | |
| boundaries_pred = set( | |
| predict_boundaries(predictor, tokens, predicate_span, frame) | |
| ) | |
| sent_true_pos = len(boundaries_gold & boundaries_pred) | |
| sent_false_pos = len(boundaries_pred - boundaries_gold) | |
| sent_false_neg = len(boundaries_gold - boundaries_pred) | |
| true_pos += sent_true_pos | |
| false_pos += sent_false_pos | |
| false_neg += sent_false_neg | |
| boundary_toks_gold = { | |
| tok_idx | |
| for (start, stop) in boundaries_gold | |
| for tok_idx in range(start, stop + 1) | |
| } | |
| boundary_toks_pred = { | |
| tok_idx | |
| for (start, stop) in boundaries_pred | |
| for tok_idx in range(start, stop + 1) | |
| } | |
| sent_tok_true_pos = len(boundary_toks_gold & boundary_toks_pred) | |
| sent_tok_false_pos = len(boundary_toks_pred - boundary_toks_gold) | |
| sent_tok_false_neg = len(boundary_toks_gold - boundary_toks_pred) | |
| true_pos_tok += sent_tok_true_pos | |
| false_pos_tok += sent_tok_false_pos | |
| false_neg_tok += sent_tok_false_neg | |
| if verbose: | |
| print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") | |
| print(f"\tpredicate: {predicate}") | |
| print(f"\t frame: {frame}") | |
| print(f"\t gold: {boundaries_gold}") | |
| print(f"\tpredicted: {boundaries_pred}") | |
| print(f"\ttp={sent_true_pos}\tfp={sent_false_pos}\tfn={sent_false_neg}") | |
| print( | |
| f"\ttp_t={sent_tok_true_pos}\tfp_t={sent_tok_false_pos}\tfn_t={sent_tok_false_neg}" | |
| ) | |
| print() | |
| prec = true_pos / (true_pos + false_pos) | |
| rec = true_pos / (true_pos + false_neg) | |
| f1_score = 2 * ((prec * rec) / (prec + rec)) | |
| print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}") | |
| tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok) | |
| tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok) | |
| tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec)) | |
| print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}") | |
| def predict_arguments( | |
| predictor: SpanPredictor, tokens, predicate_span, frame, boundaries | |
| ): | |
| boundaries = list(sorted(boundaries, key=lambda t: t[0])) | |
| _, labels, _ = predictor.force_decode( | |
| tokens, parent_span=predicate_span, parent_label=frame, child_spans=boundaries | |
| ) | |
| out = [] | |
| for bnd, lab in zip(boundaries, labels): | |
| if bnd == predicate_span and lab == "Target": | |
| continue | |
| out.append((bnd, lab)) | |
| return out | |
| def eval_argument_classification(data_file, predictor, run=1, verbose=False): | |
| assert run in [1, 2, 3] | |
| true_pos = 0 | |
| false_pos = 0 | |
| false_neg = 0 | |
| true_pos_tok = 0 | |
| false_pos_tok = 0 | |
| false_neg_tok = 0 | |
| with open(data_file, encoding="utf-8") as f: | |
| for sent_id, sent in enumerate(f): | |
| sent_data = json.loads(sent) | |
| tokens = sent_data["tokens"] | |
| annotation = sent_data["annotations"][0] | |
| predicate_span = tuple(annotation["span"]) | |
| predicate = tokens[predicate_span[0] : predicate_span[1] + 1] | |
| # gold or predicted frames? | |
| if run == 1: | |
| frame = predict_frame(predictor, tokens, predicate_span) | |
| else: | |
| frame = annotation["label"] | |
| # gold or predicted argument boundaries? | |
| if run in [1, 2]: | |
| boundaries = set( | |
| predict_boundaries(predictor, tokens, predicate_span, frame) | |
| ) | |
| else: | |
| boundaries = get_gold_boundaries(annotation, predicate_span) | |
| pred_arguments = predict_arguments( | |
| predictor, tokens, predicate_span, frame, boundaries | |
| ) | |
| gold_arguments = { | |
| (tuple(c["span"]), c["label"]) | |
| for c in annotation["children"] | |
| if not (tuple(c["span"]) == predicate_span and c["label"] == "Target") | |
| } | |
| if verbose: | |
| print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") | |
| print(f"\tpredicate: {predicate}") | |
| print(f"\t frame: {frame}") | |
| print(f"\t gold: {gold_arguments}") | |
| print(f"\tpredicted: {pred_arguments}") | |
| print() | |
| # -- full spans version | |
| for g_bnd, g_label in gold_arguments: | |
| # true positive: found the span and labeled it correctly | |
| if (g_bnd, g_label) in pred_arguments: | |
| true_pos += 1 | |
| # false negative: missed this argument | |
| else: | |
| false_neg += 1 | |
| for p_bnd, p_label in pred_arguments: | |
| # all predictions that are not true positives are false positives | |
| if (p_bnd, p_label) not in gold_arguments: | |
| false_pos += 1 | |
| # -- token based | |
| tok_gold_labels = { | |
| (token, label) | |
| for ((bnd_start, bnd_end), label) in gold_arguments | |
| for token in range(bnd_start, bnd_end + 1) | |
| } | |
| tok_pred_labels = { | |
| (token, label) | |
| for ((bnd_start, bnd_end), label) in pred_arguments | |
| for token in range(bnd_start, bnd_end + 1) | |
| } | |
| for g_tok, g_tok_label in tok_gold_labels: | |
| if (g_tok, g_tok_label) in tok_pred_labels: | |
| true_pos_tok += 1 | |
| else: | |
| false_neg_tok += 1 | |
| for p_tok, p_tok_label in tok_pred_labels: | |
| if (p_tok, p_tok_label) not in tok_gold_labels: | |
| false_pos_tok += 1 | |
| prec = true_pos / (true_pos + false_pos) | |
| rec = true_pos / (true_pos + false_neg) | |
| f1_score = 2 * ((prec * rec) / (prec + rec)) | |
| print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}") | |
| tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok) | |
| tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok) | |
| tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec)) | |
| print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}") | |
| if __name__ == "__main__": | |
| main() | |