babel_machine / interfaces /illframes.py
kovacsvi
JIT tracing
fb1a253
raw
history blame
3.49 kB
import gradio as gr
import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from huggingface_hub import HfApi
from label_dicts import ILLFRAMES_MIGRATION_LABEL_NAMES, ILLFRAMES_COVID_LABEL_NAMES, ILLFRAMES_WAR_LABEL_NAMES
from .utils import is_disk_full, release_model
HF_TOKEN = os.environ["hf_read"]
languages = [
"English"
]
domains = {
"Covid": "covid",
"Migration": "migration",
"War": "war"
}
# --- DEBUG ---
import shutil
def convert_size(size):
for unit in ['B', 'KB', 'MB', 'GB', 'TB', 'PB']:
if size < 1024:
return f"{size:.2f} {unit}"
size /= 1024
def get_disk_space(path="/"):
total, used, free = shutil.disk_usage(path)
return {
"Total": convert_size(total),
"Used": convert_size(used),
"Free": convert_size(free)
}
# ---
def check_huggingface_path(checkpoint_path: str):
try:
hf_api = HfApi(token=HF_TOKEN)
hf_api.model_info(checkpoint_path, token=HF_TOKEN)
return True
except:
return False
def build_huggingface_path(domain: str):
return f"poltextlab/xlm-roberta-large-english-ILLFRAMES-{domain}"
def predict(text, model_id, tokenizer_id, label_names):
device = torch.device("cpu")
# Load JIT-traced model
jit_model_path = f"/data/jit_models/{model_id.replace('/', '_')}.pt"
model = torch.jit.load(jit_model_path).to(device)
model.eval()
# Load tokenizer (still regular HF)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
# Tokenize input
inputs = tokenizer(
text,
max_length=256,
truncation=True,
padding="do_not_pad",
return_tensors="pt"
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output = model(inputs["input_ids"], inputs["attention_mask"])
print(output) # debug
logits = output["logits"]
release_model(model, model_id)
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
NUMS_DICT = {i: key for i, key in enumerate(sorted(label_names.keys()))}
output_pred = {f"[{NUMS_DICT[i]}] {label_names[NUMS_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
return output_pred, output_info
def predict_illframes(text, language, domain):
domain = domains[domain]
model_id = build_huggingface_path(domain)
tokenizer_id = "xlm-roberta-large"
if domain == "migration":
label_names = ILLFRAMES_MIGRATION_LABEL_NAMES
elif domain == "covid":
label_names = ILLFRAMES_COVID_LABEL_NAMES
elif domain == "war":
label_names = ILLFRAMES_WAR_LABEL_NAMES
if is_disk_full():
os.system('rm -rf /data/models*')
os.system('rm -r ~/.cache/huggingface/hub')
return predict(text, model_id, tokenizer_id, label_names)
demo = gr.Interface(
title="ILLFRAMES Babel Demo",
fn=predict_illframes,
inputs=[gr.Textbox(lines=6, label="Input"),
gr.Dropdown(languages, label="Language", value=languages[0]),
gr.Dropdown(domains.keys(), label="Domain", value=list(domains.keys())[0])],
outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])