Spaces:
Sleeping
Sleeping
import logging | |
import numpy as np | |
import streamlit as st | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
MAPPING_FROM_TAG_TO_CATEGORY = { | |
"cs.AI": ["Computer Science", "Artificial Intelligence"], | |
"cs.CL": ["Computer Science", "Computation and Language"], | |
"cs.CV": ["Computer Science", "Computer Vision and Pattern Recognition"], | |
"cs.NE": ["Computer Science", "Neural and Evolutionary Computing"], | |
"stat.ML": ["Statistics", "Machine Learning"], | |
"cs.LG": ["Computer Science", "Machine Learning"], | |
"physics.soc-ph": ["Physics", "Physics and Society"], | |
"stat.AP": ["Statistics", "Applications"], | |
"cs.RO": ["Computer Science", "Robotics"], | |
"cs.MA": ["Computer Science", "Multiagent Systems"], | |
"math.OC": ["Mathematics", "Optimization and Control"], | |
"cs.IR": ["Computer Science", "Information Retrieval"], | |
"stat.ME": ["Statistics", "Methodology"], | |
"cs.DC": ["Computer Science", "Distributed, Parallel, and Cluster Computing"], | |
"stat.CO": ["Statistics", "Computation"], | |
"q-bio.NC": ["Quantitative Biology", "Neurons and Cognition"], | |
"cs.GT": ["Computer Science", "Computer Science and Game Theory"], | |
"cs.MM": ["Computer Science", "Multimedia"], | |
"cs.CR": ["Computer Science", "Cryptography and Security"], | |
"cs.HC": ["Computer Science", "Human-Computer Interaction"], | |
"cs.SD": ["Computer Science", "Sound"], | |
"cs.GR": ["Computer Science", "Graphics"], | |
"cs.CY": ["Computer Science", "Computers and Society"], | |
"math.ST": ["Mathematics", "Statistics Theory"], | |
"stat.TH": ["Statistics", "Statistics Theory"], | |
"cs.IT": ["Computer Science", "Information Theory"], | |
"math.IT": ["Mathematics", "Information Theory"], | |
"cs.SI": ["Computer Science", "Social and Information Networks"], | |
"cs.DB": ["Computer Science", "Databases"], | |
"cs.LO": ["Computer Science", "Logic in Computer Science"], | |
"cs.SY": ["Computer Science", "Systems and Control"], | |
"q-bio.QM": ["Quantitative Biology", "Quantitative Methods"], | |
"cs.DS": ["Computer Science", "Data Structures and Algorithms"], | |
"cs.NA": ["Computer Science", "Numerical Analysis"], | |
"cs.CE": ["Computer Science", "Computational Engineering, Finance, and Science"], | |
} | |
MODEL_PATH = "minemile/arxiv-tag-classifier" | |
MODEL_MAX_LENGTH = 512 | |
CUM_PROB_THRESHOLD = 0.95 | |
def load_model_and_tokenizer(model_path): | |
logger.info("Loading model and tokenizer from %s", model_path) | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForSequenceClassification.from_pretrained( | |
model_path, device_map="cpu" | |
) | |
model.eval() | |
logger.info("Model and tokenizer loaded successfully.") | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"An error occurred during prediction: {e}") | |
return None, None | |
def inference(model, tokenizer, title, summary=None): | |
if not model or not tokenizer: | |
st.error("Model or tokenizer not loaded. Cannot predict.") | |
return None, None | |
combined_text = title | |
if summary is not None: | |
combined_text += " " + summary | |
logger.info("Predicting for text: %s", combined_text) | |
try: | |
tokenized_inputs = tokenizer( | |
combined_text, | |
return_tensors="pt", | |
truncation=True, | |
padding=True, | |
max_length=MODEL_MAX_LENGTH, | |
) | |
with torch.no_grad(): | |
logits = model(**tokenized_inputs).logits | |
probabilities = torch.softmax(logits, dim=-1).squeeze().cpu().numpy() | |
return probabilities | |
except Exception as e: | |
st.error(f"An error occurred during prediction: {e}") | |
return None, None | |
def prepare_output(probabilities, id2label): | |
top_indices = np.argsort(probabilities)[::-1] | |
cum_sum = 0.0 | |
top_tags = [] | |
top_category = [] | |
top_subcategory = [] | |
top_probas = [] | |
for indx in top_indices: | |
tag = id2label[indx] | |
if tag not in MAPPING_FROM_TAG_TO_CATEGORY: | |
top_tags.append(f"{tag}") | |
top_category.append("Unknown") | |
top_subcategory.append("Unknown") | |
logger.warning("Tag %s not found in mapping from tag to category.", tag) | |
else: | |
top_tags.append(f"{tag}") | |
top_category.append(MAPPING_FROM_TAG_TO_CATEGORY[tag][0]) | |
top_subcategory.append(MAPPING_FROM_TAG_TO_CATEGORY[tag][1]) | |
top_probas.append(f"{probabilities[indx]*100:.2f}%") | |
cum_sum += probabilities[indx] | |
if cum_sum >= CUM_PROB_THRESHOLD or len(top_tags) >= 5: | |
break | |
return { | |
"Tag": top_tags, | |
"Category": top_category, | |
"Subcategory": top_subcategory, | |
"Probability": top_probas, | |
} | |
def main(): | |
st.set_page_config(page_title="ArXiv Category Tag Classifier", layout="wide") | |
st.title("ArXiv Category Tag Classifier") | |
with st.spinner("Loading model and tokenizer..."): | |
model, tokenizer = load_model_and_tokenizer(MODEL_PATH) | |
if model is None or tokenizer is None: | |
st.error("Failed to load model/tokenizer.") | |
return | |
st.markdown( | |
f"Enter the title (required) and summary (abstract) of an ArXiv paper to predict its " | |
f"ArXiv category using a transformer model. There are {len(model.config.id2label)} available categories." | |
) | |
st.divider() | |
col1, col2 = st.columns(2) | |
with col1: | |
title = st.text_input( | |
"Title", | |
placeholder="Enter the title of the paper (Required)", | |
) | |
with col2: | |
paper_summary = st.text_area( | |
"Paper Summary (Optional):", | |
placeholder="Paste the paper's abstract here. It can increase the accuracy of the prediction.", | |
height=180, | |
) | |
predict_button = st.button("Predict Category", type="primary") | |
st.markdown("---") | |
if predict_button: | |
if not title: | |
st.error("Title of the paper is required!") | |
else: | |
with st.spinner("Predicting category..."): | |
class_probabilities = inference(model, tokenizer, title, paper_summary) | |
if class_probabilities is not None: | |
st.subheader("Top Tags Predictions:") | |
output = prepare_output(class_probabilities, model.config.id2label) | |
logger.info("Output: %s", output) | |
st.dataframe(output, use_container_width=True) | |
st.markdown("---") | |
if __name__ == "__main__": | |
main() | |