arxiv_papers_classifier / torch_primitives.py
daniilkk's picture
Upload 2 files
0edc174 verified
import pandas as pd
import torch
from torch.utils.data import Dataset
import torch.nn as nn
from transformers import DistilBertTokenizer, DistilBertModel
class PaperClassifierDatasetV1(Dataset):
MAJORS = ('cs', 'math', 'physics', 'q-bio', 'q-fin', 'stat', 'econ', 'eess')
def __init__(self, csv_path: str, no_abstract_proba: float = 0., n_samples: int = 0):
super().__init__()
self.major_to_idx = {major : idx for idx, major in enumerate(self.MAJORS)}
self.n_classes = len(self.MAJORS)
self._tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
self.df = pd.read_csv(csv_path)
if n_samples == 0:
n_samples = self.df.shape[0]
self.x = self._tokenizer(
list(zip(self.df['title'], self.df['abstract']))[:n_samples],
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)
self.y = torch.zeros((n_samples, len(self.MAJORS)))
for row_idx, majors in enumerate(self.df['majors'][:n_samples]):
majors = eval(majors)
col_idxs = [self.major_to_idx[major] for major in majors]
self.y[row_idx, col_idxs] = 1
self.sep_token_id = self._tokenizer.sep_token_id
self.sep_positions = list()
for row_idx in range(len(self.x['input_ids'])):
input_ids = self.x['input_ids'][row_idx]
sep_pos = (input_ids == self.sep_token_id).nonzero(as_tuple=True)[0][0]
self.sep_positions.append(sep_pos)
self.no_abstract_proba = no_abstract_proba
def __getitem__(self, index: int):
input_ids = self.x['input_ids'][index, ...]
attention_mask = self.x['attention_mask'][index, ...].clone()
if torch.rand(1).item() < self.no_abstract_proba:
sep_pos = self.sep_positions[index]
attention_mask[sep_pos+1:] = 0
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'target': self.y[index, ...]
}
def __len__(self):
return self.x['input_ids'].shape[0]
class PaperClassifierV1(nn.Module):
def __init__(self, n_classes: int):
super().__init__()
self.backbone = DistilBertModel.from_pretrained("distilbert-base-uncased")
self.head = nn.Linear(in_features=self.backbone.config.hidden_size, out_features=n_classes)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
backbone_output = self.backbone(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state']
logits = self.head(backbone_output[:, 0, ...])
return logits