import streamlit as st import torch import os import torch.nn as nn from safetensors import safe_open from transformers import BertPreTrainedModel, BertModel, BertTokenizer, BertConfig st.set_page_config(page_title="Paper Classifier", layout="wide") class BERTClass(BertPreTrainedModel): def __init__(self, config, p=0.3): super().__init__(config) self.bert = BertModel(config) self.dropout = nn.Dropout(p) self.linear = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None): outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True ) pooled_output = outputs.pooler_output pooled_output = self.dropout(pooled_output) logits = self.linear(pooled_output) loss = None if labels is not None: loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) return {"loss": loss, "logits": logits} MODEL_PATH = "." LABELS = ['astro-ph', 'cond-mat', 'cs', 'eess', 'gr-qc', 'hep-ex', 'hep-lat', 'hep-ph', 'hep-th', 'math', 'math-ph', 'nlin', 'nucl-ex', 'nucl-th', 'physics', 'q-bio', 'quant-ph', 'stat'] MAX_LEN = 512 @st.cache_resource def load_model(): try: config = BertConfig.from_pretrained("bert-base-cased") config.num_labels = len(LABELS) model = BERTClass(config) with safe_open(f"{MODEL_PATH}/model.safetensors", framework="pt") as f: state_dict = {key: f.get_tensor(key) for key in f.keys()} model.load_state_dict(state_dict) tokenizer = BertTokenizer.from_pretrained("bert-base-cased") return model.eval(), tokenizer except Exception as e: st.error(f"Model loading failed: {str(e)}") st.stop() @st.cache_data def predict(title, abstract): if not title.strip() and not abstract.strip(): raise ValueError("Bro, do you want me to guess?) Give me at least the title!") text = f"{title.strip()}. {abstract.strip()}".strip() if len(text) < 10: raise ValueError("Too short text to say anything sensible") device = next(model.parameters()).device inputs = tokenizer.encode_plus( text, max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="pt" ).to(device) with torch.no_grad(): outputs = model(**inputs) logits = outputs['logits'] probs = torch.sigmoid(logits).cpu().numpy()[0] return {label: float(probs[i]) for i, label in enumerate(LABELS)} model, tokenizer = load_model() with st.sidebar: st.header("Display Settings") display_mode = st.radio( "Result filtering mode", ["Top-k categories", "Top-% confidence"], index=0 ) if display_mode == "Top-k categories": top_k = st.slider( "Number of categories to show", min_value=1, max_value=10, value=3, help="Select how many top categories to display" ) else: selected_percent = st.selectbox( "Confidence threshold", ["50%", "75%", "95%"], index=2, help="Display categories until reaching this cumulative confidence" ) st.title("📄 Academic Paper Classifier") with st.form("input_form"): title = st.text_input("Paper Title", placeholder="Enter paper title...") abstract = st.text_area("Abstract", placeholder="Paste paper abstract here...", height=200) submitted = st.form_submit_button("Classify") if submitted: with st.spinner("Analyzing paper..."): try: full_predictions = predict(title, abstract) sorted_preds = sorted(full_predictions.items(), key=lambda x: x[1], reverse=True) if display_mode == "Top-k categories": filtered = dict(sorted_preds[:top_k]) else: threshold = {"50%": 0.5, "75%": 0.75, "95%": 0.95}[selected_percent] total = sum(score for _, score in sorted_preds) cumulative = 0 filtered = {} for label, score in sorted_preds: cumulative += score filtered[label] = score if cumulative >= threshold: break if len(filtered) >= 10: break if not filtered: st.warning("No categories meet the selected criteria") else: top_class = max(filtered, key=filtered.get) st.success(f"Most likely category: **{top_class}**") st.subheader("Category Confidence Scores:") total_shown = sum(filtered.values()) for label, score in filtered.items(): relative_score = score / total_shown st.progress( relative_score, text=f"{label}: {score:.1%}" ) st.caption(f"Coverage: {sum(filtered.values()):.1%} of total confidence") except Exception as e: st.error(f"Error: {str(e)}") with st.sidebar: st.header("About") st.markdown(f""" This tool predicts the arxiv tag of research papers by their title and abstarct via fine-tuned BERT. - Enter title and abstract - Enjoy the magnificent classification results """)