Spaces:
Sleeping
Sleeping
import os | |
from abc import ABC, abstractmethod | |
from transformers import pipeline | |
LLM_MODEL_REGISTRY = {} | |
hf_token = os.getenv("HF_TOKEN") | |
class AbstractLLMModel(ABC): | |
def __init__( | |
self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs | |
): | |
print(f"Loading LLM model {model_id}...") | |
self.model_id = model_id | |
self.device = device | |
self.cache_dir = cache_dir | |
def generate(self, prompt: str, **kwargs) -> str: | |
pass | |
def register_llm_model(prefix: str): | |
def wrapper(cls): | |
assert issubclass(cls, AbstractLLMModel), f"{cls} must inherit AbstractLLMModel" | |
LLM_MODEL_REGISTRY[prefix] = cls | |
return cls | |
return wrapper | |
def get_llm_model(model_id: str, device="cpu", **kwargs) -> AbstractLLMModel: | |
for prefix, cls in LLM_MODEL_REGISTRY.items(): | |
if model_id.startswith(prefix): | |
return cls(model_id, device=device, **kwargs) | |
raise ValueError(f"No LLM wrapper found for model: {model_id}") | |
# e.g., Falcon | |
class HFTextGenerationLLM(AbstractLLMModel): | |
def __init__( | |
self, model_id: str, device: str = "cpu", cache_dir: str = "cache", **kwargs | |
): | |
super().__init__(model_id, device, cache_dir, **kwargs) | |
model_kwargs = kwargs.setdefault("model_kwargs", {}) | |
model_kwargs["cache_dir"] = cache_dir | |
self.pipe = pipeline( | |
"text-generation", | |
model=model_id, | |
device=0 if device == "cuda" else -1, | |
return_full_text=False, | |
token=hf_token, | |
**kwargs, | |
) | |
def generate(self, prompt: str, **kwargs) -> str: | |
outputs = self.pipe(prompt, **kwargs) | |
return outputs[0]["generated_text"] if outputs else "" | |