Spaces:
Sleeping
Sleeping
import pandas as pd | |
import torch | |
from torch.utils.data import Dataset | |
import torch.nn as nn | |
from transformers import DistilBertTokenizer, DistilBertModel | |
class PaperClassifierDatasetV1(Dataset): | |
MAJORS = ('cs', 'math', 'physics', 'q-bio', 'q-fin', 'stat', 'econ', 'eess') | |
def __init__(self, csv_path: str, no_abstract_proba: float = 0., n_samples: int = 0): | |
super().__init__() | |
self.major_to_idx = {major : idx for idx, major in enumerate(self.MAJORS)} | |
self.n_classes = len(self.MAJORS) | |
self._tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') | |
self.df = pd.read_csv(csv_path) | |
if n_samples == 0: | |
n_samples = self.df.shape[0] | |
self.x = self._tokenizer( | |
list(zip(self.df['title'], self.df['abstract']))[:n_samples], | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors='pt' | |
) | |
self.y = torch.zeros((n_samples, len(self.MAJORS))) | |
for row_idx, majors in enumerate(self.df['majors'][:n_samples]): | |
majors = eval(majors) | |
col_idxs = [self.major_to_idx[major] for major in majors] | |
self.y[row_idx, col_idxs] = 1 | |
self.sep_token_id = self._tokenizer.sep_token_id | |
self.sep_positions = list() | |
for row_idx in range(len(self.x['input_ids'])): | |
input_ids = self.x['input_ids'][row_idx] | |
sep_pos = (input_ids == self.sep_token_id).nonzero(as_tuple=True)[0][0] | |
self.sep_positions.append(sep_pos) | |
self.no_abstract_proba = no_abstract_proba | |
def __getitem__(self, index: int): | |
input_ids = self.x['input_ids'][index, ...] | |
attention_mask = self.x['attention_mask'][index, ...].clone() | |
if torch.rand(1).item() < self.no_abstract_proba: | |
sep_pos = self.sep_positions[index] | |
attention_mask[sep_pos+1:] = 0 | |
return { | |
'input_ids': input_ids, | |
'attention_mask': attention_mask, | |
'target': self.y[index, ...] | |
} | |
def __len__(self): | |
return self.x['input_ids'].shape[0] | |
class PaperClassifierV1(nn.Module): | |
def __init__(self, n_classes: int): | |
super().__init__() | |
self.backbone = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
self.head = nn.Linear(in_features=self.backbone.config.hidden_size, out_features=n_classes) | |
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): | |
backbone_output = self.backbone(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state'] | |
logits = self.head(backbone_output[:, 0, ...]) | |
return logits | |