Spaces:
Runtime error
Runtime error
| import argparse | |
| import logging | |
| import re | |
| from datetime import datetime | |
| import os | |
| import numpy as np | |
| import torch | |
| from nltk import bleu, meteor | |
| from rouge_score.rouge_scorer import RougeScorer | |
| from tqdm import tqdm | |
| from src.distinct_n.distinct_n.metrics import distinct_n_corpus_level as distinct_n | |
| from inductor import BartInductor, CometInductor | |
| FILES = { | |
| 'amie-yago2': 'data/RE-datasets/AMIE-yago2.txt', | |
| 'rules-yago2': 'data/RE-datasets/RuLES-yago2.txt', | |
| "openrule155": "data/OpenRule155.txt", | |
| 'fewrel': 'data/RE/fewrel-5.txt', | |
| 'semeval': 'data/RE/semeval-5.txt', | |
| 'TREx': 'data/RE/trex-5.txt', | |
| 'nyt10': 'data/RE/nyt10-5.txt', | |
| 'google-re': 'data/RE/google-re-5.txt', | |
| 'wiki80': 'data/RE/wiki80-5.txt', | |
| } | |
| if not os.path.exists('logs/'): | |
| os.mkdir('logs/') | |
| logging.basicConfig( | |
| filename='logs/evaluation-{}.log'.format(str(datetime.now())), | |
| format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |
| datefmt='%m/%d/%Y %H:%M:%S', | |
| level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def print_config(config): | |
| config = vars(config) | |
| logger.info("**************** MODEL CONFIGURATION ****************") | |
| for key in sorted(config.keys()): | |
| val = config[key] | |
| keystr = "{}".format(key) + (" " * (25 - len(key))) | |
| logger.info("{} --> {}".format(keystr, val)) | |
| logger.info("**************** MODEL CONFIGURATION ****************") | |
| scorer = RougeScorer(['rougeL'], use_stemmer=True) | |
| def rouge(references, hypothesis): | |
| scores = [] | |
| for reference in references: | |
| scores.append( | |
| scorer.score( | |
| reference, | |
| hypothesis)['rougeL'][2] | |
| ) | |
| return max(scores) | |
| class RelationExtractionEvaluator(object): | |
| def __init__(self, args): | |
| self.args = args | |
| if self.args.inductor == 'rule': | |
| self.inductor = BartInductor( | |
| group_beam=self.args.group_beam, | |
| continue_pretrain_instance_generator=self.args.mlm_training, | |
| continue_pretrain_hypo_generator=self.args.bart_training, | |
| if_then=self.args.if_then, | |
| ) | |
| elif self.args.inductor == 'comet': | |
| self.inductor = CometInductor() | |
| def clean(self, text): | |
| segments = text.split('<mask>') | |
| if len(segments) == 3 and segments[2].startswith('.'): | |
| return '<mask>'.join(segments[:2]) + '<mask>.' | |
| else: | |
| return text | |
| def clean_references(self, texts): | |
| for i, text in enumerate(texts): | |
| if text.endswith(" ."): | |
| texts[i] = text.replace(" .", ".") | |
| return texts | |
| def self_bleu(self, hypothesis): | |
| bleus = [] | |
| for i in range(len(hypothesis)): | |
| bleus.append(bleu( | |
| hypothesis[:i] + hypothesis[i + 1:], | |
| hypothesis[i], | |
| weights=(0.5, 0.5))) | |
| ret = np.mean(bleus) | |
| return ret | |
| def evaluate(self, task): | |
| with torch.no_grad(): | |
| self.metrics = { | |
| "bleu-4": [], | |
| "bleu-3": [], | |
| "bleu-2": [], | |
| "bleu-1": [], | |
| "METEOR": [], | |
| "ROUGE-L": [], | |
| "self-BLEU-2": [], | |
| } | |
| with open(FILES[task], 'r', encoding='utf-8') as file: | |
| data = file.readlines() | |
| with tqdm(total=len(data)) as pbar: | |
| for row in data: | |
| pbar.update(1) | |
| row = row.strip().split('\t') | |
| inputs, head, tail, relations = row[0], row[1], row[2], row[3] | |
| inputs = inputs.strip() | |
| if relations.startswith('[') and relations.endswith(']'): | |
| inputs = re.sub("<A>|<B>", "<mask>", inputs) | |
| references = [relation.replace('<A>', '<mask>').replace('<B>', '<mask>').lower().strip() for relation in eval(relations)] | |
| else: | |
| references = [relations.replace('[X]', '<mask>').replace('[Y]', '<mask>').lower().strip()] | |
| references = self.clean_references(references) | |
| hypothesis = self.inductor.generate(inputs, k=10, topk=10) | |
| logger.info("***********Input************") | |
| logger.info(inputs) | |
| logger.info("*********Hypothesis*********") | |
| for i, hypo in enumerate(hypothesis): | |
| hypothesis[i] = self.clean(hypo.lower().strip()) | |
| logger.info(hypo) | |
| logger.info("****************************") | |
| logger.info("*********References*********") | |
| logger.info(references) | |
| logger.info("****************************") | |
| if len(hypothesis) == 0: | |
| for k in self.metrics.keys(): | |
| if k != 'self-BLEU-2': | |
| self.metrics[k].append(0.) | |
| else: | |
| for hypo in hypothesis: | |
| try: | |
| self.metrics['bleu-4'].append( | |
| bleu( | |
| [reference.split() for reference in references], | |
| hypo.split(), | |
| weights=(0.25, 0.25, 0.25, 0.25) | |
| ) | |
| ) | |
| except Exception: | |
| logger.warning("Skip bleu-4 in example: {}".format(inputs)) | |
| pass | |
| try: | |
| self.metrics['bleu-3'].append( | |
| bleu( | |
| [reference.split() for reference in references], | |
| hypo.split(), | |
| weights=(1 / 3, ) * 3 | |
| ) | |
| ) | |
| except Exception: | |
| logger.warning("Skip bleu-3 in example: {}".format(inputs)) | |
| pass | |
| try: | |
| self.metrics['bleu-2'].append( | |
| bleu( | |
| [reference.split() for reference in references], | |
| hypo.split(), | |
| weights=(0.5, 0.5) | |
| ) | |
| ) | |
| except Exception: | |
| logger.warning("Skip bleu-2 in example: {}".format(inputs)) | |
| pass | |
| try: | |
| self.metrics['bleu-1'].append( | |
| bleu( | |
| [reference.split() for reference in references], | |
| hypo.split(), | |
| weights=(1.0, ) | |
| ) | |
| ) | |
| except Exception: | |
| logger.warning("Skip bleu-1 in example: {}".format(inputs)) | |
| pass | |
| try: | |
| self.metrics['METEOR'].append( | |
| meteor( | |
| references, | |
| hypo, | |
| ) | |
| ) | |
| except: | |
| logger.warning("Skip METEOR in example: {}".format(inputs)) | |
| pass | |
| try: | |
| self.metrics['ROUGE-L'].append( | |
| rouge( | |
| references, | |
| hypo, | |
| ) | |
| ) | |
| except: | |
| logger.warning("Skip ROUGE-L in example: {}".format(inputs)) | |
| pass | |
| try: | |
| self.metrics['self-BLEU-2'].append( | |
| self.self_bleu( | |
| hypothesis, | |
| ) | |
| ) | |
| except: | |
| logger.warning("Skip self-bleu-2 in example: {}.".format(inputs)) | |
| pass | |
| # break | |
| self.print(task, self.metrics) | |
| def print(self, task, metrics): | |
| logger.info("Task: {}".format(str(task))) | |
| for k, v in metrics.items(): | |
| logger.info("{}: {}".format(k, str(np.mean(v)))) | |
| logger.info("*******************************************************") | |
| logger.info("*******************************************************") | |
| logger.info("*******************************************************") | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--inductor", type=str, default='rule') | |
| parser.add_argument("--group_beam", type=bool, default=False) | |
| parser.add_argument("--mlm_training", type=bool, default=False) | |
| parser.add_argument("--bart_training", type=bool, default=False) | |
| parser.add_argument("--if_then", type=bool, default=False) | |
| parser.add_argument("--task", type=str, default='openrule155') | |
| args = parser.parse_args() | |
| print_config(args) | |
| evaluator = RelationExtractionEvaluator(args) | |
| evaluator.evaluate(args.task) | |