File size: 3,843 Bytes
1dd63b2
 
e36317e
1dd63b2
e36317e
 
 
1dd63b2
e36317e
 
 
 
 
1dd63b2
e36317e
 
 
 
 
 
 
 
 
 
1dd63b2
836e6dc
e36317e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dd63b2
836e6dc
 
e36317e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836e6dc
e36317e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836e6dc
 
 
1dd63b2
 
e36317e
1dd63b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
from transformers import pipeline
import logging

# Logging configuration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Model information
MODEL_LINKS = {
    "OpenAlex": "https://huggingface.co/OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract",
    "albertmartinez": "https://huggingface.co/albertmartinez/openalex-topic-classification-title-abstract"
}

# Load models only once
try:
    model = pipeline("text-classification",
                    model="OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract")
    model2 = pipeline("text-classification",
                     model="albertmartinez/openalex-topic-classification-title-abstract")
    logger.info("Models loaded successfully")
except Exception as e:
    logger.error(f"Error loading models: {str(e)}")
    raise

def classify_text(text, top_k):
    """
    Classify the given text using two different models.
    
    Args:
        text (str): Text to classify in format "<TITLE> {title}\n<ABSTRACT> {abstract}"
        top_k (int): Number of classifications to return
        
    Returns:
        tuple: Two dictionaries with classifications from each model
    """
    try:
        if not text or not isinstance(text, str):
            raise ValueError("Input text must be a non-empty string")
        
        if not isinstance(top_k, int) or top_k < 1:
            raise ValueError("top_k must be a positive integer")
            
        results = [
            {p["label"]: p["score"] for p in model(text, top_k=top_k, truncation=True, max_length=512)},
            {p["label"]: p["score"] for p in model2(text, top_k=top_k, truncation=True, max_length=512)}
        ]
        return results
    except Exception as e:
        logger.error(f"Classification error: {str(e)}")
        raise gr.Error(f"Classification error: {str(e)}")

# Example text
EXAMPLE_TEXT = """<TITLE> Machine Learning Applications in Healthcare
<ABSTRACT> This paper explores the use of machine learning algorithms in healthcare systems for disease prediction and diagnosis."""

demo = gr.Interface(
    fn=classify_text,
    inputs=[
        gr.Textbox(
            lines=5,
            label="Text",
            placeholder="<TITLE> {title}\n<ABSTRACT> {abstract}",
            value=EXAMPLE_TEXT
        ),
        gr.Number(
            label="Number of classifications (top_k)",
            value=10,
            precision=0,
            minimum=1,
            maximum=20
        )
    ],
    outputs=[
        gr.Label(label="Model 1: OpenAlex"),
        gr.Label(label="Model 2: albertmartinez")
    ],
    title="OpenAlex Topic Classification",
    description="""
    Enter a text with title and abstract to get its topic classification.
    
    Input format:
    ```
    <TITLE> Your title here
    <ABSTRACT> Your abstract here
    ```
    
    The system uses two different models to provide a more robust classification:
    
    1. [OpenAlex Model]({openalex_link}): Based on BERT multilingual model, fine-tuned on OpenAlex data
    2. [AlbertMartinez Model]({albert_link}): Based on BERT multilingual model, fine-tuned on [OpenAlex data](https://huggingface.co/datasets/albertmartinez/openalex-topic-title-abstract)
    
    For more information about the models and their performance, visit their Hugging Face pages.
    """.format(
        openalex_link=MODEL_LINKS["OpenAlex"],
        albert_link=MODEL_LINKS["albertmartinez"]
    ),
    examples=[
        [EXAMPLE_TEXT, 5],
        ["<TITLE> Climate Change Impact\n<ABSTRACT> Study of global warming effects on biodiversity", 3]
    ],
    flagging_mode="never",
    api_name="classify"
)

if __name__ == "__main__":
    logger.info(f"Gradio version: {gr.__version__}")
    demo.launch()