babel_machine / utils.py
kovacsvi
jit pt2
0a394ee
import os
import shutil
import subprocess
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from interfaces.cap import languages as languages_cap
from interfaces.cap import domains as domains_cap
from interfaces.emotion9 import languages as languages_emotion9
from interfaces.illframes import domains as domains_illframes
from interfaces.cap import build_huggingface_path as hf_cap_path
from interfaces.cap_minor import build_huggingface_path as hf_cap_minor_path
from interfaces.cap_minor_media import build_huggingface_path as hf_cap_minor_media_path
from interfaces.cap_media_demo import build_huggingface_path as hf_cap_media_path # why... just follow the name template the next time pls
from interfaces.manifesto import build_huggingface_path as hf_manifesto_path
from interfaces.sentiment import build_huggingface_path as hf_sentiment_path
from interfaces.emotion import build_huggingface_path as hf_emotion_path
from interfaces.emotion9 import build_huggingface_path as hf_emotion9_path
from interfaces.ontolisst import build_huggingface_path as hf_ontlisst_path
from interfaces.illframes import build_huggingface_path as hf_illframes_path
from interfaces.ontolisst import build_huggingface_path as hf_ontolisst_path
from huggingface_hub import scan_cache_dir
JIT_DIR = "/data/jit_models"
HF_TOKEN = os.environ["hf_read"]
# should be a temporary solution
models = [hf_manifesto_path(""), hf_sentiment_path(""), hf_emotion_path(""), hf_cap_minor_path("", ""), hf_ontolisst_path("")]
# it gets more difficult with cap
domains_cap = list(domains_cap.values())
for language in languages_cap:
for domain in domains_cap:
models.append(hf_cap_path(language, domain))
# cap media
models.append(hf_cap_media_path("", ""))
# cap minor media
models.append(hf_cap_minor_media_path("", "", False))
# emotion9
for language in languages_emotion9:
models.append(hf_emotion9_path(language))
# illframes (domains is a dict for some reason?)
for domain in domains_illframes.values():
models.append(hf_illframes_path(domain))
tokenizers = ["xlm-roberta-large"]
def download_hf_models():
# Ensure the JIT model directory exists
os.makedirs(JIT_DIR, exist_ok=True)
for model_id in models:
print(f"Downloading + JIT tracing model: {model_id}")
# Load model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained(
model_id,
token=HF_TOKEN,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large")
safe_model_name = model_id.replace("/", "_")
traced_model_path = os.path.join(JIT_DIR, f"{safe_model_name}.pt")
if os.path.exists(traced_model_path):
print(f"⏩ Skipping JIT β€” already exists: {traced_model_path}")
else:
print(f"βš™οΈ Tracing and saving: {traced_model_path}")
model.eval()
# Dummy input for tracing
dummy_input = tokenizer(
"Hello, world!",
return_tensors="pt",
padding=True,
truncation=True,
max_length=256
)
# JIT trace
traced_model = torch.jit.trace(
model,
(dummy_input["input_ids"], dummy_input["attention_mask"]),
strict=False
)
# Save traced model
traced_model.save(traced_model_path)
print(f"βœ”οΈ Saved JIT model to: {traced_model_path}")
def df_h():
result = subprocess.run(["df", "-H"], capture_output=True, text=True)
print(result.stdout)
def scan_cache():
# Scan Hugging Face model cache
cache_dir = os.environ.get("TRANSFORMERS_CACHE", os.path.expanduser("~/.cache/huggingface/transformers"))
scan_result = scan_cache_dir(cache_dir)
print("=== πŸ€— Hugging Face Model Cache ===")
print(f"Cache size: {scan_result.size_on_disk / 1e6:.2f} MB")
print(f"Number of repos: {len(scan_result.repos)}")
for repo in scan_result.repos:
print(f"- {repo.repo_id} ({repo.repo_type}) β€” {repo.size_on_disk / 1e6:.2f} MB")
print("\n=== 🧊 TorchScript JIT Cache ===")
if not os.path.exists(JIT_DIR):
print(f"(Directory does not exist: {JIT_DIR})")
return
total_size = 0
for filename in os.listdir(JIT_DIR):
if filename.endswith(".pt"):
path = os.path.join(JIT_DIR, filename)
size = os.path.getsize(path)
total_size += size
print(f"- {filename}: {size / 1e6:.2f} MB")
print(f"Total JIT cache size: {total_size / 1e6:.2f} MB")
def set_hf_cache_dir(path:str):
os.environ['TRANSFORMERS_CACHE'] = path
os.environ['HF_HOME'] = path
os.environ['HF_DATASETS_CACHE'] = path
os.environ['TORCH_HOME'] = path
def is_disk_full(min_free_space_in_GB=10):
total, used, free = shutil.disk_usage("/")
free_gb = free / (1024 ** 3)
if free_gb >= min_free_space_in_GB:
return False
else:
return True