Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- best_model.pt +3 -0
- torch_primitives.py +74 -0
best_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fb98898cc875973d1f028b523405feea9140410d2895dec8dfb675e68c2d08aa
|
3 |
+
size 265516836
|
torch_primitives.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from transformers import DistilBertTokenizer, DistilBertModel
|
8 |
+
|
9 |
+
|
10 |
+
class PaperClassifierDatasetV1(Dataset):
|
11 |
+
MAJORS = ('cs', 'math', 'physics', 'q-bio', 'q-fin', 'stat', 'econ', 'eess')
|
12 |
+
def __init__(self, csv_path: str, no_abstract_proba: float = 0., n_samples: int = 0):
|
13 |
+
super().__init__()
|
14 |
+
self.major_to_idx = {major : idx for idx, major in enumerate(self.MAJORS)}
|
15 |
+
self.n_classes = len(self.MAJORS)
|
16 |
+
|
17 |
+
self._tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
18 |
+
|
19 |
+
self.df = pd.read_csv(csv_path)
|
20 |
+
|
21 |
+
if n_samples == 0:
|
22 |
+
n_samples = self.df.shape[0]
|
23 |
+
|
24 |
+
self.x = self._tokenizer(
|
25 |
+
list(zip(self.df['title'], self.df['abstract']))[:n_samples],
|
26 |
+
padding=True,
|
27 |
+
truncation=True,
|
28 |
+
max_length=512,
|
29 |
+
return_tensors='pt'
|
30 |
+
)
|
31 |
+
|
32 |
+
self.y = torch.zeros((n_samples, len(self.MAJORS)))
|
33 |
+
for row_idx, majors in enumerate(self.df['majors'][:n_samples]):
|
34 |
+
majors = eval(majors)
|
35 |
+
col_idxs = [self.major_to_idx[major] for major in majors]
|
36 |
+
self.y[row_idx, col_idxs] = 1
|
37 |
+
|
38 |
+
self.sep_token_id = self._tokenizer.sep_token_id
|
39 |
+
self.sep_positions = list()
|
40 |
+
for row_idx in range(len(self.x['input_ids'])):
|
41 |
+
input_ids = self.x['input_ids'][row_idx]
|
42 |
+
sep_pos = (input_ids == self.sep_token_id).nonzero(as_tuple=True)[0][0]
|
43 |
+
self.sep_positions.append(sep_pos)
|
44 |
+
|
45 |
+
self.no_abstract_proba = no_abstract_proba
|
46 |
+
|
47 |
+
def __getitem__(self, index: int):
|
48 |
+
input_ids = self.x['input_ids'][index, ...]
|
49 |
+
attention_mask = self.x['attention_mask'][index, ...].clone()
|
50 |
+
|
51 |
+
if torch.rand(1).item() < self.no_abstract_proba:
|
52 |
+
sep_pos = self.sep_positions[index]
|
53 |
+
attention_mask[sep_pos+1:] = 0
|
54 |
+
|
55 |
+
return {
|
56 |
+
'input_ids': input_ids,
|
57 |
+
'attention_mask': attention_mask,
|
58 |
+
'target': self.y[index, ...]
|
59 |
+
}
|
60 |
+
|
61 |
+
def __len__(self):
|
62 |
+
return self.x['input_ids'].shape[0]
|
63 |
+
|
64 |
+
|
65 |
+
class PaperClassifierV1(nn.Module):
|
66 |
+
def __init__(self, n_classes: int):
|
67 |
+
super().__init__()
|
68 |
+
self.backbone = DistilBertModel.from_pretrained("distilbert-base-uncased")
|
69 |
+
self.head = nn.Linear(in_features=self.backbone.config.hidden_size, out_features=n_classes)
|
70 |
+
|
71 |
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
|
72 |
+
backbone_output = self.backbone(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state']
|
73 |
+
logits = self.head(backbone_output[:, 0, ...])
|
74 |
+
return logits
|