Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import os | |
import torch.nn as nn | |
from safetensors import safe_open | |
from transformers import BertPreTrainedModel, BertModel, BertTokenizer, BertConfig | |
st.set_page_config(page_title="Paper Classifier", layout="wide") | |
class BERTClass(BertPreTrainedModel): | |
def __init__(self, config, p=0.3): | |
super().__init__(config) | |
self.bert = BertModel(config) | |
self.dropout = nn.Dropout(p) | |
self.linear = nn.Linear(config.hidden_size, config.num_labels) | |
self.init_weights() | |
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None): | |
outputs = self.bert( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
return_dict=True | |
) | |
pooled_output = outputs.pooler_output | |
pooled_output = self.dropout(pooled_output) | |
logits = self.linear(pooled_output) | |
loss = None | |
if labels is not None: | |
loss_fct = nn.BCEWithLogitsLoss() | |
loss = loss_fct(logits, labels) | |
return {"loss": loss, "logits": logits} | |
MODEL_PATH = "." | |
LABELS = ['astro-ph', 'cond-mat', 'cs', 'eess', 'gr-qc', | |
'hep-ex', 'hep-lat', 'hep-ph', 'hep-th', 'math', 'math-ph', 'nlin', | |
'nucl-ex', 'nucl-th', 'physics', 'q-bio', 'quant-ph', 'stat'] | |
MAX_LEN = 512 | |
def load_model(): | |
try: | |
config = BertConfig.from_pretrained("bert-base-cased") | |
config.num_labels = len(LABELS) | |
model = BERTClass(config) | |
with safe_open(f"{MODEL_PATH}/model.safetensors", framework="pt") as f: | |
state_dict = {key: f.get_tensor(key) for key in f.keys()} | |
model.load_state_dict(state_dict) | |
tokenizer = BertTokenizer.from_pretrained("bert-base-cased") | |
return model.eval(), tokenizer | |
except Exception as e: | |
st.error(f"Model loading failed: {str(e)}") | |
st.stop() | |
def predict(title, abstract): | |
if not title.strip() and not abstract.strip(): | |
raise ValueError("Bro, do you want me to guess?) Give me at least the title!") | |
text = f"{title.strip()}. {abstract.strip()}".strip() | |
if len(text) < 10: | |
raise ValueError("Too short text to say anything sensible") | |
device = next(model.parameters()).device | |
inputs = tokenizer.encode_plus( | |
text, | |
max_length=MAX_LEN, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt" | |
).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs['logits'] | |
probs = torch.sigmoid(logits).cpu().numpy()[0] | |
return {label: float(probs[i]) for i, label in enumerate(LABELS)} | |
model, tokenizer = load_model() | |
with st.sidebar: | |
st.header("Display Settings") | |
display_mode = st.radio( | |
"Result filtering mode", | |
["Top-k categories", "Top-% confidence"], | |
index=0 | |
) | |
if display_mode == "Top-k categories": | |
top_k = st.slider( | |
"Number of categories to show", | |
min_value=1, | |
max_value=10, | |
value=3, | |
help="Select how many top categories to display" | |
) | |
else: | |
selected_percent = st.selectbox( | |
"Confidence threshold", | |
["50%", "75%", "95%"], | |
index=2, | |
help="Display categories until reaching this cumulative confidence" | |
) | |
st.title("π Academic Paper Classifier") | |
with st.form("input_form"): | |
title = st.text_input("Paper Title", placeholder="Enter paper title...") | |
abstract = st.text_area("Abstract", placeholder="Paste paper abstract here...", height=200) | |
submitted = st.form_submit_button("Classify") | |
if submitted: | |
with st.spinner("Analyzing paper..."): | |
try: | |
full_predictions = predict(title, abstract) | |
sorted_preds = sorted(full_predictions.items(), | |
key=lambda x: x[1], | |
reverse=True) | |
if display_mode == "Top-k categories": | |
filtered = dict(sorted_preds[:top_k]) | |
else: | |
threshold = {"50%": 0.5, "75%": 0.75, "95%": 0.95}[selected_percent] | |
total = sum(score for _, score in sorted_preds) | |
cumulative = 0 | |
filtered = {} | |
for label, score in sorted_preds: | |
cumulative += score | |
filtered[label] = score | |
if cumulative >= threshold: | |
break | |
if len(filtered) >= 10: | |
break | |
if not filtered: | |
st.warning("No categories meet the selected criteria") | |
else: | |
top_class = max(filtered, key=filtered.get) | |
st.success(f"Most likely category: **{top_class}**") | |
st.subheader("Category Confidence Scores:") | |
total_shown = sum(filtered.values()) | |
for label, score in filtered.items(): | |
relative_score = score / total_shown | |
st.progress( | |
relative_score, | |
text=f"{label}: {score:.1%}" | |
) | |
st.caption(f"Coverage: {sum(filtered.values()):.1%} of total confidence") | |
except Exception as e: | |
st.error(f"Error: {str(e)}") | |
with st.sidebar: | |
st.header("About") | |
st.markdown(f""" | |
This tool predicts the arxiv tag of research papers by their title and abstarct via fine-tuned BERT. | |
- Enter title and abstract | |
- Enjoy the magnificent classification results | |
""") |