daniilkk commited on
Commit
0edc174
·
verified ·
1 Parent(s): 5ab1ab8

Upload 2 files

Browse files
Files changed (2) hide show
  1. best_model.pt +3 -0
  2. torch_primitives.py +74 -0
best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb98898cc875973d1f028b523405feea9140410d2895dec8dfb675e68c2d08aa
3
+ size 265516836
torch_primitives.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import torch.nn as nn
6
+
7
+ from transformers import DistilBertTokenizer, DistilBertModel
8
+
9
+
10
+ class PaperClassifierDatasetV1(Dataset):
11
+ MAJORS = ('cs', 'math', 'physics', 'q-bio', 'q-fin', 'stat', 'econ', 'eess')
12
+ def __init__(self, csv_path: str, no_abstract_proba: float = 0., n_samples: int = 0):
13
+ super().__init__()
14
+ self.major_to_idx = {major : idx for idx, major in enumerate(self.MAJORS)}
15
+ self.n_classes = len(self.MAJORS)
16
+
17
+ self._tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
18
+
19
+ self.df = pd.read_csv(csv_path)
20
+
21
+ if n_samples == 0:
22
+ n_samples = self.df.shape[0]
23
+
24
+ self.x = self._tokenizer(
25
+ list(zip(self.df['title'], self.df['abstract']))[:n_samples],
26
+ padding=True,
27
+ truncation=True,
28
+ max_length=512,
29
+ return_tensors='pt'
30
+ )
31
+
32
+ self.y = torch.zeros((n_samples, len(self.MAJORS)))
33
+ for row_idx, majors in enumerate(self.df['majors'][:n_samples]):
34
+ majors = eval(majors)
35
+ col_idxs = [self.major_to_idx[major] for major in majors]
36
+ self.y[row_idx, col_idxs] = 1
37
+
38
+ self.sep_token_id = self._tokenizer.sep_token_id
39
+ self.sep_positions = list()
40
+ for row_idx in range(len(self.x['input_ids'])):
41
+ input_ids = self.x['input_ids'][row_idx]
42
+ sep_pos = (input_ids == self.sep_token_id).nonzero(as_tuple=True)[0][0]
43
+ self.sep_positions.append(sep_pos)
44
+
45
+ self.no_abstract_proba = no_abstract_proba
46
+
47
+ def __getitem__(self, index: int):
48
+ input_ids = self.x['input_ids'][index, ...]
49
+ attention_mask = self.x['attention_mask'][index, ...].clone()
50
+
51
+ if torch.rand(1).item() < self.no_abstract_proba:
52
+ sep_pos = self.sep_positions[index]
53
+ attention_mask[sep_pos+1:] = 0
54
+
55
+ return {
56
+ 'input_ids': input_ids,
57
+ 'attention_mask': attention_mask,
58
+ 'target': self.y[index, ...]
59
+ }
60
+
61
+ def __len__(self):
62
+ return self.x['input_ids'].shape[0]
63
+
64
+
65
+ class PaperClassifierV1(nn.Module):
66
+ def __init__(self, n_classes: int):
67
+ super().__init__()
68
+ self.backbone = DistilBertModel.from_pretrained("distilbert-base-uncased")
69
+ self.head = nn.Linear(in_features=self.backbone.config.hidden_size, out_features=n_classes)
70
+
71
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
72
+ backbone_output = self.backbone(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state']
73
+ logits = self.head(backbone_output[:, 0, ...])
74
+ return logits