Spaces:
Running
Running
File size: 2,102 Bytes
9a179e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import torch
import torch.nn as nn
import numpy as np
import streamlit as st
from transformers import DistilBertModel, DistilBertTokenizerFast
TARGET_IND2LABEL = {
0: 'Computer Science',
1: 'Economics',
2: 'Electrical Engineering and Systems Science',
3: 'Mathematics',
4: 'Physics',
5: 'Quantitative Biology',
6: 'Quantitative Finance',
7: 'Statistics',
}
class DistilBERTClassifier(nn.Module):
def __init__(self, num_classes=8):
super().__init__()
self.encoder = DistilBertModel.from_pretrained("distilbert-base-cased")
self.pre_classifier = nn.Linear(768, 768)
self.gelu = nn.GELU()
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(768, num_classes)
def forward(self, input_ids, attention_mask, labels):
output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
hidden_state = output[0]
pooler = hidden_state[:, 0]
pooler = self.dropout(self.gelu(self.pre_classifier(pooler)))
preds = self.classifier(pooler)
return preds
@st.cache_resource
def load_tokenizer():
return DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
@st.cache_resource
def load_model(device):
model = torch.load('model.pt', map_location=torch.device('cpu')).to(device)
model.eval()
return model
def get_verdict(preds):
inds = np.argsort(preds)[::-1]
sum_prob = 0.0
verdict = []
for ind in inds:
prob = preds[ind]
sum_prob += prob
verdict.append(f"{TARGET_IND2LABEL[ind]}: {prob}")
if (sum_prob >= 0.95):
break
return "\n\n".join(verdict)
def get_preds(text, model, tokenizer, device):
tokens = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
tokens['input_ids'] = tokens['input_ids'].to(device)
tokens['attention_mask'] = tokens['attention_mask'].to(device)
tokens['labels'] = None # made for training convinience
with torch.no_grad():
preds = torch.softmax(model(**tokens)[0], 0).cpu().numpy()
return preds
|