babel_machine / interfaces /cap_minor.py
kovacsvi
JIT tracing
fb1a253
raw
history blame
4.78 kB
import gradio as gr
import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from huggingface_hub import HfApi
from label_dicts import CAP_MIN_NUM_DICT, CAP_MIN_LABEL_NAMES, CAP_LABEL_NAMES
from .utils import is_disk_full, release_model
from itertools import islice
def take(n, iterable):
"""Return the first n items of the iterable as a list."""
return list(islice(iterable, n))
def score_to_color(prob):
red = int(255 * (1 - prob))
green = int(255 * prob)
return f"rgb({red},{green},0)"
HF_TOKEN = os.environ["hf_read"]
languages = [
"Multilingual",
]
domains = {
"media": "media",
"social media": "social",
"parliamentary speech": "parlspeech",
"legislative documents": "legislative",
"executive speech": "execspeech",
"executive order": "execorder",
"party programs": "party",
"judiciary": "judiciary",
"budget": "budget",
"public opinion": "publicopinion",
"local government agenda": "localgovernment"
}
def convert_minor_to_major(minor_topic):
if minor_topic == 999:
major_code = 999
else:
major_code = str(minor_topic)[:-2]
label = CAP_LABEL_NAMES[int(major_code)]
return label
def check_huggingface_path(checkpoint_path: str):
try:
hf_api = HfApi(token=HF_TOKEN)
hf_api.model_info(checkpoint_path, token=HF_TOKEN)
return True
except:
return False
def build_huggingface_path(language: str, domain: str):
return "poltextlab/xlm-roberta-large-pooled-cap-minor-v3"
def predict(text, model_id, tokenizer_id):
device = torch.device("cpu")
# Load JIT-traced model
jit_model_path = f"/data/jit_models/{model_id.replace('/', '_')}.pt"
model = torch.jit.load(jit_model_path).to(device)
model.eval()
# Load tokenizer (still regular HF)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
# Tokenize input
inputs = tokenizer(
text,
max_length=256,
truncation=True,
padding="do_not_pad",
return_tensors="pt"
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output = model(inputs["input_ids"], inputs["attention_mask"])
print(output) # debug
logits = output["logits"]
release_model(model, model_id)
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
output_pred = {f"[{'999' if str(CAP_MIN_NUM_DICT[i]) == '999' else str(CAP_MIN_NUM_DICT[i])[:-2]}]{convert_minor_to_major(CAP_MIN_NUM_DICT[i])} [{CAP_MIN_NUM_DICT[i]}]{CAP_MIN_LABEL_NAMES[CAP_MIN_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
output_pred = dict(sorted(output_pred.items(), key=lambda item: item[1], reverse=True))
first_n_items = take(5, output_pred.items())
html = ""
html += '<div style="background-color: white">'
first = True
for label, prob in first_n_items:
bar_color = "#e0d890" if first else "#ccc"
text_color = "black"
bar_width = int(prob * 100)
bar_color = score_to_color(prob)
if first:
html += f"""
<div style="text-align: center; font-weight: bold; font-size: 27px; margin-bottom: 10px; margin-left: 10px; margin-right: 10px;">
<span style="color: {text_color};">{label}</span>
</div>"""
html += f"""
<div style="height: 4px; background-color: green; width: {bar_width}%; margin-bottom: 8px;"></div>
<div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 4px;">
<span style="color: {text_color};">{label}{int(prob * 100)}%</span>
</div>
"""
first = False
html += '</div>'
output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
return html, output_info
def predict_cap(text, language, domain):
domain = domains[domain]
model_id = build_huggingface_path(language, domain)
tokenizer_id = "xlm-roberta-large"
if is_disk_full():
os.system('rm -rf /data/models*')
os.system('rm -r ~/.cache/huggingface/hub')
return predict(text, model_id, tokenizer_id)
demo = gr.Interface(
title="CAP Minor Topics Babel Demo",
fn=predict_cap,
inputs=[gr.Textbox(lines=6, label="Input"),
gr.Dropdown(languages, label="Language", value=languages[0]),
gr.Dropdown(domains.keys(), label="Domain", value=list(domains.keys())[0])],
outputs=[gr.HTML(label="Output"), gr.Markdown()])