Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| import copy, math | |
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from vocab import PepVocab | |
| def create_vocab(): | |
| vocab_mlm = PepVocab() | |
| vocab_mlm.vocab_from_txt('vocab.txt') | |
| # vocab_mlm.token_to_idx['-'] = 23 | |
| return vocab_mlm | |
| def show_parameters(model: nn.Module, show_all=False, show_trainable=True): | |
| mlp_pa = {name:param.requires_grad for name, param in model.named_parameters()} | |
| if show_all: | |
| print('All parameters:') | |
| print(mlp_pa) | |
| if show_trainable: | |
| print('Trainable parameters:') | |
| print(list(filter(lambda x: x[1], list(mlp_pa.items())))) | |
| class ContraLoss(nn.Module): | |
| def __init__(self, *args, **kwargs) -> None: | |
| super(ContraLoss, self).__init__(*args, **kwargs) | |
| self.temp = 0.07 | |
| def contrastive_loss(self, proj1, proj2): | |
| proj1 = F.normalize(proj1, dim=1) | |
| proj2 = F.normalize(proj2, dim=1) | |
| dot = torch.matmul(proj1, proj2.T) / self.temp | |
| dot_max, _ = torch.max(dot, dim=1, keepdim=True) | |
| dot = dot - dot_max.detach() | |
| exp_dot = torch.exp(dot) | |
| log_prob = torch.diag(dot, 0) - torch.log(exp_dot.sum(1)) | |
| cont_loss = -log_prob.mean() | |
| return cont_loss | |
| def forward(self, x, y, label=None): | |
| return self.contrastive_loss(x, y) | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch | |
| import torch.nn as nn | |
| import random | |
| from transformers import set_seed | |
| def show_parameters(model: nn.Module, show_all=False, show_trainable=True): | |
| mlp_pa = {name:param.requires_grad for name, param in model.named_parameters()} | |
| if show_all: | |
| print('All parameters:') | |
| print(mlp_pa) | |
| if show_trainable: | |
| print('Trainable parameters:') | |
| print(list(filter(lambda x: x[1], list(mlp_pa.items())))) | |
| def extract_args(text): | |
| str_list = [] | |
| substr = "" | |
| for s in text: | |
| if s in ('(', ')', '=', ',', ' ', '\n', "'"): | |
| if substr != '': | |
| str_list.append(substr) | |
| substr = '' | |
| else: | |
| substr += s | |
| def eval_one_epoch(loader, cono_encoder): | |
| cono_encoder.eval() | |
| batch_loss = [] | |
| for i, data in enumerate(tqdm(loader)): | |
| loss = cono_encoder.contra_forward(data) | |
| batch_loss.append(loss.item()) | |
| print(f'[INFO] Test batch {i} loss: {loss.item()}') | |
| total_loss = np.mean(batch_loss) | |
| print(f'[INFO] Total loss: {total_loss}') | |
| return total_loss | |
| def setup_seed(seed): | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| set_seed(seed) | |
| class CrossEntropyLossWithMask(torch.nn.Module): | |
| def __init__(self, weight=None): | |
| super(CrossEntropyLossWithMask, self).__init__() | |
| self.criterion = nn.CrossEntropyLoss(reduction='none') | |
| def forward(self, y_pred, y_true, mask): | |
| (pos_mask, label_mask, seq_mask) = mask | |
| loss = self.criterion(y_pred, y_true) # (6912) | |
| pos_loss = (loss * pos_mask).sum() / torch.sum(pos_mask) | |
| label_loss = (loss * label_mask).sum() / torch.sum(label_mask) | |
| seq_loss = (loss * seq_mask).sum() / torch.sum(seq_mask) | |
| loss = pos_loss + label_loss/2 + seq_loss/3 | |
| return loss | |
| def mask(x, start, end, time): | |
| ske_pos = np.where(np.array(x)=='C')[0] - start | |
| lables_pos = np.array([1, 2]) - start | |
| ske_pos = list(filter(lambda x: end-start >= x >= 0, ske_pos)) | |
| lables_pos = list(filter(lambda x: x >= 0, lables_pos)) | |
| weight = np.ones(end - start+1) | |
| rand = np.random.rand() | |
| if rand < 0.5: | |
| weight[lables_pos] = 100000 | |
| else: | |
| weight[lables_pos] = 1 | |
| mask_pos = np.random.choice(range(start, end+1), time, p=weight/np.sum(weight), replace=False) | |
| for idx in mask_pos: | |
| x[idx] = '[MASK]' | |
| return x |