Spaces:
Sleeping
Sleeping
| # coding=utf-8 | |
| # Copyleft 2019 project LXRT. | |
| import json | |
| import torch | |
| class AnswerTable: | |
| ANS_CONVERT = { | |
| "a man": "man", | |
| "the man": "man", | |
| "a woman": "woman", | |
| "the woman": "woman", | |
| 'one': '1', | |
| 'two': '2', | |
| 'three': '3', | |
| 'four': '4', | |
| 'five': '5', | |
| 'six': '6', | |
| 'seven': '7', | |
| 'eight': '8', | |
| 'nine': '9', | |
| 'ten': '10', | |
| 'grey': 'gray', | |
| } | |
| def __init__(self, dsets=None): | |
| self.all_ans = json.load(open("data/lxmert/all_ans.json")) | |
| if dsets is not None: | |
| dsets = set(dsets) | |
| # If the answer is used in the dsets | |
| self.anss = [ans['ans'] for ans in self.all_ans if | |
| len(set(ans['dsets']) & dsets) > 0] | |
| else: | |
| self.anss = [ans['ans'] for ans in self.all_ans] | |
| self.ans_set = set(self.anss) | |
| self._id2ans_map = self.anss | |
| self._ans2id_map = {ans: ans_id for ans_id, ans in enumerate(self.anss)} | |
| assert len(self._id2ans_map) == len(self._ans2id_map) | |
| for ans_id, ans in enumerate(self._id2ans_map): | |
| assert self._ans2id_map[ans] == ans_id | |
| def convert_ans(self, ans): | |
| if len(ans) == 0: | |
| return "" | |
| ans = ans.lower() | |
| if ans[-1] == '.': | |
| ans = ans[:-1].strip() | |
| if ans.startswith("a "): | |
| ans = ans[2:].strip() | |
| if ans.startswith("an "): | |
| ans = ans[3:].strip() | |
| if ans.startswith("the "): | |
| ans = ans[4:].strip() | |
| if ans in self.ANS_CONVERT: | |
| ans = self.ANS_CONVERT[ans] | |
| return ans | |
| def ans2id(self, ans): | |
| return self._ans2id_map[ans] | |
| def id2ans(self, ans_id): | |
| return self._id2ans_map[ans_id] | |
| def ans2id_map(self): | |
| return self._ans2id_map.copy() | |
| def id2ans_map(self): | |
| return self._id2ans_map.copy() | |
| def used(self, ans): | |
| return ans in self.ans_set | |
| def all_answers(self): | |
| return self.anss.copy() | |
| def num_answers(self): | |
| return len(self.anss) | |
| def load_lxmert_qa(path, model, label2ans): | |
| """ | |
| Load model weights from lxmert pre-training. | |
| The answers in the fine-tuned QA task (indicated by label2ans) | |
| would also be properly initialized with lxmert pre-trained | |
| QA heads. | |
| :param path: Path to lxmert snapshot. | |
| :param model: LXRT model instance. | |
| :param label2ans: The label2ans dict of fine-tuned QA datasets, like | |
| {0: 'cat', 1: 'dog', ...} | |
| :return: | |
| """ | |
| print("Load QA pre-trained lxmert from %s " % path) | |
| loaded_state_dict = torch.load("%s_LXRT.pth" % path) | |
| model_state_dict = model.state_dict() | |
| # Handle Multi-GPU pre-training --> Single GPU fine-tuning | |
| for key in list(loaded_state_dict.keys()): | |
| loaded_state_dict[key.replace("module.", '')] = loaded_state_dict.pop(key) | |
| # Isolate bert model | |
| bert_state_dict = {} | |
| for key, value in loaded_state_dict.items(): | |
| if key.startswith('bert.'): | |
| bert_state_dict[key] = value | |
| # Isolate answer head | |
| answer_state_dict = {} | |
| for key, value in loaded_state_dict.items(): | |
| if key.startswith("answer_head."): | |
| answer_state_dict[key.replace('answer_head.', '')] = value | |
| # Do surgery on answer state dict | |
| ans_weight = answer_state_dict['logit_fc.3.weight'] | |
| ans_bias = answer_state_dict['logit_fc.3.bias'] | |
| import copy | |
| new_answer_weight = copy.deepcopy(model_state_dict['logit_fc.3.weight']) | |
| new_answer_bias = copy.deepcopy(model_state_dict['logit_fc.3.bias']) | |
| answer_table = AnswerTable() | |
| loaded = 0 | |
| unload = 0 | |
| if type(label2ans) is list: | |
| label2ans = {label: ans for label, ans in enumerate(label2ans)} | |
| for label, ans in label2ans.items(): | |
| new_ans = answer_table.convert_ans(ans) | |
| if answer_table.used(new_ans): | |
| ans_id_9500 = answer_table.ans2id(new_ans) | |
| new_answer_weight[label] = ans_weight[ans_id_9500] | |
| new_answer_bias[label] = ans_bias[ans_id_9500] | |
| loaded += 1 | |
| else: | |
| new_answer_weight[label] = 0. | |
| new_answer_bias[label] = 0. | |
| unload += 1 | |
| print("Loaded %d answers from LXRTQA pre-training and %d not" % (loaded, unload)) | |
| print() | |
| answer_state_dict['logit_fc.3.weight'] = new_answer_weight | |
| answer_state_dict['logit_fc.3.bias'] = new_answer_bias | |
| # Load Bert Weights | |
| bert_model_keys = set(model.lxrt_encoder.model.state_dict().keys()) | |
| bert_loaded_keys = set(bert_state_dict.keys()) | |
| assert len(bert_model_keys - bert_loaded_keys) == 0 | |
| model.lxrt_encoder.model.load_state_dict(bert_state_dict, strict=False) | |
| # Load Answer Logic FC Weights | |
| model_keys = set(model.state_dict().keys()) | |
| ans_loaded_keys = set(answer_state_dict.keys()) | |
| assert len(ans_loaded_keys - model_keys) == 0 | |
| model.load_state_dict(answer_state_dict, strict=False) | |