File size: 6,654 Bytes
920d148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import logging

import numpy as np
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

MAPPING_FROM_TAG_TO_CATEGORY = {
    "cs.AI": ["Computer Science", "Artificial Intelligence"],
    "cs.CL": ["Computer Science", "Computation and Language"],
    "cs.CV": ["Computer Science", "Computer Vision and Pattern Recognition"],
    "cs.NE": ["Computer Science", "Neural and Evolutionary Computing"],
    "stat.ML": ["Statistics", "Machine Learning"],
    "cs.LG": ["Computer Science", "Machine Learning"],
    "physics.soc-ph": ["Physics", "Physics and Society"],
    "stat.AP": ["Statistics", "Applications"],
    "cs.RO": ["Computer Science", "Robotics"],
    "cs.MA": ["Computer Science", "Multiagent Systems"],
    "math.OC": ["Mathematics", "Optimization and Control"],
    "cs.IR": ["Computer Science", "Information Retrieval"],
    "stat.ME": ["Statistics", "Methodology"],
    "cs.DC": ["Computer Science", "Distributed, Parallel, and Cluster Computing"],
    "stat.CO": ["Statistics", "Computation"],
    "q-bio.NC": ["Quantitative Biology", "Neurons and Cognition"],
    "cs.GT": ["Computer Science", "Computer Science and Game Theory"],
    "cs.MM": ["Computer Science", "Multimedia"],
    "cs.CR": ["Computer Science", "Cryptography and Security"],
    "cs.HC": ["Computer Science", "Human-Computer Interaction"],
    "cs.SD": ["Computer Science", "Sound"],
    "cs.GR": ["Computer Science", "Graphics"],
    "cs.CY": ["Computer Science", "Computers and Society"],
    "math.ST": ["Mathematics", "Statistics Theory"],
    "stat.TH": ["Statistics", "Statistics Theory"],
    "cs.IT": ["Computer Science", "Information Theory"],
    "math.IT": ["Mathematics", "Information Theory"],
    "cs.SI": ["Computer Science", "Social and Information Networks"],
    "cs.DB": ["Computer Science", "Databases"],
    "cs.LO": ["Computer Science", "Logic in Computer Science"],
    "cs.SY": ["Computer Science", "Systems and Control"],
    "q-bio.QM": ["Quantitative Biology", "Quantitative Methods"],
    "cs.DS": ["Computer Science", "Data Structures and Algorithms"],
    "cs.NA": ["Computer Science", "Numerical Analysis"],
    "cs.CE": ["Computer Science", "Computational Engineering, Finance, and Science"],
}

MODEL_PATH = "minemile/arxiv-tag-classifier"
MODEL_MAX_LENGTH = 512
CUM_PROB_THRESHOLD = 0.95


@st.cache_resource
def load_model_and_tokenizer(model_path):
    logger.info("Loading model and tokenizer from %s", model_path)
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForSequenceClassification.from_pretrained(
            model_path, device_map="cpu"
        )
        model.eval()
        logger.info("Model and tokenizer loaded successfully.")
        return model, tokenizer
    except Exception as e:
        st.error(f"An error occurred during prediction: {e}")
        return None, None


def inference(model, tokenizer, title, summary=None):
    if not model or not tokenizer:
        st.error("Model or tokenizer not loaded. Cannot predict.")
        return None, None

    combined_text = title
    if summary is not None:
        combined_text += " " + summary

    logger.info("Predicting for text: %s", combined_text)

    try:
        tokenized_inputs = tokenizer(
            combined_text,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=MODEL_MAX_LENGTH,
        )

        with torch.no_grad():
            logits = model(**tokenized_inputs).logits
            probabilities = torch.softmax(logits, dim=-1).squeeze().cpu().numpy()
        return probabilities

    except Exception as e:
        st.error(f"An error occurred during prediction: {e}")
        return None, None


def prepare_output(probabilities, id2label):
    top_indices = np.argsort(probabilities)[::-1]
    cum_sum = 0.0
    top_tags = []
    top_category = []
    top_subcategory = []
    top_probas = []
    for indx in top_indices:
        tag = id2label[indx]
        if tag not in MAPPING_FROM_TAG_TO_CATEGORY:
            top_tags.append(f"{tag}")
            top_category.append("Unknown")
            top_subcategory.append("Unknown")
            logger.warning("Tag %s not found in mapping from tag to category.", tag)
        else:
            top_tags.append(f"{tag}")
            top_category.append(MAPPING_FROM_TAG_TO_CATEGORY[tag][0])
            top_subcategory.append(MAPPING_FROM_TAG_TO_CATEGORY[tag][1])
        top_probas.append(f"{probabilities[indx]*100:.2f}%")
        cum_sum += probabilities[indx]
        if cum_sum >= CUM_PROB_THRESHOLD or len(top_tags) >= 5:
            break
    return {
        "Tag": top_tags,
        "Category": top_category,
        "Subcategory": top_subcategory,
        "Probability": top_probas,
    }


def main():
    st.set_page_config(page_title="ArXiv Category Tag Classifier", layout="wide")
    st.title("ArXiv Category Tag Classifier")
    with st.spinner("Loading model and tokenizer..."):
        model, tokenizer = load_model_and_tokenizer(MODEL_PATH)
    if model is None or tokenizer is None:
        st.error("Failed to load model/tokenizer.")
        return

    st.markdown(
        f"Enter the title (required) and summary (abstract) of an ArXiv paper to predict its "
        f"ArXiv category using a transformer model. There are {len(model.config.id2label)} available categories."
    )
    st.divider()
    col1, col2 = st.columns(2)
    with col1:
        title = st.text_input(
            "Title",
            placeholder="Enter the title of the paper (Required)",
        )
    with col2:
        paper_summary = st.text_area(
            "Paper Summary (Optional):",
            placeholder="Paste the paper's abstract here. It can increase the accuracy of the prediction.",
            height=180,
        )

    predict_button = st.button("Predict Category", type="primary")

    st.markdown("---")

    if predict_button:
        if not title:
            st.error("Title of the paper is required!")
        else:
            with st.spinner("Predicting category..."):
                class_probabilities = inference(model, tokenizer, title, paper_summary)

            if class_probabilities is not None:
                st.subheader("Top Tags Predictions:")
                output = prepare_output(class_probabilities, model.config.id2label)
                logger.info("Output: %s", output)
                st.dataframe(output, use_container_width=True)
    st.markdown("---")


if __name__ == "__main__":
    main()