Spaces:
Running
Running
File size: 2,423 Bytes
b2ce320 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import logging
from typing import List, Any, Optional, Tuple
import numpy as np
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
# Cache for loaded models
_model_cache = {}
def get_model(model_id: str) -> Tuple[Optional[SentenceTransformer], Optional[str]]:
"""
Loads a SentenceTransformer model from the Hugging Face Hub.
Args:
model_id (str): The identifier for the model to load (e.g., 'sentence-transformers/LaBSE').
Returns:
Tuple[Optional[SentenceTransformer], Optional[str]]: A tuple containing the loaded model and its type ('sentence-transformer'),
or (None, None) if loading fails.
"""
if model_id in _model_cache:
logger.info(f"Returning cached model: {model_id}")
return _model_cache[model_id], "sentence-transformer"
logger.info(f"Loading SentenceTransformer model: {model_id}")
try:
model = SentenceTransformer(model_id)
_model_cache[model_id] = model
logger.info(f"Model '{model_id}' loaded successfully.")
return model, "sentence-transformer"
except Exception as e:
logger.error(f"Failed to load SentenceTransformer model '{model_id}': {e}", exc_info=True)
return None, None
def generate_embeddings(texts: List[str], model: SentenceTransformer) -> Optional[np.ndarray]:
"""
Generates embeddings for a list of texts using a SentenceTransformer model.
Args:
texts (list[str]): A list of texts to embed.
model (SentenceTransformer): The loaded SentenceTransformer model.
Returns:
Optional[np.ndarray]: A numpy array containing the embeddings. Returns None if generation fails.
"""
if not texts or not isinstance(model, SentenceTransformer):
logger.warning("Invalid input for generating embeddings. Texts list is empty or model is not a SentenceTransformer.")
return None
logger.info(f"Generating embeddings for {len(texts)} texts with {type(model).__name__}...")
try:
embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
logger.info(f"Embeddings generated with shape: {embeddings.shape}")
return embeddings
except Exception as e:
logger.error(f"An unexpected error occurred during embedding generation: {e}", exc_info=True)
return None
|