import logging import numpy as np import streamlit as st import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) MAPPING_FROM_TAG_TO_CATEGORY = { "cs.AI": ["Computer Science", "Artificial Intelligence"], "cs.CL": ["Computer Science", "Computation and Language"], "cs.CV": ["Computer Science", "Computer Vision and Pattern Recognition"], "cs.NE": ["Computer Science", "Neural and Evolutionary Computing"], "stat.ML": ["Statistics", "Machine Learning"], "cs.LG": ["Computer Science", "Machine Learning"], "physics.soc-ph": ["Physics", "Physics and Society"], "stat.AP": ["Statistics", "Applications"], "cs.RO": ["Computer Science", "Robotics"], "cs.MA": ["Computer Science", "Multiagent Systems"], "math.OC": ["Mathematics", "Optimization and Control"], "cs.IR": ["Computer Science", "Information Retrieval"], "stat.ME": ["Statistics", "Methodology"], "cs.DC": ["Computer Science", "Distributed, Parallel, and Cluster Computing"], "stat.CO": ["Statistics", "Computation"], "q-bio.NC": ["Quantitative Biology", "Neurons and Cognition"], "cs.GT": ["Computer Science", "Computer Science and Game Theory"], "cs.MM": ["Computer Science", "Multimedia"], "cs.CR": ["Computer Science", "Cryptography and Security"], "cs.HC": ["Computer Science", "Human-Computer Interaction"], "cs.SD": ["Computer Science", "Sound"], "cs.GR": ["Computer Science", "Graphics"], "cs.CY": ["Computer Science", "Computers and Society"], "math.ST": ["Mathematics", "Statistics Theory"], "stat.TH": ["Statistics", "Statistics Theory"], "cs.IT": ["Computer Science", "Information Theory"], "math.IT": ["Mathematics", "Information Theory"], "cs.SI": ["Computer Science", "Social and Information Networks"], "cs.DB": ["Computer Science", "Databases"], "cs.LO": ["Computer Science", "Logic in Computer Science"], "cs.SY": ["Computer Science", "Systems and Control"], "q-bio.QM": ["Quantitative Biology", "Quantitative Methods"], "cs.DS": ["Computer Science", "Data Structures and Algorithms"], "cs.NA": ["Computer Science", "Numerical Analysis"], "cs.CE": ["Computer Science", "Computational Engineering, Finance, and Science"], } MODEL_PATH = "minemile/arxiv-tag-classifier" MODEL_MAX_LENGTH = 512 CUM_PROB_THRESHOLD = 0.95 @st.cache_resource def load_model_and_tokenizer(model_path): logger.info("Loading model and tokenizer from %s", model_path) try: tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSequenceClassification.from_pretrained( model_path, device_map="cpu" ) model.eval() logger.info("Model and tokenizer loaded successfully.") return model, tokenizer except Exception as e: st.error(f"An error occurred during prediction: {e}") return None, None def inference(model, tokenizer, title, summary=None): if not model or not tokenizer: st.error("Model or tokenizer not loaded. Cannot predict.") return None, None combined_text = title if summary is not None: combined_text += " " + summary logger.info("Predicting for text: %s", combined_text) try: tokenized_inputs = tokenizer( combined_text, return_tensors="pt", truncation=True, padding=True, max_length=MODEL_MAX_LENGTH, ) with torch.no_grad(): logits = model(**tokenized_inputs).logits probabilities = torch.softmax(logits, dim=-1).squeeze().cpu().numpy() return probabilities except Exception as e: st.error(f"An error occurred during prediction: {e}") return None, None def prepare_output(probabilities, id2label): top_indices = np.argsort(probabilities)[::-1] cum_sum = 0.0 top_tags = [] top_category = [] top_subcategory = [] top_probas = [] for indx in top_indices: tag = id2label[indx] if tag not in MAPPING_FROM_TAG_TO_CATEGORY: top_tags.append(f"{tag}") top_category.append("Unknown") top_subcategory.append("Unknown") logger.warning("Tag %s not found in mapping from tag to category.", tag) else: top_tags.append(f"{tag}") top_category.append(MAPPING_FROM_TAG_TO_CATEGORY[tag][0]) top_subcategory.append(MAPPING_FROM_TAG_TO_CATEGORY[tag][1]) top_probas.append(f"{probabilities[indx]*100:.2f}%") cum_sum += probabilities[indx] if cum_sum >= CUM_PROB_THRESHOLD or len(top_tags) >= 5: break return { "Tag": top_tags, "Category": top_category, "Subcategory": top_subcategory, "Probability": top_probas, } def main(): st.set_page_config(page_title="ArXiv Category Tag Classifier", layout="wide") st.title("ArXiv Category Tag Classifier") with st.spinner("Loading model and tokenizer..."): model, tokenizer = load_model_and_tokenizer(MODEL_PATH) if model is None or tokenizer is None: st.error("Failed to load model/tokenizer.") return st.markdown( f"Enter the title (required) and summary (abstract) of an ArXiv paper to predict its " f"ArXiv category using a transformer model. There are {len(model.config.id2label)} available categories." ) st.divider() col1, col2 = st.columns(2) with col1: title = st.text_input( "Title", placeholder="Enter the title of the paper (Required)", ) with col2: paper_summary = st.text_area( "Paper Summary (Optional):", placeholder="Paste the paper's abstract here. It can increase the accuracy of the prediction.", height=180, ) predict_button = st.button("Predict Category", type="primary") st.markdown("---") if predict_button: if not title: st.error("Title of the paper is required!") else: with st.spinner("Predicting category..."): class_probabilities = inference(model, tokenizer, title, paper_summary) if class_probabilities is not None: st.subheader("Top Tags Predictions:") output = prepare_output(class_probabilities, model.config.id2label) logger.info("Output: %s", output) st.dataframe(output, use_container_width=True) st.markdown("---") if __name__ == "__main__": main()