Spaces:
Build error
Build error
| # coding=utf-8 | |
| # Copyright 2021 The IDEA Authors. All rights reserved. | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from logging import basicConfig | |
| import torch | |
| from torch import nn | |
| import json | |
| from tqdm import tqdm | |
| import os | |
| import numpy as np | |
| from transformers import BertTokenizer, AutoTokenizer | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from pytorch_lightning import loggers | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers.optimization import get_linear_schedule_with_warmup | |
| from transformers import BertForMaskedLM, AlbertTokenizer | |
| from transformers import AutoConfig | |
| from transformers import MegatronBertForMaskedLM | |
| import argparse | |
| import copy | |
| import streamlit as st | |
| # os.environ["CUDA_VISIBLE_DEVICES"] = '6' | |
| class UniMCDataset(Dataset): | |
| def __init__(self, data, yes_token, no_token, tokenizer, args, used_mask=True): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.max_length = args.max_length | |
| self.num_labels = args.num_labels | |
| self.used_mask = used_mask | |
| self.data = data | |
| self.args = args | |
| self.yes_token = yes_token | |
| self.no_token = no_token | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, index): | |
| return self.encode(self.data[index], self.used_mask) | |
| def get_token_type(self, sep_idx, max_length): | |
| token_type_ids = np.zeros(shape=(max_length,)) | |
| for i in range(len(sep_idx)-1): | |
| if i % 2 == 0: | |
| ty = np.ones(shape=(sep_idx[i+1]-sep_idx[i],)) | |
| else: | |
| ty = np.zeros(shape=(sep_idx[i+1]-sep_idx[i],)) | |
| token_type_ids[sep_idx[i]:sep_idx[i+1]] = ty | |
| return token_type_ids | |
| def get_position_ids(self, label_idx, max_length, question_len): | |
| question_position_ids = np.arange(question_len) | |
| label_position_ids = np.arange(question_len, label_idx[-1]) | |
| for i in range(len(label_idx)-1): | |
| label_position_ids[label_idx[i]-question_len:label_idx[i+1]-question_len] = np.arange( | |
| question_len, question_len+label_idx[i+1]-label_idx[i]) | |
| max_len_label = max(label_position_ids) | |
| text_position_ids = np.arange( | |
| max_len_label+1, max_length+max_len_label+1-label_idx[-1]) | |
| position_ids = list(question_position_ids) + \ | |
| list(label_position_ids)+list(text_position_ids) | |
| if max_length <= 512: | |
| return position_ids[:max_length] | |
| else: | |
| for i in range(512, max_length): | |
| if position_ids[i] > 511: | |
| position_ids[i] = 511 | |
| return position_ids[:max_length] | |
| def get_att_mask(self, attention_mask, label_idx, question_len): | |
| max_length = len(attention_mask) | |
| attention_mask = np.array(attention_mask) | |
| attention_mask = np.tile(attention_mask[None, :], (max_length, 1)) | |
| zeros = np.zeros( | |
| shape=(label_idx[-1]-question_len, label_idx[-1]-question_len)) | |
| attention_mask[question_len:label_idx[-1], | |
| question_len:label_idx[-1]] = zeros | |
| for i in range(len(label_idx)-1): | |
| label_token_length = label_idx[i+1]-label_idx[i] | |
| if label_token_length <= 0: | |
| print('label_idx', label_idx) | |
| print('question_len', question_len) | |
| continue | |
| ones = np.ones(shape=(label_token_length, label_token_length)) | |
| attention_mask[label_idx[i]:label_idx[i+1], | |
| label_idx[i]:label_idx[i+1]] = ones | |
| return attention_mask | |
| def random_masking(self, token_ids, maks_rate, mask_start_idx, max_length, mask_id, tokenizer): | |
| rands = np.random.random(len(token_ids)) | |
| source, target = [], [] | |
| for i, (r, t) in enumerate(zip(rands, token_ids)): | |
| if i < mask_start_idx: | |
| source.append(t) | |
| target.append(-100) | |
| continue | |
| if r < maks_rate * 0.8: | |
| source.append(mask_id) | |
| target.append(t) | |
| elif r < maks_rate * 0.9: | |
| source.append(t) | |
| target.append(t) | |
| elif r < maks_rate: | |
| source.append(np.random.choice(tokenizer.vocab_size - 1) + 1) | |
| target.append(t) | |
| else: | |
| source.append(t) | |
| target.append(-100) | |
| while len(source) < max_length: | |
| source.append(0) | |
| target.append(-100) | |
| return source[:max_length], target[:max_length] | |
| def encode(self, item, used_mask=False): | |
| while len(self.tokenizer.encode('[MASK]'.join(item['choice']))) > self.max_length-32: | |
| item['choice'] = [c[:int(len(c)/2)] for c in item['choice']] | |
| if 'textb' in item.keys() and item['textb'] != '': | |
| if 'question' in item.keys() and item['question'] != '': | |
| texta = '[MASK]' + '[MASK]'.join(item['choice']) + '[SEP]' + \ | |
| item['question'] + '[SEP]' + \ | |
| item['texta']+'[SEP]'+item['textb'] | |
| else: | |
| texta = '[MASK]' + '[MASK]'.join(item['choice']) + '[SEP]' + \ | |
| item['texta']+'[SEP]'+item['textb'] | |
| else: | |
| if 'question' in item.keys() and item['question'] != '': | |
| texta = '[MASK]' + '[MASK]'.join(item['choice']) + '[SEP]' + \ | |
| item['question'] + '[SEP]' + item['texta'] | |
| else: | |
| texta = '[MASK]' + '[MASK]'.join(item['choice']) + \ | |
| '[SEP]' + item['texta'] | |
| encode_dict = self.tokenizer.encode_plus(texta, | |
| max_length=self.max_length, | |
| padding='max_length', | |
| truncation='longest_first') | |
| encode_sent = encode_dict['input_ids'] | |
| token_type_ids = encode_dict['token_type_ids'] | |
| attention_mask = encode_dict['attention_mask'] | |
| sample_max_length = sum(encode_dict['attention_mask']) | |
| if 'label' not in item.keys(): | |
| item['label'] = 0 | |
| item['answer'] = '' | |
| question_len = 1 | |
| label_idx = [question_len] | |
| for choice in item['choice']: | |
| cur_mask_idx = label_idx[-1] + \ | |
| len(self.tokenizer.encode(choice, add_special_tokens=False))+1 | |
| label_idx.append(cur_mask_idx) | |
| token_type_ids = [0]*question_len+[1] * \ | |
| (label_idx[-1]-label_idx[0]+1)+[0]*self.max_length | |
| token_type_ids = token_type_ids[:self.max_length] | |
| attention_mask = self.get_att_mask( | |
| attention_mask, label_idx, question_len) | |
| position_ids = self.get_position_ids( | |
| label_idx, self.max_length, question_len) | |
| clslabels_mask = np.zeros(shape=(len(encode_sent),)) | |
| clslabels_mask[label_idx[:-1]] = 10000 | |
| clslabels_mask = clslabels_mask-10000 | |
| mlmlabels_mask = np.zeros(shape=(len(encode_sent),)) | |
| mlmlabels_mask[label_idx[0]] = 1 | |
| used_mask = False | |
| if used_mask: | |
| mask_rate = 0.1*np.random.choice(4, p=[0.3, 0.3, 0.25, 0.15]) | |
| source, target = self.random_masking(token_ids=encode_sent, maks_rate=mask_rate, | |
| mask_start_idx=label_idx[-1], max_length=self.max_length, | |
| mask_id=self.tokenizer.mask_token_id, tokenizer=self.tokenizer) | |
| else: | |
| source, target = encode_sent[:], encode_sent[:] | |
| source = np.array(source) | |
| target = np.array(target) | |
| source[label_idx[:-1]] = self.tokenizer.mask_token_id | |
| target[label_idx[:-1]] = self.no_token | |
| target[label_idx[item['label']]] = self.yes_token | |
| input_ids = source[:sample_max_length] | |
| token_type_ids = token_type_ids[:sample_max_length] | |
| attention_mask = attention_mask[:sample_max_length, :sample_max_length] | |
| position_ids = position_ids[:sample_max_length] | |
| mlmlabels = target[:sample_max_length] | |
| clslabels = label_idx[item['label']] | |
| clslabels_mask = clslabels_mask[:sample_max_length] | |
| mlmlabels_mask = mlmlabels_mask[:sample_max_length] | |
| return { | |
| "input_ids": torch.tensor(input_ids).long(), | |
| "token_type_ids": torch.tensor(token_type_ids).long(), | |
| "attention_mask": torch.tensor(attention_mask).float(), | |
| "position_ids": torch.tensor(position_ids).long(), | |
| "mlmlabels": torch.tensor(mlmlabels).long(), | |
| "clslabels": torch.tensor(clslabels).long(), | |
| "clslabels_mask": torch.tensor(clslabels_mask).float(), | |
| "mlmlabels_mask": torch.tensor(mlmlabels_mask).float(), | |
| } | |
| class UniMCDataModel(pl.LightningDataModule): | |
| def add_data_specific_args(parent_args): | |
| parser = parent_args.add_argument_group('TASK NAME DataModel') | |
| parser.add_argument('--num_workers', default=8, type=int) | |
| parser.add_argument('--batchsize', default=16, type=int) | |
| parser.add_argument('--max_length', default=512, type=int) | |
| return parent_args | |
| def __init__(self, train_data, val_data, yes_token, no_token, tokenizer, args): | |
| super().__init__() | |
| self.batchsize = args.batchsize | |
| self.train_data = UniMCDataset( | |
| train_data, yes_token, no_token, tokenizer, args, True) | |
| self.valid_data = UniMCDataset( | |
| val_data, yes_token, no_token, tokenizer, args, False) | |
| def train_dataloader(self): | |
| return DataLoader(self.train_data, shuffle=True, collate_fn=self.collate_fn, batch_size=self.batchsize, pin_memory=False) | |
| def val_dataloader(self): | |
| return DataLoader(self.valid_data, shuffle=False, collate_fn=self.collate_fn, batch_size=self.batchsize, pin_memory=False) | |
| def collate_fn(self, batch): | |
| ''' | |
| Aggregate a batch data. | |
| batch = [ins1_dict, ins2_dict, ..., insN_dict] | |
| batch_data = {'sentence':[ins1_sentence, ins2_sentence...], 'input_ids':[ins1_input_ids, ins2_input_ids...], ...} | |
| ''' | |
| batch_data = {} | |
| for key in batch[0]: | |
| batch_data[key] = [example[key] for example in batch] | |
| batch_data['input_ids'] = nn.utils.rnn.pad_sequence(batch_data['input_ids'], | |
| batch_first=True, | |
| padding_value=0) | |
| batch_data['clslabels_mask'] = nn.utils.rnn.pad_sequence(batch_data['clslabels_mask'], | |
| batch_first=True, | |
| padding_value=-10000) | |
| batch_size, batch_max_length = batch_data['input_ids'].shape | |
| for k, v in batch_data.items(): | |
| if k == 'input_ids' or k == 'clslabels_mask': | |
| continue | |
| if k == 'clslabels': | |
| batch_data[k] = torch.tensor(v).long() | |
| continue | |
| if k != 'attention_mask': | |
| batch_data[k] = nn.utils.rnn.pad_sequence(v, | |
| batch_first=True, | |
| padding_value=0) | |
| else: | |
| attention_mask = torch.zeros( | |
| (batch_size, batch_max_length, batch_max_length)) | |
| for i, att in enumerate(v): | |
| sample_length, _ = att.shape | |
| attention_mask[i, :sample_length, :sample_length] = att | |
| batch_data[k] = attention_mask | |
| return batch_data | |
| class UniMCModel(nn.Module): | |
| def __init__(self, pre_train_dir, yes_token): | |
| super().__init__() | |
| self.config = AutoConfig.from_pretrained(pre_train_dir) | |
| if self.config.model_type == 'megatron-bert': | |
| self.bert = MegatronBertForMaskedLM.from_pretrained(pre_train_dir) | |
| else: | |
| self.bert = BertForMaskedLM.from_pretrained(pre_train_dir) | |
| self.loss_func = torch.nn.CrossEntropyLoss() | |
| self.yes_token = yes_token | |
| def forward(self, input_ids, attention_mask, token_type_ids, position_ids=None, mlmlabels=None, clslabels=None, clslabels_mask=None, mlmlabels_mask=None): | |
| batch_size, seq_len = input_ids.shape | |
| outputs = self.bert(input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| token_type_ids=token_type_ids, | |
| labels=mlmlabels) # (bsz, seq, dim) | |
| mask_loss = outputs.loss | |
| mlm_logits = outputs.logits | |
| cls_logits = mlm_logits[:, :, | |
| self.yes_token].view(-1, seq_len)+clslabels_mask | |
| if mlmlabels == None: | |
| return 0, mlm_logits, cls_logits | |
| else: | |
| cls_loss = self.loss_func(cls_logits, clslabels) | |
| all_loss = mask_loss+cls_loss | |
| return all_loss, mlm_logits, cls_logits | |
| class UniMCLitModel(pl.LightningModule): | |
| def add_model_specific_args(parent_args): | |
| parser = parent_args.add_argument_group('BaseModel') | |
| parser.add_argument('--learning_rate', default=1e-5, type=float) | |
| parser.add_argument('--weight_decay', default=0.1, type=float) | |
| parser.add_argument('--warmup', default=0.01, type=float) | |
| parser.add_argument('--num_labels', default=2, type=int) | |
| return parent_args | |
| def __init__(self, args, yes_token, num_data=100): | |
| super().__init__() | |
| self.args = args | |
| self.num_data = num_data | |
| self.model = UniMCModel(self.args.pretrained_model_path, yes_token) | |
| def setup(self, stage) -> None: | |
| if stage == 'fit': | |
| num_gpus = self.trainer.gpus if self.trainer.gpus is not None else 0 | |
| self.total_step = int(self.trainer.max_epochs * self.num_data / | |
| (max(1, num_gpus) * self.trainer.accumulate_grad_batches)) | |
| print('Total training step:', self.total_step) | |
| def training_step(self, batch, batch_idx): | |
| loss, logits, cls_logits = self.model(**batch) | |
| cls_acc = self.comput_metrix( | |
| cls_logits, batch['clslabels'], batch['mlmlabels_mask']) | |
| self.log('train_loss', loss) | |
| self.log('train_acc', cls_acc) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| loss, logits, cls_logits = self.model(**batch) | |
| cls_acc = self.comput_metrix( | |
| cls_logits, batch['clslabels'], batch['mlmlabels_mask']) | |
| self.log('val_loss', loss) | |
| self.log('val_acc', cls_acc) | |
| def configure_optimizers(self): | |
| no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] | |
| paras = list( | |
| filter(lambda p: p[1].requires_grad, self.named_parameters())) | |
| paras = [{ | |
| 'params': | |
| [p for n, p in paras if not any(nd in n for nd in no_decay)], | |
| 'weight_decay': self.args.weight_decay | |
| }, { | |
| 'params': [p for n, p in paras if any(nd in n for nd in no_decay)], | |
| 'weight_decay': 0.0 | |
| }] | |
| optimizer = torch.optim.AdamW(paras, lr=self.args.learning_rate) | |
| scheduler = get_linear_schedule_with_warmup( | |
| optimizer, int(self.total_step * self.args.warmup), | |
| self.total_step) | |
| return [{ | |
| 'optimizer': optimizer, | |
| 'lr_scheduler': { | |
| 'scheduler': scheduler, | |
| 'interval': 'step', | |
| 'frequency': 1 | |
| } | |
| }] | |
| def comput_metrix(self, logits, labels, mlmlabels_mask): | |
| logits = torch.nn.functional.softmax(logits, dim=-1) | |
| logits = torch.argmax(logits, dim=-1) | |
| y_pred = logits.view(size=(-1,)) | |
| y_true = labels.view(size=(-1,)) | |
| corr = torch.eq(y_pred, y_true).float() | |
| return torch.sum(corr.float())/labels.size(0) | |
| class TaskModelCheckpoint: | |
| def add_argparse_args(parent_args): | |
| parser = parent_args.add_argument_group('BaseModel') | |
| parser.add_argument('--monitor', default='val_acc', type=str) | |
| parser.add_argument('--mode', default='max', type=str) | |
| parser.add_argument('--dirpath', default='./log/', type=str) | |
| parser.add_argument( | |
| '--filename', default='model-{epoch:02d}-{val_acc:.4f}', type=str) | |
| parser.add_argument('--save_top_k', default=3, type=float) | |
| parser.add_argument('--every_n_epochs', default=1, type=float) | |
| parser.add_argument('--every_n_train_steps', default=100, type=float) | |
| parser.add_argument('--save_weights_only', default=True, type=bool) | |
| return parent_args | |
| def __init__(self, args): | |
| self.callbacks = ModelCheckpoint(monitor=args.monitor, | |
| save_top_k=args.save_top_k, | |
| mode=args.mode, | |
| save_last=True, | |
| every_n_train_steps=args.every_n_train_steps, | |
| save_weights_only=args.save_weights_only, | |
| dirpath=args.dirpath, | |
| filename=args.filename) | |
| class UniMCPredict: | |
| def __init__(self, yes_token, no_token, model, tokenizer, args): | |
| self.tokenizer = tokenizer | |
| self.args = args | |
| self.data_model = UniMCDataModel( | |
| [], [], yes_token, no_token, tokenizer, args) | |
| self.model = model | |
| def predict(self, batch_data): | |
| batch = [self.data_model.train_data.encode( | |
| sample) for sample in batch_data] | |
| batch = self.data_model.collate_fn(batch) | |
| batch = {k: v.cuda() for k, v in batch.items()} | |
| _, _, logits = self.model.model(**batch) | |
| soft_logits = torch.nn.functional.softmax(logits, dim=-1) | |
| logits = torch.argmax(soft_logits, dim=-1).detach().cpu().numpy() | |
| soft_logits = soft_logits.detach().cpu().numpy() | |
| clslabels_mask = batch['clslabels_mask'].detach( | |
| ).cpu().numpy().tolist() | |
| clslabels = batch['clslabels'].detach().cpu().numpy().tolist() | |
| for i, v in enumerate(batch_data): | |
| label_idx = [idx for idx, v in enumerate( | |
| clslabels_mask[i]) if v == 0.] | |
| label = label_idx.index(logits[i]) | |
| answer = batch_data[i]['choice'][label] | |
| score = {} | |
| for c in range(len(batch_data[i]['choice'])): | |
| score[batch_data[i]['choice'][c]] = float( | |
| soft_logits[i][label_idx[c]]) | |
| batch_data[i]['label_ori'] = copy.deepcopy(batch_data[i]['label']) | |
| batch_data[i]['label'] = label | |
| batch_data[i]['answer'] = answer | |
| batch_data[i]['score'] = score | |
| return batch_data | |
| class UniMCPipelines: | |
| def pipelines_args(parent_args): | |
| total_parser = parent_args.add_argument_group("pipelines args") | |
| total_parser.add_argument( | |
| '--pretrained_model_path', default='', type=str) | |
| total_parser.add_argument('--load_checkpoints_path', | |
| default='', type=str) | |
| total_parser.add_argument('--train', action='store_true') | |
| total_parser.add_argument('--language', | |
| default='chinese', type=str) | |
| total_parser = UniMCDataModel.add_data_specific_args(total_parser) | |
| total_parser = TaskModelCheckpoint.add_argparse_args(total_parser) | |
| total_parser = UniMCLitModel.add_model_specific_args(total_parser) | |
| total_parser = pl.Trainer.add_argparse_args(parent_args) | |
| return parent_args | |
| def __init__(self, args): | |
| self.args = args | |
| self.checkpoint_callback = TaskModelCheckpoint(args).callbacks | |
| self.logger = loggers.TensorBoardLogger(save_dir=args.default_root_dir) | |
| self.trainer = pl.Trainer.from_argparse_args(args, | |
| logger=self.logger, | |
| callbacks=[self.checkpoint_callback]) | |
| self.config = AutoConfig.from_pretrained(args.pretrained_model_path) | |
| if self.config.model_type == 'albert': | |
| self.tokenizer = AlbertTokenizer.from_pretrained( | |
| args.pretrained_model_path) | |
| else: | |
| if args.language == 'chinese': | |
| self.tokenizer = BertTokenizer.from_pretrained( | |
| args.pretrained_model_path) | |
| else: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| args.pretrained_model_path, is_split_into_words=True, add_prefix_space=True) | |
| if args.language == 'chinese': | |
| self.yes_token = self.tokenizer.encode('是')[1] | |
| self.no_token = self.tokenizer.encode('非')[1] | |
| else: | |
| self.yes_token = self.tokenizer.encode('yes')[1] | |
| self.no_token = self.tokenizer.encode('no')[1] | |
| if args.load_checkpoints_path != '': | |
| self.model = UniMCLitModel.load_from_checkpoint( | |
| args.load_checkpoints_path, args=args, yes_token=self.yes_token) | |
| print('load model from: ', args.load_checkpoints_path) | |
| else: | |
| self.model = UniMCLitModel(args, yes_token=self.yes_token) | |
| def fit(self, train_data, dev_data, process=True): | |
| if process: | |
| train_data = self.preprocess(train_data) | |
| dev_data = self.preprocess(dev_data) | |
| data_model = UniMCDataModel( | |
| train_data, dev_data, self.yes_token, self.no_token, self.tokenizer, self.args) | |
| self.model.num_data = len(train_data) | |
| self.trainer.fit(self.model, data_model) | |
| def predict(self, test_data, cuda=True, process=True): | |
| if process: | |
| test_data = self.preprocess(test_data) | |
| result = [] | |
| start = 0 | |
| if cuda: | |
| self.model = self.model.cuda() | |
| self.model.model.eval() | |
| predict_model = UniMCPredict( | |
| self.yes_token, self.no_token, self.model, self.tokenizer, self.args) | |
| while start < len(test_data): | |
| batch_data = test_data[start:start+self.args.batchsize] | |
| start += self.args.batchsize | |
| batch_result = predict_model.predict(batch_data) | |
| result.extend(batch_result) | |
| if process: | |
| result = self.postprocess(result) | |
| return result | |
| def preprocess(self, data): | |
| for i, line in enumerate(data): | |
| if 'task_type' in line.keys() and line['task_type'] == '语义匹配': | |
| data[i]['choice'] = ['不能理解为:'+data[i] | |
| ['textb'], '可以理解为:'+data[i]['textb']] | |
| # data[i]['question']='怎么理解这段话?' | |
| data[i]['textb'] = '' | |
| if 'task_type' in line.keys() and line['task_type'] == '自然语言推理': | |
| data[i]['choice'] = ['不能推断出:'+data[i]['textb'], | |
| '很难推断出:'+data[i]['textb'], '可以推断出:'+data[i]['textb']] | |
| # data[i]['question']='根据这段话' | |
| data[i]['textb'] = '' | |
| return data | |
| def postprocess(self, data): | |
| for i, line in enumerate(data): | |
| if 'task_type' in line.keys() and line['task_type'] == '语义匹配': | |
| data[i]['textb'] = data[i]['choice'][0].replace('不能理解为:', '') | |
| data[i]['choice'] = ['不相似', '相似'] | |
| ns = {} | |
| for k, v in data[i]['score'].items(): | |
| if '不能' in k: | |
| k = '不相似' | |
| if '可以' in k: | |
| k = '相似' | |
| ns[k] = v | |
| data[i]['score'] = ns | |
| data[i]['answer'] = data[i]['choice'][data[i]['label']] | |
| if 'task_type' in line.keys() and line['task_type'] == '自然语言推理': | |
| data[i]['textb'] = data[i]['choice'][0].replace('不能推断出:', '') | |
| data[i]['choice'] = ['矛盾', '自然', '蕴含'] | |
| ns = {} | |
| for k, v in data[i]['score'].items(): | |
| if '不能' in k: | |
| k = '矛盾' | |
| if '很难' in k: | |
| k = '自然' | |
| if '可以' in k: | |
| k = '蕴含' | |
| ns[k] = v | |
| data[i]['score'] = ns | |
| data[i]['answer'] = data[i]['choice'][data[i]['label']] | |
| return data | |
| def load_data(data_path): | |
| with open(data_path, 'r', encoding='utf8') as f: | |
| lines = f.readlines() | |
| samples = [json.loads(line) for line in tqdm(lines)] | |
| return samples | |
| def comp_acc(pred_data, test_data): | |
| corr = 0 | |
| for i in range(len(pred_data)): | |
| if pred_data[i]['label'] == test_data[i]['label']: | |
| corr += 1 | |
| return corr/len(pred_data) | |
| def load_model(): | |
| total_parser = argparse.ArgumentParser("TASK NAME") | |
| total_parser = UniMCPipelines.pipelines_args(total_parser) | |
| args = total_parser.parse_args() | |
| args.pretrained_model_path = 'IDEA-CCNL/Erlangshen-UniMC-RoBERTa-110M-Chinese' | |
| args.max_length = 512 | |
| args.batchsize = 8 | |
| args.default_root_dir = './' | |
| model = UniMCPipelines(args) | |
| return model | |
| def main(): | |
| model = load_model() | |
| st.subheader("UniMC Zero-shot 体验") | |
| st.info("请输入以下信息...") | |
| sentences = st.text_area("请输入句子:", """彭于晏不着急,胡歌也不着急,他俩都不着急,那我也不着急""") | |
| question = st.text_input("请输入问题(不输入问题也可以):", "请问下面的新闻属于哪个类别?") | |
| choice = st.text_input("输入标签(以中文;分割):", "娱乐;军事;体育;财经") | |
| choice = choice.split(';') | |
| data = [{"texta": sentences, | |
| "textb": "", | |
| "question": question, | |
| "choice": choice, | |
| "answer": "", "label": 0, | |
| "id": 0}] | |
| if st.button("点击一下,开始预测!"): | |
| result = model.predict(data, cuda=False) | |
| st.success("预测成功!") | |
| st.json(result[0]) | |
| else: | |
| st.info( | |
| "**Enter a text** above and **press the button** to predict the category." | |
| ) | |
| if __name__ == "__main__": | |
| main() | |