File size: 3,486 Bytes
4bba8df
 
 
 
 
 
 
 
 
 
 
 
853f29a
4bba8df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb1a253
 
 
 
 
 
4bba8df
 
fb1a253
 
 
 
 
 
 
 
 
4bba8df
 
fb1a253
 
 
 
853f29a
4bba8df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b41a25
 
4bba8df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr

import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from huggingface_hub import HfApi

from label_dicts import ILLFRAMES_MIGRATION_LABEL_NAMES, ILLFRAMES_COVID_LABEL_NAMES, ILLFRAMES_WAR_LABEL_NAMES

from .utils import is_disk_full, release_model

HF_TOKEN = os.environ["hf_read"]

languages = [
    "English"
]

domains = {
    "Covid": "covid",
    "Migration": "migration",
    "War": "war"
}


# --- DEBUG ---
import shutil

def convert_size(size):
    for unit in ['B', 'KB', 'MB', 'GB', 'TB', 'PB']:
        if size < 1024:
            return f"{size:.2f} {unit}"
        size /= 1024

def get_disk_space(path="/"):
    total, used, free = shutil.disk_usage(path)
    
    return {
        "Total": convert_size(total),
        "Used": convert_size(used),
        "Free": convert_size(free)
    }

# ---

def check_huggingface_path(checkpoint_path: str):
    try:
        hf_api = HfApi(token=HF_TOKEN)
        hf_api.model_info(checkpoint_path, token=HF_TOKEN)
        return True
    except:
        return False

def build_huggingface_path(domain: str):
    return f"poltextlab/xlm-roberta-large-english-ILLFRAMES-{domain}"

def predict(text, model_id, tokenizer_id, label_names):
    device = torch.device("cpu")

    # Load JIT-traced model
    jit_model_path = f"/data/jit_models/{model_id.replace('/', '_')}.pt"
    model = torch.jit.load(jit_model_path).to(device)
    model.eval()

    # Load tokenizer (still regular HF)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

    # Tokenize input
    inputs = tokenizer(
        text,
        max_length=256,
        truncation=True,
        padding="do_not_pad",
        return_tensors="pt"
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        output = model(inputs["input_ids"], inputs["attention_mask"])
        print(output) # debug
        logits = output["logits"]
        
    release_model(model, model_id)

    probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()

    NUMS_DICT = {i: key for i, key in enumerate(sorted(label_names.keys()))}

    output_pred = {f"[{NUMS_DICT[i]}] {label_names[NUMS_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
    output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
    return output_pred, output_info

def predict_illframes(text, language, domain):   
    domain = domains[domain]
    model_id = build_huggingface_path(domain)
    tokenizer_id = "xlm-roberta-large"

    if domain == "migration":
        label_names = ILLFRAMES_MIGRATION_LABEL_NAMES
    elif domain == "covid":
        label_names = ILLFRAMES_COVID_LABEL_NAMES
    elif domain == "war":
        label_names = ILLFRAMES_WAR_LABEL_NAMES

    if is_disk_full():
        os.system('rm -rf /data/models*')
        os.system('rm -r ~/.cache/huggingface/hub')

    return predict(text, model_id, tokenizer_id, label_names)

demo = gr.Interface(
    title="ILLFRAMES Babel Demo",
    fn=predict_illframes,
    inputs=[gr.Textbox(lines=6, label="Input"),
            gr.Dropdown(languages, label="Language", value=languages[0]),
            gr.Dropdown(domains.keys(), label="Domain", value=list(domains.keys())[0])],
    outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])