daniilkk commited on
Commit
d5fc19a
·
verified ·
1 Parent(s): 2fdf707

add app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import pandas as pd
4
+ from transformers import DistilBertTokenizer, DistilBertConfig, DistilBertModel
5
+
6
+ from .torch_primitives import PaperClassifierV1, PaperClassifierDatasetV1
7
+
8
+
9
+ @st.cache_resource
10
+ def load_everything():
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
13
+
14
+ # DistilBertTokenizer.from_pretrained('distilbert-base-uncased') doesn't work from my laptop, but we don't need
15
+ # that checkpoint anymore so we will use this class instead.
16
+ class EmptyPaperClassifier(PaperClassifierV1):
17
+ def __init__(self, n_classes):
18
+ super(PaperClassifierV1, self).__init__()
19
+ self.backbone = DistilBertModel(DistilBertConfig())
20
+ self.head = torch.nn.Linear(in_features=self.backbone.config.hidden_size, out_features=n_classes)
21
+
22
+ model = EmptyPaperClassifier(n_classes=len(PaperClassifierDatasetV1.MAJORS))
23
+ model.load_state_dict(torch.load('best_model.pt', map_location=device))
24
+ model.to(device)
25
+ model.eval()
26
+
27
+ return model, tokenizer, device
28
+
29
+
30
+ def classify_paper(title, abstract, model, tokenizer, device):
31
+ if abstract.strip() == "":
32
+ inputs = tokenizer(
33
+ title,
34
+ padding=True,
35
+ truncation=True,
36
+ max_length=512,
37
+ return_tensors='pt'
38
+ )
39
+ else:
40
+ inputs = tokenizer(
41
+ [title],
42
+ [abstract],
43
+ padding=True,
44
+ truncation=True,
45
+ max_length=512,
46
+ return_tensors='pt'
47
+ )
48
+
49
+ inputs = {k: v.to(device) for k, v in inputs.items()}
50
+
51
+ with torch.no_grad():
52
+ outputs = model(**inputs)
53
+ probabilities = torch.sigmoid(outputs).cpu().numpy()[0]
54
+
55
+ return pd.DataFrame({
56
+ 'Category': PaperClassifierDatasetV1.MAJORS,
57
+ 'Probability': probabilities
58
+ }).sort_values('Probability', ascending=False)
59
+
60
+
61
+ def main(threshold: float = 0.5):
62
+ st.set_page_config(page_title="ArXiv Paper Classifier", page_icon="🦈")
63
+ st.title("ArXiv Paper Classifier")
64
+
65
+ model, tokenizer, device = load_everything()
66
+
67
+ col1, col2 = st.columns([1, 1])
68
+ with col1:
69
+ title = st.text_area("Title", height=200, placeholder="Enter paper title here...", )
70
+ with col2:
71
+ abstract = st.text_area("Abstract (optional)", height=200, placeholder="Enter paper abstract here...")
72
+
73
+ if st.button("Classify", type='primary', use_container_width=True):
74
+ if not title:
75
+ st.error("Please enter a paper title")
76
+ else:
77
+ with st.spinner('In progress...'):
78
+ results = classify_paper(title, abstract, model, tokenizer, device)
79
+
80
+ st.subheader("Results")
81
+
82
+ predicted = results[results['Probability'] > threshold]['Category'].tolist()
83
+ results['Probability'] = results['Probability'].apply(lambda x: f"{x:.2%}")
84
+
85
+ if len(predicted) == 0:
86
+ st.info("Hmm, I am not sure about this one.")
87
+ else:
88
+ st.success(f"Predicted categories: {', '.join(predicted)}")
89
+
90
+ with st.expander("Show details"):
91
+ st.dataframe(results, use_container_width=True, hide_index=True)
92
+ st.caption("All categories with their confidence scores")
93
+
94
+ if __name__ == "__main__":
95
+ main()