babel_machine / interfaces /cap_minor.py
kovacsvi
up-to-date prod demo
4bba8df
raw
history blame
2.78 kB
import gradio as gr
import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from huggingface_hub import HfApi
from label_dicts import CAP_MIN_NUM_DICT, CAP_MIN_LABEL_NAMES
from .utils import is_disk_full
HF_TOKEN = os.environ["hf_read"]
languages = [
"Multilingual",
]
domains = {
"media": "media",
"social media": "social",
"parliamentary speech": "parlspeech",
"legislative documents": "legislative",
"executive speech": "execspeech",
"executive order": "execorder",
"party programs": "party",
"judiciary": "judiciary",
"budget": "budget",
"public opinion": "publicopinion",
"local government agenda": "localgovernment"
}
def check_huggingface_path(checkpoint_path: str):
try:
hf_api = HfApi(token=HF_TOKEN)
hf_api.model_info(checkpoint_path, token=HF_TOKEN)
return True
except:
return False
def build_huggingface_path(language: str, domain: str):
return "poltextlab/xlm-roberta-large-pooled-cap-minor"
def predict(text, model_id, tokenizer_id):
device = torch.device("cpu")
model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
inputs = tokenizer(text,
max_length=256,
truncation=True,
padding="do_not_pad",
return_tensors="pt").to(device)
model.eval()
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
output_pred = {f"[{CAP_MIN_NUM_DICT[i]}] {CAP_MIN_LABEL_NAMES[CAP_MIN_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
return output_pred, output_info
def predict_cap(text, language, domain):
domain = domains[domain]
model_id = build_huggingface_path(language, domain)
tokenizer_id = "xlm-roberta-large"
if is_disk_full():
os.system('rm -rf /data/models*')
os.system('rm -r ~/.cache/huggingface/hub')
return predict(text, model_id, tokenizer_id)
demo = gr.Interface(
title="CAP Minor Topics Babel Demo",
fn=predict_cap,
inputs=[gr.Textbox(lines=6, label="Input"),
gr.Dropdown(languages, label="Language"),
gr.Dropdown(domains.keys(), label="Domain")],
outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])