Spaces:
Sleeping
Sleeping
File size: 6,654 Bytes
920d148 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import logging
import numpy as np
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
MAPPING_FROM_TAG_TO_CATEGORY = {
"cs.AI": ["Computer Science", "Artificial Intelligence"],
"cs.CL": ["Computer Science", "Computation and Language"],
"cs.CV": ["Computer Science", "Computer Vision and Pattern Recognition"],
"cs.NE": ["Computer Science", "Neural and Evolutionary Computing"],
"stat.ML": ["Statistics", "Machine Learning"],
"cs.LG": ["Computer Science", "Machine Learning"],
"physics.soc-ph": ["Physics", "Physics and Society"],
"stat.AP": ["Statistics", "Applications"],
"cs.RO": ["Computer Science", "Robotics"],
"cs.MA": ["Computer Science", "Multiagent Systems"],
"math.OC": ["Mathematics", "Optimization and Control"],
"cs.IR": ["Computer Science", "Information Retrieval"],
"stat.ME": ["Statistics", "Methodology"],
"cs.DC": ["Computer Science", "Distributed, Parallel, and Cluster Computing"],
"stat.CO": ["Statistics", "Computation"],
"q-bio.NC": ["Quantitative Biology", "Neurons and Cognition"],
"cs.GT": ["Computer Science", "Computer Science and Game Theory"],
"cs.MM": ["Computer Science", "Multimedia"],
"cs.CR": ["Computer Science", "Cryptography and Security"],
"cs.HC": ["Computer Science", "Human-Computer Interaction"],
"cs.SD": ["Computer Science", "Sound"],
"cs.GR": ["Computer Science", "Graphics"],
"cs.CY": ["Computer Science", "Computers and Society"],
"math.ST": ["Mathematics", "Statistics Theory"],
"stat.TH": ["Statistics", "Statistics Theory"],
"cs.IT": ["Computer Science", "Information Theory"],
"math.IT": ["Mathematics", "Information Theory"],
"cs.SI": ["Computer Science", "Social and Information Networks"],
"cs.DB": ["Computer Science", "Databases"],
"cs.LO": ["Computer Science", "Logic in Computer Science"],
"cs.SY": ["Computer Science", "Systems and Control"],
"q-bio.QM": ["Quantitative Biology", "Quantitative Methods"],
"cs.DS": ["Computer Science", "Data Structures and Algorithms"],
"cs.NA": ["Computer Science", "Numerical Analysis"],
"cs.CE": ["Computer Science", "Computational Engineering, Finance, and Science"],
}
MODEL_PATH = "minemile/arxiv-tag-classifier"
MODEL_MAX_LENGTH = 512
CUM_PROB_THRESHOLD = 0.95
@st.cache_resource
def load_model_and_tokenizer(model_path):
logger.info("Loading model and tokenizer from %s", model_path)
try:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(
model_path, device_map="cpu"
)
model.eval()
logger.info("Model and tokenizer loaded successfully.")
return model, tokenizer
except Exception as e:
st.error(f"An error occurred during prediction: {e}")
return None, None
def inference(model, tokenizer, title, summary=None):
if not model or not tokenizer:
st.error("Model or tokenizer not loaded. Cannot predict.")
return None, None
combined_text = title
if summary is not None:
combined_text += " " + summary
logger.info("Predicting for text: %s", combined_text)
try:
tokenized_inputs = tokenizer(
combined_text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=MODEL_MAX_LENGTH,
)
with torch.no_grad():
logits = model(**tokenized_inputs).logits
probabilities = torch.softmax(logits, dim=-1).squeeze().cpu().numpy()
return probabilities
except Exception as e:
st.error(f"An error occurred during prediction: {e}")
return None, None
def prepare_output(probabilities, id2label):
top_indices = np.argsort(probabilities)[::-1]
cum_sum = 0.0
top_tags = []
top_category = []
top_subcategory = []
top_probas = []
for indx in top_indices:
tag = id2label[indx]
if tag not in MAPPING_FROM_TAG_TO_CATEGORY:
top_tags.append(f"{tag}")
top_category.append("Unknown")
top_subcategory.append("Unknown")
logger.warning("Tag %s not found in mapping from tag to category.", tag)
else:
top_tags.append(f"{tag}")
top_category.append(MAPPING_FROM_TAG_TO_CATEGORY[tag][0])
top_subcategory.append(MAPPING_FROM_TAG_TO_CATEGORY[tag][1])
top_probas.append(f"{probabilities[indx]*100:.2f}%")
cum_sum += probabilities[indx]
if cum_sum >= CUM_PROB_THRESHOLD or len(top_tags) >= 5:
break
return {
"Tag": top_tags,
"Category": top_category,
"Subcategory": top_subcategory,
"Probability": top_probas,
}
def main():
st.set_page_config(page_title="ArXiv Category Tag Classifier", layout="wide")
st.title("ArXiv Category Tag Classifier")
with st.spinner("Loading model and tokenizer..."):
model, tokenizer = load_model_and_tokenizer(MODEL_PATH)
if model is None or tokenizer is None:
st.error("Failed to load model/tokenizer.")
return
st.markdown(
f"Enter the title (required) and summary (abstract) of an ArXiv paper to predict its "
f"ArXiv category using a transformer model. There are {len(model.config.id2label)} available categories."
)
st.divider()
col1, col2 = st.columns(2)
with col1:
title = st.text_input(
"Title",
placeholder="Enter the title of the paper (Required)",
)
with col2:
paper_summary = st.text_area(
"Paper Summary (Optional):",
placeholder="Paste the paper's abstract here. It can increase the accuracy of the prediction.",
height=180,
)
predict_button = st.button("Predict Category", type="primary")
st.markdown("---")
if predict_button:
if not title:
st.error("Title of the paper is required!")
else:
with st.spinner("Predicting category..."):
class_probabilities = inference(model, tokenizer, title, paper_summary)
if class_probabilities is not None:
st.subheader("Top Tags Predictions:")
output = prepare_output(class_probabilities, model.config.id2label)
logger.info("Output: %s", output)
st.dataframe(output, use_container_width=True)
st.markdown("---")
if __name__ == "__main__":
main()
|