Spaces:
Sleeping
Sleeping
import logging | |
from typing import List, Any, Optional, Tuple | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
logger = logging.getLogger(__name__) | |
# Cache for loaded models | |
_model_cache = {} | |
def get_model(model_id: str) -> Tuple[Optional[SentenceTransformer], Optional[str]]: | |
""" | |
Loads a SentenceTransformer model from the Hugging Face Hub. | |
Args: | |
model_id (str): The identifier for the model to load (e.g., 'sentence-transformers/LaBSE'). | |
Returns: | |
Tuple[Optional[SentenceTransformer], Optional[str]]: A tuple containing the loaded model and its type ('sentence-transformer'), | |
or (None, None) if loading fails. | |
""" | |
if model_id in _model_cache: | |
logger.info(f"Returning cached model: {model_id}") | |
return _model_cache[model_id], "sentence-transformer" | |
logger.info(f"Loading SentenceTransformer model: {model_id}") | |
try: | |
model = SentenceTransformer(model_id) | |
_model_cache[model_id] = model | |
logger.info(f"Model '{model_id}' loaded successfully.") | |
return model, "sentence-transformer" | |
except Exception as e: | |
logger.error(f"Failed to load SentenceTransformer model '{model_id}': {e}", exc_info=True) | |
return None, None | |
def generate_embeddings(texts: List[str], model: SentenceTransformer) -> Optional[np.ndarray]: | |
""" | |
Generates embeddings for a list of texts using a SentenceTransformer model. | |
Args: | |
texts (list[str]): A list of texts to embed. | |
model (SentenceTransformer): The loaded SentenceTransformer model. | |
Returns: | |
Optional[np.ndarray]: A numpy array containing the embeddings. Returns None if generation fails. | |
""" | |
if not texts or not isinstance(model, SentenceTransformer): | |
logger.warning("Invalid input for generating embeddings. Texts list is empty or model is not a SentenceTransformer.") | |
return None | |
logger.info(f"Generating embeddings for {len(texts)} texts with {type(model).__name__}...") | |
try: | |
embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=False) | |
logger.info(f"Embeddings generated with shape: {embeddings.shape}") | |
return embeddings | |
except Exception as e: | |
logger.error(f"An unexpected error occurred during embedding generation: {e}", exc_info=True) | |
return None | |