File size: 2,102 Bytes
9a179e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn
import numpy as np
import streamlit as st
from transformers import DistilBertModel, DistilBertTokenizerFast


TARGET_IND2LABEL = {
    0: 'Computer Science',
    1: 'Economics',
    2: 'Electrical Engineering and Systems Science',
    3: 'Mathematics',
    4: 'Physics',
    5: 'Quantitative Biology',
    6: 'Quantitative Finance',
    7: 'Statistics',
}

class DistilBERTClassifier(nn.Module):
    def __init__(self, num_classes=8):
        super().__init__()
        self.encoder = DistilBertModel.from_pretrained("distilbert-base-cased")
        self.pre_classifier = nn.Linear(768, 768)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(768, num_classes)

    def forward(self, input_ids, attention_mask, labels):
        output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output[0]
        pooler = hidden_state[:, 0]
        pooler = self.dropout(self.gelu(self.pre_classifier(pooler)))
        preds = self.classifier(pooler)
        return preds
    
@st.cache_resource
def load_tokenizer():
    return DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')

@st.cache_resource
def load_model(device):
    model = torch.load('model.pt', map_location=torch.device('cpu')).to(device)
    model.eval()
    return model

def get_verdict(preds):
    inds = np.argsort(preds)[::-1]
    sum_prob = 0.0
    verdict = []
    for ind in inds:
        prob = preds[ind]
        sum_prob += prob
        verdict.append(f"{TARGET_IND2LABEL[ind]}: {prob}")
        if (sum_prob >= 0.95):
            break
    return "\n\n".join(verdict)

def get_preds(text, model, tokenizer, device):
    tokens = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
    tokens['input_ids'] = tokens['input_ids'].to(device)
    tokens['attention_mask'] = tokens['attention_mask'].to(device)
    tokens['labels'] = None # made for training convinience
    with torch.no_grad():
        preds = torch.softmax(model(**tokens)[0], 0).cpu().numpy()
    return preds