Spaces:
Runtime error
Runtime error
import logging | |
from pathlib import Path | |
from functools import lru_cache | |
import torch | |
from transformers import ( | |
pipeline, | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
AutoModelForSeq2SeqLM, | |
AutoModelForMaskedLM, | |
) | |
from sentence_transformers import SentenceTransformer | |
from app.core.config import ( | |
MODELS_DIR, SPACY_MODEL_ID, SENTENCE_TRANSFORMER_MODEL_ID, | |
OFFLINE_MODE | |
) | |
from app.core.exceptions import ModelNotDownloadedError | |
logger = logging.getLogger(__name__) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# π§ SpaCy | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def load_spacy_model(model_id: str = SPACY_MODEL_ID): | |
import spacy | |
from spacy.util import is_package | |
logger.info(f"Loading spaCy model: {model_id}") | |
if is_package(model_id): | |
return spacy.load(model_id) | |
possible_path = MODELS_DIR / model_id | |
if possible_path.exists(): | |
return spacy.load(str(possible_path)) | |
raise RuntimeError(f"Could not find spaCy model '{model_id}' at {possible_path}") | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# π€ Sentence Transformers | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def load_sentence_transformer_model(model_id: str = SENTENCE_TRANSFORMER_MODEL_ID) -> SentenceTransformer: | |
logger.info(f"Loading SentenceTransformer: {model_id}") | |
return SentenceTransformer(model_name_or_path=model_id, cache_folder=MODELS_DIR) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# π€ Hugging Face Pipelines (T5 models, classifiers, etc.) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def _check_model_downloaded(model_id: str, cache_dir: str) -> bool: | |
model_path = Path(cache_dir) / model_id.replace("/", "_") | |
return model_path.exists() | |
def _timed_load(name: str, fn): | |
import time | |
start = time.time() | |
model = fn() | |
elapsed = round(time.time() - start, 2) | |
logger.info(f"[{name}] model loaded in {elapsed}s") | |
return model | |
def load_hf_pipeline(model_id: str, task: str, feature_name: str, **kwargs): | |
if OFFLINE_MODE and not _check_model_downloaded(model_id, str(MODELS_DIR)): | |
raise ModelNotDownloadedError(model_id, feature_name, "Model not found locally in offline mode.") | |
try: | |
# Choose appropriate AutoModel loader based on task | |
if task == "text-classification": | |
model_loader = AutoModelForSequenceClassification | |
elif task == "text2text-generation" or task.startswith("translation"): | |
model_loader = AutoModelForSeq2SeqLM | |
elif task == "fill-mask": | |
model_loader = AutoModelForMaskedLM | |
else: | |
raise ValueError(f"Unsupported task type '{task}' for feature '{feature_name}'.") | |
model = _timed_load( | |
f"{feature_name}:{model_id} (model)", | |
lambda: model_loader.from_pretrained( | |
model_id, | |
cache_dir=MODELS_DIR, | |
local_files_only=OFFLINE_MODE | |
) | |
) | |
tokenizer = _timed_load( | |
f"{feature_name}:{model_id} (tokenizer)", | |
lambda: AutoTokenizer.from_pretrained( | |
model_id, | |
cache_dir=MODELS_DIR, | |
local_files_only=OFFLINE_MODE | |
) | |
) | |
return pipeline( | |
task=task, | |
model=model, | |
tokenizer=tokenizer, | |
device=0 if torch.cuda.is_available() else -1, | |
**kwargs | |
) | |
except Exception as e: | |
logger.error(f"Failed to load pipeline for '{feature_name}' - {model_id}: {e}", exc_info=True) | |
raise ModelNotDownloadedError(model_id, feature_name, str(e)) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# π NLTK | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def ensure_nltk_resource(resource_name: str = "wordnet") -> None: | |
try: | |
import nltk | |
nltk.data.find(f"corpora/{resource_name}") | |
except (LookupError, ImportError): | |
if OFFLINE_MODE: | |
raise RuntimeError(f"NLTK resource '{resource_name}' not found in offline mode.") | |
nltk.download(resource_name) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# π― Ready-to-use Loaders (for your app use) | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |