Spaces:
Sleeping
Sleeping
File size: 5,774 Bytes
1432a8c a5f1ae8 1432a8c 68cbe43 1432a8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import streamlit as st
import torch
import os
import torch.nn as nn
from safetensors import safe_open
from transformers import BertPreTrainedModel, BertModel, BertTokenizer, BertConfig
st.set_page_config(page_title="Paper Classifier", layout="wide")
class BERTClass(BertPreTrainedModel):
def __init__(self, config, p=0.3):
super().__init__(config)
self.bert = BertModel(config)
self.dropout = nn.Dropout(p)
self.linear = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.linear(pooled_output)
loss = None
if labels is not None:
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
return {"loss": loss, "logits": logits}
MODEL_PATH = "."
LABELS = ['astro-ph', 'cond-mat', 'cs', 'eess', 'gr-qc',
'hep-ex', 'hep-lat', 'hep-ph', 'hep-th', 'math', 'math-ph', 'nlin',
'nucl-ex', 'nucl-th', 'physics', 'q-bio', 'quant-ph', 'stat']
MAX_LEN = 512
@st.cache_resource
def load_model():
try:
config = BertConfig.from_pretrained("bert-base-cased")
config.num_labels = len(LABELS)
model = BERTClass(config)
with safe_open(f"{MODEL_PATH}/model.safetensors", framework="pt") as f:
state_dict = {key: f.get_tensor(key) for key in f.keys()}
model.load_state_dict(state_dict)
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
return model.eval(), tokenizer
except Exception as e:
st.error(f"Model loading failed: {str(e)}")
st.stop()
@st.cache_data
def predict(title, abstract):
if not title.strip() and not abstract.strip():
raise ValueError("Bro, do you want me to guess?) Give me at least the title!")
text = f"{title.strip()}. {abstract.strip()}".strip()
if len(text) < 10:
raise ValueError("Too short text to say anything sensible")
device = next(model.parameters()).device
inputs = tokenizer.encode_plus(
text,
max_length=MAX_LEN,
padding="max_length",
truncation=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs['logits']
probs = torch.sigmoid(logits).cpu().numpy()[0]
return {label: float(probs[i]) for i, label in enumerate(LABELS)}
model, tokenizer = load_model()
with st.sidebar:
st.header("Display Settings")
display_mode = st.radio(
"Result filtering mode",
["Top-k categories", "Top-% confidence"],
index=0
)
if display_mode == "Top-k categories":
top_k = st.slider(
"Number of categories to show",
min_value=1,
max_value=10,
value=3,
help="Select how many top categories to display"
)
else:
selected_percent = st.selectbox(
"Confidence threshold",
["50%", "75%", "95%"],
index=2,
help="Display categories until reaching this cumulative confidence"
)
st.title("π Academic Paper Classifier")
with st.form("input_form"):
title = st.text_input("Paper Title", placeholder="Enter paper title...")
abstract = st.text_area("Abstract", placeholder="Paste paper abstract here...", height=200)
submitted = st.form_submit_button("Classify")
if submitted:
with st.spinner("Analyzing paper..."):
try:
full_predictions = predict(title, abstract)
sorted_preds = sorted(full_predictions.items(),
key=lambda x: x[1],
reverse=True)
if display_mode == "Top-k categories":
filtered = dict(sorted_preds[:top_k])
else:
threshold = {"50%": 0.5, "75%": 0.75, "95%": 0.95}[selected_percent]
total = sum(score for _, score in sorted_preds)
cumulative = 0
filtered = {}
for label, score in sorted_preds:
cumulative += score
filtered[label] = score
if cumulative >= threshold:
break
if len(filtered) >= 10:
break
if not filtered:
st.warning("No categories meet the selected criteria")
else:
top_class = max(filtered, key=filtered.get)
st.success(f"Most likely category: **{top_class}**")
st.subheader("Category Confidence Scores:")
total_shown = sum(filtered.values())
for label, score in filtered.items():
relative_score = score / total_shown
st.progress(
relative_score,
text=f"{label}: {score:.1%}"
)
st.caption(f"Coverage: {sum(filtered.values()):.1%} of total confidence")
except Exception as e:
st.error(f"Error: {str(e)}")
with st.sidebar:
st.header("About")
st.markdown(f"""
This tool predicts the arxiv tag of research papers by their title and abstarct via fine-tuned BERT.
- Enter title and abstract
- Enjoy the magnificent classification results
""") |