File size: 1,794 Bytes
7f7285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import json
import os
import argparse

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--gold_file')
    parser.add_argument('--retrieval_file')
    parser.add_argument('--output')
    parser.add_argument('--test', action='store_true', default=False)
    args = parser.parse_args()
    filter_dict = dict()
    data_dict = dict()
    golden_dict = dict()
    with open(args.gold_file) as f:
        for line in f:
            data = json.loads(line)
            data_dict[data["id"]] = {"id": data["id"], "evidence":[], "claim": data["claim"]}
            if "label" in data:
                data_dict[data["id"]]["label"] = data["label"]
            if not args.test:
                for evidence in data["evidence"]:
                    data_dict[data["id"]]["evidence"].append([evidence[0], evidence[1], evidence[2], 1.0])
                    string = str(data["id"]) + "_" + evidence[0] + "_" + str(evidence[1])
                    golden_dict[string] = 1
    with open(args.retrieval_file) as f:
        for line in f:
            data = json.loads(line)
            for step, evidence in enumerate(data["evidence"]):
                string = str(data["id"]) + "_" + str(evidence[0]) + "_" + str(evidence[1])
                if string not in golden_dict and string not in filter_dict:
                    data_dict[data["id"]]["evidence"].append([evidence[0], evidence[1], evidence[2], evidence[4]])
                    filter_dict[string] = 1
    with open(args.output, "w") as out:
        for data in data_dict.values():
            evidence_tmp = data["evidence"]
            evidence_tmp = sorted(evidence_tmp, key=lambda x:x[3], reverse=True)
            data["evidence"] = evidence_tmp[:5]
            out.write(json.dumps(data) + "\n")