File size: 6,190 Bytes
71192d1
73a6a7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71192d1
ce2ce69
71192d1
73a6a7e
 
 
 
 
 
 
 
71192d1
73a6a7e
71192d1
73a6a7e
 
71192d1
73a6a7e
 
 
71192d1
73a6a7e
71192d1
73a6a7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce2ce69
71192d1
73a6a7e
 
 
71192d1
 
73a6a7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce2ce69
73a6a7e
 
 
 
 
 
 
 
71192d1
73a6a7e
 
 
 
 
 
 
ce2ce69
73a6a7e
 
 
71192d1
73a6a7e
 
 
ce2ce69
73a6a7e
 
 
 
 
 
 
 
 
ce2ce69
73a6a7e
 
 
ce2ce69
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import logging
from pathlib import Path
from functools import lru_cache

import torch
from transformers import (
    pipeline,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForSeq2SeqLM,
    AutoModelForMaskedLM,
)

from sentence_transformers import SentenceTransformer

from app.core.config import (
    MODELS_DIR, SPACY_MODEL_ID, SENTENCE_TRANSFORMER_MODEL_ID,
    OFFLINE_MODE
)
from app.core.exceptions import ModelNotDownloadedError

logger = logging.getLogger(__name__)

# ─────────────────────────────────────────────────────────────────────────────
# 🧠 SpaCy
# ─────────────────────────────────────────────────────────────────────────────

@lru_cache(maxsize=1)
def load_spacy_model(model_id: str = SPACY_MODEL_ID):
    import spacy
    from spacy.util import is_package

    logger.info(f"Loading spaCy model: {model_id}")

    if is_package(model_id):
        return spacy.load(model_id)

    possible_path = MODELS_DIR / model_id
    if possible_path.exists():
        return spacy.load(str(possible_path))

    raise RuntimeError(f"Could not find spaCy model '{model_id}' at {possible_path}")

# ─────────────────────────────────────────────────────────────────────────────
# πŸ”€ Sentence Transformers
# ─────────────────────────────────────────────────────────────────────────────

@lru_cache(maxsize=1)
def load_sentence_transformer_model(model_id: str = SENTENCE_TRANSFORMER_MODEL_ID) -> SentenceTransformer:
    logger.info(f"Loading SentenceTransformer: {model_id}")
    return SentenceTransformer(model_name_or_path=model_id, cache_folder=MODELS_DIR)

# ─────────────────────────────────────────────────────────────────────────────
# πŸ€— Hugging Face Pipelines (T5 models, classifiers, etc.)
# ─────────────────────────────────────────────────────────────────────────────

def _check_model_downloaded(model_id: str, cache_dir: str) -> bool:
    model_path = Path(cache_dir) / model_id.replace("/", "_")
    return model_path.exists()

def _timed_load(name: str, fn):
    import time
    start = time.time()
    model = fn()
    elapsed = round(time.time() - start, 2)
    logger.info(f"[{name}] model loaded in {elapsed}s")
    return model

@lru_cache(maxsize=2)
def load_hf_pipeline(model_id: str, task: str, feature_name: str, **kwargs):
    if OFFLINE_MODE and not _check_model_downloaded(model_id, str(MODELS_DIR)):
        raise ModelNotDownloadedError(model_id, feature_name, "Model not found locally in offline mode.")

    try:
        # Choose appropriate AutoModel loader based on task
        if task == "text-classification":
            model_loader = AutoModelForSequenceClassification
        elif task == "text2text-generation" or task.startswith("translation"):
            model_loader = AutoModelForSeq2SeqLM
        elif task == "fill-mask":
            model_loader = AutoModelForMaskedLM
        else:
            raise ValueError(f"Unsupported task type '{task}' for feature '{feature_name}'.")

        model = _timed_load(
            f"{feature_name}:{model_id} (model)",
            lambda: model_loader.from_pretrained(
                model_id,
                cache_dir=MODELS_DIR,
                local_files_only=OFFLINE_MODE
            )
        )

        tokenizer = _timed_load(
            f"{feature_name}:{model_id} (tokenizer)",
            lambda: AutoTokenizer.from_pretrained(
                model_id,
                cache_dir=MODELS_DIR,
                local_files_only=OFFLINE_MODE
            )
        )

        return pipeline(
            task=task,
            model=model,
            tokenizer=tokenizer,
            device=0 if torch.cuda.is_available() else -1,
            **kwargs
        )

    except Exception as e:
        logger.error(f"Failed to load pipeline for '{feature_name}' - {model_id}: {e}", exc_info=True)
        raise ModelNotDownloadedError(model_id, feature_name, str(e))

# ─────────────────────────────────────────────────────────────────────────────
# πŸ“š NLTK
# ─────────────────────────────────────────────────────────────────────────────

@lru_cache(maxsize=1)
def ensure_nltk_resource(resource_name: str = "wordnet") -> None:
    try:
        import nltk
        nltk.data.find(f"corpora/{resource_name}")
    except (LookupError, ImportError):
        if OFFLINE_MODE:
            raise RuntimeError(f"NLTK resource '{resource_name}' not found in offline mode.")
        nltk.download(resource_name)

# ─────────────────────────────────────────────────────────────────────────────
# 🎯 Ready-to-use Loaders (for your app use)
# ─────────────────────────────────────────────────────────────────────────────