File size: 5,198 Bytes
1130e24
 
 
 
 
 
 
 
04f8f8b
1130e24
70be539
1130e24
70be539
 
1130e24
 
 
 
 
 
 
 
 
 
 
 
 
67fd23e
 
 
 
b2b1d1f
67fd23e
 
 
 
 
 
 
 
 
a1eb2dd
 
 
 
 
 
 
 
 
67fd23e
1130e24
 
 
 
 
 
 
 
 
70be539
1130e24
70be539
1130e24
70be539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1130e24
 
70be539
 
1130e24
70be539
 
1130e24
70be539
 
 
 
 
 
 
5f6b004
90fa7ec
 
 
 
5f6b004
90fa7ec
5f6b004
 
70be539
d6c2bce
7d912ec
70be539
5f6b004
70be539
 
 
 
 
 
 
 
 
 
 
 
a1eb2dd
 
70be539
 
 
 
 
 
1130e24
 
 
 
 
70be539
1130e24
 
 
 
 
 
70be539
 
 
90fa7ec
70be539
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import gradio as gr

import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import torch.nn.functional as F
from huggingface_hub import HfApi
from collections import defaultdict

from label_dicts import (CAP_MEDIA_NUM_DICT, CAP_MEDIA_LABEL_NAMES,
                        CAP_MIN_NUM_DICT, CAP_MIN_LABEL_NAMES)

from .utils import is_disk_full

HF_TOKEN = os.environ["hf_read"]

languages = [
    "Multilingual",
]

domains = {
    "media": "media"
}


CAP_MEDIA_CODES = list(CAP_MEDIA_NUM_DICT.values())
CAP_MIN_CODES = list(CAP_MIN_NUM_DICT.values())

major_index_to_id = {i: code for i, code in enumerate(CAP_MEDIA_CODES)}
minor_id_to_index = {code: i for i, code in enumerate(CAP_MIN_CODES)}
minor_index_to_id = {i: code for i, code in enumerate(CAP_MIN_CODES)}

major_to_minor_map = defaultdict(list)
for code in CAP_MIN_CODES:
    major_id = int(str(code)[:-2])
    major_to_minor_map[major_id].append(code)
major_to_minor_map = dict(major_to_minor_map)

def normalize_probs(probs: dict):
    min_val = min(probs.values())
    max_val = max(probs.values())
    range_val = max_val - min_val

    if range_val == 0:
        return {k: 1.0 for k in probs}

    return {k: (v - min_val) / range_val for k, v in probs.items()}

def check_huggingface_path(checkpoint_path: str):
    try:
        hf_api = HfApi(token=HF_TOKEN)
        hf_api.model_info(checkpoint_path, token=HF_TOKEN)
        return True
    except:
        return False

def build_huggingface_path(language: str, domain: str):
    return ("poltextlab/xlm-roberta-large-pooled-cap-media", "poltextlab/xlm-roberta-large-pooled-cap-minor-v3")

def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
    device = torch.device("cpu")

    # Load major and minor models + tokenizer
    major_model = AutoModelForSequenceClassification.from_pretrained(
        major_model_id,
        low_cpu_mem_usage=True,
        device_map="auto",
        offload_folder="offload",
        token=HF_TOKEN
    ).to(device)

    minor_model = AutoModelForSequenceClassification.from_pretrained(
        minor_model_id,
        low_cpu_mem_usage=True,
        device_map="auto",
        offload_folder="offload",
        token=HF_TOKEN
    ).to(device)

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

    # Tokenize input
    inputs = tokenizer(text, max_length=256, truncation=True, padding="do_not_pad", return_tensors="pt").to(device)

    # Predict major topic
    major_model.eval()
    with torch.no_grad():
        major_logits = major_model(**inputs).logits
        major_probs = F.softmax(major_logits, dim=-1)
    major_probs_np = major_probs.cpu().numpy().flatten()
    top_major_index = int(np.argmax(major_probs_np))
    top_major_id = major_index_to_id[top_major_index]

    # Default: show major topic predictions
    print(major_probs_np) # debug
    filtered_probs = {
        i: float(major_probs_np[i])
        for i in np.argsort(major_probs_np)[::-1]
    }
    print(filtered_probs) # debug
    filtered_probs = normalize_probs(filtered_probs)
    print(filtered_probs) # debug
        
    output_pred = {
        f"[{major_index_to_id[k]}] {CAP_MEDIA_LABEL_NAMES[k]}": v
        for k, v in sorted(filtered_probs.items(), key=lambda item: item[1], reverse=True)
    }
    print(output_pred) # debug

    # If eligible for minor prediction
    if top_major_id in major_to_minor_map:
        valid_minor_ids = major_to_minor_map[top_major_id]
        minor_model.eval()
        with torch.no_grad():
            minor_logits = minor_model(**inputs).logits
            minor_probs = F.softmax(minor_logits, dim=-1)

        # Restrict to valid minor codes
        valid_indices = [minor_id_to_index[mid] for mid in valid_minor_ids if mid in minor_id_to_index]
        filtered_probs = {minor_index_to_id[i]: float(minor_probs[0][i]) for i in valid_indices}
        filtered_probs = normalize_probs(filtered_probs)
        
        output_pred = {
            f"[{k}] {CAP_MIN_LABEL_NAMES[k]}": v
            for k, v in sorted(filtered_probs.items(), key=lambda item: item[1], reverse=True)
        }

    output_info = f'<p style="text-align: center; display: block">Prediction used <a href="https://huggingface.co/{major_model_id}">{major_model_id}</a> and <a href="https://huggingface.co/{minor_model_id}">{minor_model_id}</a>.</p>'

    return output_pred, output_info

def predict_cap(text, language, domain):
    domain = domains[domain]
    major_model_id, minor_model_id = build_huggingface_path(language, domain)
    tokenizer_id = "xlm-roberta-large"
    
    if is_disk_full():
        os.system('rm -rf /data/models*')
        os.system('rm -r ~/.cache/huggingface/hub')
        
    return predict(text, major_model_id, minor_model_id, tokenizer_id)

demo = gr.Interface(
    title="CAP Media/Minor Topics Babel Demo",
    fn=predict_cap,
    inputs=[gr.Textbox(lines=6, label="Input"),
            gr.Dropdown(languages, label="Language"),
            gr.Dropdown(domains.keys(), label="Domain")],
    outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])