Stanpie3 commited on
Commit
1432a8c
·
verified ·
1 Parent(s): d596d88

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -0
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import os
4
+ import torch.nn as nn
5
+ from safetensors import safe_open
6
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizer, BertConfig
7
+
8
+ st.set_page_config(page_title="Paper Classifier", layout="wide")
9
+
10
+ class BERTClass(BertPreTrainedModel):
11
+ def __init__(self, config, p=0.3):
12
+ super().__init__(config)
13
+ self.bert = BertModel(config)
14
+ self.dropout = nn.Dropout(p)
15
+ self.linear = nn.Linear(config.hidden_size, config.num_labels)
16
+ self.init_weights()
17
+
18
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
19
+ outputs = self.bert(
20
+ input_ids,
21
+ attention_mask=attention_mask,
22
+ token_type_ids=token_type_ids,
23
+ return_dict=True
24
+ )
25
+ pooled_output = outputs.pooler_output
26
+ pooled_output = self.dropout(pooled_output)
27
+ logits = self.linear(pooled_output)
28
+ loss = None
29
+ if labels is not None:
30
+ loss_fct = nn.BCEWithLogitsLoss()
31
+ loss = loss_fct(logits, labels)
32
+ return {"loss": loss, "logits": logits}
33
+
34
+ MODEL_PATH = "./"
35
+ LABELS = ['astro-ph', 'cond-mat', 'cs', 'eess', 'gr-qc',
36
+ 'hep-ex', 'hep-lat', 'hep-ph', 'hep-th', 'math', 'math-ph', 'nlin',
37
+ 'nucl-ex', 'nucl-th', 'physics', 'q-bio', 'quant-ph', 'stat']
38
+ MAX_LEN = 512
39
+
40
+ @st.cache_resource
41
+ def load_model():
42
+ try:
43
+ config = BertConfig.from_pretrained("bert-base-cased")
44
+ config.num_labels = len(LABELS)
45
+ model = BERTClass(config)
46
+
47
+ with safe_open(f"{MODEL_PATH}/model.safetensors", framework="pt") as f:
48
+ state_dict = {key: f.get_tensor(key) for key in f.keys()}
49
+
50
+ model.load_state_dict(state_dict)
51
+ tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
52
+ return model.eval(), tokenizer
53
+
54
+ except Exception as e:
55
+ st.error(f"Model loading failed: {str(e)}")
56
+ st.stop()
57
+
58
+
59
+ @st.cache_data
60
+ def predict(title, abstract):
61
+ if not title.strip() and not abstract.strip():
62
+ raise ValueError("Bro, do you want me to guess?) Give me at least the title!")
63
+
64
+ text = f"{title.strip()}. {abstract.strip()}".strip()
65
+ if len(text) < 10:
66
+ raise ValueError("Too short text to say anything sensible")
67
+
68
+ device = next(model.parameters()).device
69
+ inputs = tokenizer.encode_plus(
70
+ text,
71
+ max_length=MAX_LEN,
72
+ padding="max_length",
73
+ truncation=True,
74
+ return_tensors="pt"
75
+ ).to(device)
76
+
77
+ with torch.no_grad():
78
+ outputs = model(**inputs)
79
+
80
+ logits = outputs['logits']
81
+ probs = torch.sigmoid(logits).cpu().numpy()[0]
82
+ return {label: float(probs[i]) for i, label in enumerate(LABELS)}
83
+
84
+ model, tokenizer = load_model()
85
+
86
+ with st.sidebar:
87
+ st.header("Display Settings")
88
+ display_mode = st.radio(
89
+ "Result filtering mode",
90
+ ["Top-k categories", "Top-% confidence"],
91
+ index=0
92
+ )
93
+
94
+ if display_mode == "Top-k categories":
95
+ top_k = st.slider(
96
+ "Number of categories to show",
97
+ min_value=1,
98
+ max_value=10,
99
+ value=3,
100
+ help="Select how many top categories to display"
101
+ )
102
+ else:
103
+ selected_percent = st.selectbox(
104
+ "Confidence threshold",
105
+ ["50%", "75%", "95%"],
106
+ index=2,
107
+ help="Display categories until reaching this cumulative confidence"
108
+ )
109
+
110
+ st.markdown(f"""
111
+ This tool predicts the academic category of research papers using AI.
112
+ """)
113
+
114
+ st.title("📄 Academic Paper Classifier")
115
+
116
+ with st.form("input_form"):
117
+ title = st.text_input("Paper Title", placeholder="Enter paper title...")
118
+ abstract = st.text_area("Abstract", placeholder="Paste paper abstract here...", height=200)
119
+ submitted = st.form_submit_button("Classify")
120
+
121
+ if submitted:
122
+ with st.spinner("Analyzing paper..."):
123
+ try:
124
+ full_predictions = predict(title, abstract)
125
+ sorted_preds = sorted(full_predictions.items(),
126
+ key=lambda x: x[1],
127
+ reverse=True)
128
+
129
+ if display_mode == "Top-k categories":
130
+ filtered = dict(sorted_preds[:top_k])
131
+ else:
132
+ threshold = {"50%": 0.5, "75%": 0.75, "95%": 0.95}[selected_percent]
133
+ total = sum(score for _, score in sorted_preds)
134
+ cumulative = 0
135
+ filtered = {}
136
+
137
+ for label, score in sorted_preds:
138
+ cumulative += score
139
+ filtered[label] = score
140
+ if cumulative >= threshold:
141
+ break
142
+ if len(filtered) >= 10:
143
+ break
144
+
145
+ if not filtered:
146
+ st.warning("No categories meet the selected criteria")
147
+ else:
148
+ top_class = max(filtered, key=filtered.get)
149
+ st.success(f"Most likely category: **{top_class}**")
150
+
151
+ st.subheader("Category Confidence Scores:")
152
+ total_shown = sum(filtered.values())
153
+
154
+ for label, score in filtered.items():
155
+ relative_score = score / total_shown
156
+ st.progress(
157
+ relative_score,
158
+ text=f"{label}: {score:.1%}"
159
+ )
160
+
161
+ st.caption(f"Coverage: {sum(filtered.values()):.1%} of total confidence")
162
+
163
+ except Exception as e:
164
+ st.error(f"Error: {str(e)}")
165
+
166
+ with st.sidebar:
167
+ st.header("About")
168
+ st.markdown(f"""
169
+ This tool predicts the arxiv of research papers by their title and abstarct via fine-tuned BERT.
170
+ - Enter title and abstract
171
+ - Enjoy the magnificent classification results
172
+ """)