Stanpie3's picture
Update app.py
68cbe43 verified
import streamlit as st
import torch
import os
import torch.nn as nn
from safetensors import safe_open
from transformers import BertPreTrainedModel, BertModel, BertTokenizer, BertConfig
st.set_page_config(page_title="Paper Classifier", layout="wide")
class BERTClass(BertPreTrainedModel):
def __init__(self, config, p=0.3):
super().__init__(config)
self.bert = BertModel(config)
self.dropout = nn.Dropout(p)
self.linear = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.linear(pooled_output)
loss = None
if labels is not None:
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
return {"loss": loss, "logits": logits}
MODEL_PATH = "."
LABELS = ['astro-ph', 'cond-mat', 'cs', 'eess', 'gr-qc',
'hep-ex', 'hep-lat', 'hep-ph', 'hep-th', 'math', 'math-ph', 'nlin',
'nucl-ex', 'nucl-th', 'physics', 'q-bio', 'quant-ph', 'stat']
MAX_LEN = 512
@st.cache_resource
def load_model():
try:
config = BertConfig.from_pretrained("bert-base-cased")
config.num_labels = len(LABELS)
model = BERTClass(config)
with safe_open(f"{MODEL_PATH}/model.safetensors", framework="pt") as f:
state_dict = {key: f.get_tensor(key) for key in f.keys()}
model.load_state_dict(state_dict)
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
return model.eval(), tokenizer
except Exception as e:
st.error(f"Model loading failed: {str(e)}")
st.stop()
@st.cache_data
def predict(title, abstract):
if not title.strip() and not abstract.strip():
raise ValueError("Bro, do you want me to guess?) Give me at least the title!")
text = f"{title.strip()}. {abstract.strip()}".strip()
if len(text) < 10:
raise ValueError("Too short text to say anything sensible")
device = next(model.parameters()).device
inputs = tokenizer.encode_plus(
text,
max_length=MAX_LEN,
padding="max_length",
truncation=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs['logits']
probs = torch.sigmoid(logits).cpu().numpy()[0]
return {label: float(probs[i]) for i, label in enumerate(LABELS)}
model, tokenizer = load_model()
with st.sidebar:
st.header("Display Settings")
display_mode = st.radio(
"Result filtering mode",
["Top-k categories", "Top-% confidence"],
index=0
)
if display_mode == "Top-k categories":
top_k = st.slider(
"Number of categories to show",
min_value=1,
max_value=10,
value=3,
help="Select how many top categories to display"
)
else:
selected_percent = st.selectbox(
"Confidence threshold",
["50%", "75%", "95%"],
index=2,
help="Display categories until reaching this cumulative confidence"
)
st.title("πŸ“„ Academic Paper Classifier")
with st.form("input_form"):
title = st.text_input("Paper Title", placeholder="Enter paper title...")
abstract = st.text_area("Abstract", placeholder="Paste paper abstract here...", height=200)
submitted = st.form_submit_button("Classify")
if submitted:
with st.spinner("Analyzing paper..."):
try:
full_predictions = predict(title, abstract)
sorted_preds = sorted(full_predictions.items(),
key=lambda x: x[1],
reverse=True)
if display_mode == "Top-k categories":
filtered = dict(sorted_preds[:top_k])
else:
threshold = {"50%": 0.5, "75%": 0.75, "95%": 0.95}[selected_percent]
total = sum(score for _, score in sorted_preds)
cumulative = 0
filtered = {}
for label, score in sorted_preds:
cumulative += score
filtered[label] = score
if cumulative >= threshold:
break
if len(filtered) >= 10:
break
if not filtered:
st.warning("No categories meet the selected criteria")
else:
top_class = max(filtered, key=filtered.get)
st.success(f"Most likely category: **{top_class}**")
st.subheader("Category Confidence Scores:")
total_shown = sum(filtered.values())
for label, score in filtered.items():
relative_score = score / total_shown
st.progress(
relative_score,
text=f"{label}: {score:.1%}"
)
st.caption(f"Coverage: {sum(filtered.values()):.1%} of total confidence")
except Exception as e:
st.error(f"Error: {str(e)}")
with st.sidebar:
st.header("About")
st.markdown(f"""
This tool predicts the arxiv tag of research papers by their title and abstarct via fine-tuned BERT.
- Enter title and abstract
- Enjoy the magnificent classification results
""")