TiRex-demo / tirex /base.py
Nikita
added tirex as model
14d91dc
# Copyright (c) NXAI GmbH.
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
import os
from abc import ABC, abstractmethod
from typing import TypeVar
from huggingface_hub import hf_hub_download
T = TypeVar("T", bound="PretrainedModel")
def parse_hf_repo_id(path):
parts = path.split("/")
return "/".join(parts[0:2])
class PretrainedModel(ABC):
REGISTRY: dict[str, "PretrainedModel"] = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.REGISTRY[cls.register_name()] = cls
@classmethod
def from_pretrained(cls: type[T], path, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=None) -> T:
if hf_kwargs is None:
hf_kwargs = {}
if ckp_kwargs is None:
ckp_kwargs = {}
if os.path.exists(path):
print("Loading weights from local directory")
checkpoint_path = path
else:
repo_id = parse_hf_repo_id(path)
checkpoint_path = hf_hub_download(repo_id=repo_id, filename="model.ckpt", **hf_kwargs)
model = cls.load_from_checkpoint(checkpoint_path, map_location=device, **ckp_kwargs)
model.after_load_from_checkpoint()
return model
@classmethod
@abstractmethod
def register_name(cls) -> str:
pass
def after_load_from_checkpoint(self):
pass
def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=None) -> PretrainedModel:
"""Loads a TiRex model. This function attempts to load the specified model.
Args:
path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
If you want to use "cpu" you need to deactivate the sLSTM CUDA kernels (check repository FAQ!).
hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
Returns:
PretrainedModel: The loaded model.
Examples:
model: ForecastModel = load_model("NX-AI/TiRex")
"""
try:
_, model_id = parse_hf_repo_id(path).split("/")
except:
raise ValueError(f"Invalid model path {path}")
model_cls = PretrainedModel.REGISTRY.get(model_id, None)
if model_cls is None:
raise ValueError(f"Invalid model id {model_id}")
return model_cls.from_pretrained(path, device=device, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs)