Spaces:
Runtime error
Runtime error
File size: 6,190 Bytes
71192d1 73a6a7e 71192d1 ce2ce69 71192d1 73a6a7e 71192d1 73a6a7e 71192d1 73a6a7e 71192d1 73a6a7e 71192d1 73a6a7e 71192d1 73a6a7e ce2ce69 71192d1 73a6a7e 71192d1 73a6a7e ce2ce69 73a6a7e 71192d1 73a6a7e ce2ce69 73a6a7e 71192d1 73a6a7e ce2ce69 73a6a7e ce2ce69 73a6a7e ce2ce69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import logging
from pathlib import Path
from functools import lru_cache
import torch
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM,
AutoModelForMaskedLM,
)
from sentence_transformers import SentenceTransformer
from app.core.config import (
MODELS_DIR, SPACY_MODEL_ID, SENTENCE_TRANSFORMER_MODEL_ID,
OFFLINE_MODE
)
from app.core.exceptions import ModelNotDownloadedError
logger = logging.getLogger(__name__)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# π§ SpaCy
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@lru_cache(maxsize=1)
def load_spacy_model(model_id: str = SPACY_MODEL_ID):
import spacy
from spacy.util import is_package
logger.info(f"Loading spaCy model: {model_id}")
if is_package(model_id):
return spacy.load(model_id)
possible_path = MODELS_DIR / model_id
if possible_path.exists():
return spacy.load(str(possible_path))
raise RuntimeError(f"Could not find spaCy model '{model_id}' at {possible_path}")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# π€ Sentence Transformers
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@lru_cache(maxsize=1)
def load_sentence_transformer_model(model_id: str = SENTENCE_TRANSFORMER_MODEL_ID) -> SentenceTransformer:
logger.info(f"Loading SentenceTransformer: {model_id}")
return SentenceTransformer(model_name_or_path=model_id, cache_folder=MODELS_DIR)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# π€ Hugging Face Pipelines (T5 models, classifiers, etc.)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _check_model_downloaded(model_id: str, cache_dir: str) -> bool:
model_path = Path(cache_dir) / model_id.replace("/", "_")
return model_path.exists()
def _timed_load(name: str, fn):
import time
start = time.time()
model = fn()
elapsed = round(time.time() - start, 2)
logger.info(f"[{name}] model loaded in {elapsed}s")
return model
@lru_cache(maxsize=2)
def load_hf_pipeline(model_id: str, task: str, feature_name: str, **kwargs):
if OFFLINE_MODE and not _check_model_downloaded(model_id, str(MODELS_DIR)):
raise ModelNotDownloadedError(model_id, feature_name, "Model not found locally in offline mode.")
try:
# Choose appropriate AutoModel loader based on task
if task == "text-classification":
model_loader = AutoModelForSequenceClassification
elif task == "text2text-generation" or task.startswith("translation"):
model_loader = AutoModelForSeq2SeqLM
elif task == "fill-mask":
model_loader = AutoModelForMaskedLM
else:
raise ValueError(f"Unsupported task type '{task}' for feature '{feature_name}'.")
model = _timed_load(
f"{feature_name}:{model_id} (model)",
lambda: model_loader.from_pretrained(
model_id,
cache_dir=MODELS_DIR,
local_files_only=OFFLINE_MODE
)
)
tokenizer = _timed_load(
f"{feature_name}:{model_id} (tokenizer)",
lambda: AutoTokenizer.from_pretrained(
model_id,
cache_dir=MODELS_DIR,
local_files_only=OFFLINE_MODE
)
)
return pipeline(
task=task,
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1,
**kwargs
)
except Exception as e:
logger.error(f"Failed to load pipeline for '{feature_name}' - {model_id}: {e}", exc_info=True)
raise ModelNotDownloadedError(model_id, feature_name, str(e))
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# π NLTK
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@lru_cache(maxsize=1)
def ensure_nltk_resource(resource_name: str = "wordnet") -> None:
try:
import nltk
nltk.data.find(f"corpora/{resource_name}")
except (LookupError, ImportError):
if OFFLINE_MODE:
raise RuntimeError(f"NLTK resource '{resource_name}' not found in offline mode.")
nltk.download(resource_name)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# π― Ready-to-use Loaders (for your app use)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|