evstifeev.stepan
init
920d148
raw
history blame
6.65 kB
import logging
import numpy as np
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
MAPPING_FROM_TAG_TO_CATEGORY = {
"cs.AI": ["Computer Science", "Artificial Intelligence"],
"cs.CL": ["Computer Science", "Computation and Language"],
"cs.CV": ["Computer Science", "Computer Vision and Pattern Recognition"],
"cs.NE": ["Computer Science", "Neural and Evolutionary Computing"],
"stat.ML": ["Statistics", "Machine Learning"],
"cs.LG": ["Computer Science", "Machine Learning"],
"physics.soc-ph": ["Physics", "Physics and Society"],
"stat.AP": ["Statistics", "Applications"],
"cs.RO": ["Computer Science", "Robotics"],
"cs.MA": ["Computer Science", "Multiagent Systems"],
"math.OC": ["Mathematics", "Optimization and Control"],
"cs.IR": ["Computer Science", "Information Retrieval"],
"stat.ME": ["Statistics", "Methodology"],
"cs.DC": ["Computer Science", "Distributed, Parallel, and Cluster Computing"],
"stat.CO": ["Statistics", "Computation"],
"q-bio.NC": ["Quantitative Biology", "Neurons and Cognition"],
"cs.GT": ["Computer Science", "Computer Science and Game Theory"],
"cs.MM": ["Computer Science", "Multimedia"],
"cs.CR": ["Computer Science", "Cryptography and Security"],
"cs.HC": ["Computer Science", "Human-Computer Interaction"],
"cs.SD": ["Computer Science", "Sound"],
"cs.GR": ["Computer Science", "Graphics"],
"cs.CY": ["Computer Science", "Computers and Society"],
"math.ST": ["Mathematics", "Statistics Theory"],
"stat.TH": ["Statistics", "Statistics Theory"],
"cs.IT": ["Computer Science", "Information Theory"],
"math.IT": ["Mathematics", "Information Theory"],
"cs.SI": ["Computer Science", "Social and Information Networks"],
"cs.DB": ["Computer Science", "Databases"],
"cs.LO": ["Computer Science", "Logic in Computer Science"],
"cs.SY": ["Computer Science", "Systems and Control"],
"q-bio.QM": ["Quantitative Biology", "Quantitative Methods"],
"cs.DS": ["Computer Science", "Data Structures and Algorithms"],
"cs.NA": ["Computer Science", "Numerical Analysis"],
"cs.CE": ["Computer Science", "Computational Engineering, Finance, and Science"],
}
MODEL_PATH = "minemile/arxiv-tag-classifier"
MODEL_MAX_LENGTH = 512
CUM_PROB_THRESHOLD = 0.95
@st.cache_resource
def load_model_and_tokenizer(model_path):
logger.info("Loading model and tokenizer from %s", model_path)
try:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(
model_path, device_map="cpu"
)
model.eval()
logger.info("Model and tokenizer loaded successfully.")
return model, tokenizer
except Exception as e:
st.error(f"An error occurred during prediction: {e}")
return None, None
def inference(model, tokenizer, title, summary=None):
if not model or not tokenizer:
st.error("Model or tokenizer not loaded. Cannot predict.")
return None, None
combined_text = title
if summary is not None:
combined_text += " " + summary
logger.info("Predicting for text: %s", combined_text)
try:
tokenized_inputs = tokenizer(
combined_text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=MODEL_MAX_LENGTH,
)
with torch.no_grad():
logits = model(**tokenized_inputs).logits
probabilities = torch.softmax(logits, dim=-1).squeeze().cpu().numpy()
return probabilities
except Exception as e:
st.error(f"An error occurred during prediction: {e}")
return None, None
def prepare_output(probabilities, id2label):
top_indices = np.argsort(probabilities)[::-1]
cum_sum = 0.0
top_tags = []
top_category = []
top_subcategory = []
top_probas = []
for indx in top_indices:
tag = id2label[indx]
if tag not in MAPPING_FROM_TAG_TO_CATEGORY:
top_tags.append(f"{tag}")
top_category.append("Unknown")
top_subcategory.append("Unknown")
logger.warning("Tag %s not found in mapping from tag to category.", tag)
else:
top_tags.append(f"{tag}")
top_category.append(MAPPING_FROM_TAG_TO_CATEGORY[tag][0])
top_subcategory.append(MAPPING_FROM_TAG_TO_CATEGORY[tag][1])
top_probas.append(f"{probabilities[indx]*100:.2f}%")
cum_sum += probabilities[indx]
if cum_sum >= CUM_PROB_THRESHOLD or len(top_tags) >= 5:
break
return {
"Tag": top_tags,
"Category": top_category,
"Subcategory": top_subcategory,
"Probability": top_probas,
}
def main():
st.set_page_config(page_title="ArXiv Category Tag Classifier", layout="wide")
st.title("ArXiv Category Tag Classifier")
with st.spinner("Loading model and tokenizer..."):
model, tokenizer = load_model_and_tokenizer(MODEL_PATH)
if model is None or tokenizer is None:
st.error("Failed to load model/tokenizer.")
return
st.markdown(
f"Enter the title (required) and summary (abstract) of an ArXiv paper to predict its "
f"ArXiv category using a transformer model. There are {len(model.config.id2label)} available categories."
)
st.divider()
col1, col2 = st.columns(2)
with col1:
title = st.text_input(
"Title",
placeholder="Enter the title of the paper (Required)",
)
with col2:
paper_summary = st.text_area(
"Paper Summary (Optional):",
placeholder="Paste the paper's abstract here. It can increase the accuracy of the prediction.",
height=180,
)
predict_button = st.button("Predict Category", type="primary")
st.markdown("---")
if predict_button:
if not title:
st.error("Title of the paper is required!")
else:
with st.spinner("Predicting category..."):
class_probabilities = inference(model, tokenizer, title, paper_summary)
if class_probabilities is not None:
st.subheader("Top Tags Predictions:")
output = prepare_output(class_probabilities, model.config.id2label)
logger.info("Output: %s", output)
st.dataframe(output, use_container_width=True)
st.markdown("---")
if __name__ == "__main__":
main()