Spaces:
Sleeping
Sleeping
# utils/helpers.py | |
"""Helper functions for model loading and embedding generation""" | |
import torch | |
import torch.nn.functional as F | |
from transformers import ( | |
AutoTokenizer, AutoModel, | |
RobertaTokenizer, RobertaModel, | |
BertTokenizer, BertModel | |
) | |
from typing import List, Dict, Optional | |
import gc | |
import os | |
def load_models(model_names: List[str] = None) -> Dict: | |
""" | |
Load specific embedding models with memory optimization | |
Args: | |
model_names: List of model names to load. If None, loads all models. | |
Returns: | |
Dict containing loaded models and tokenizers | |
""" | |
models_cache = {} | |
# Default to all models if none specified | |
if model_names is None: | |
model_names = ["jina", "robertalex", "jina-v3", "legal-bert", "roberta-ca"] | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
try: | |
# Load Jina v2 Spanish model | |
if "jina" in model_names: | |
print("Loading Jina embeddings v2 Spanish model...") | |
jina_tokenizer = AutoTokenizer.from_pretrained( | |
'jinaai/jina-embeddings-v2-base-es', | |
trust_remote_code=True | |
) | |
jina_model = AutoModel.from_pretrained( | |
'jinaai/jina-embeddings-v2-base-es', | |
trust_remote_code=True, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
).to(device) | |
jina_model.eval() | |
models_cache['jina'] = { | |
'tokenizer': jina_tokenizer, | |
'model': jina_model, | |
'device': device, | |
'pooling': 'mean' | |
} | |
# Load RoBERTalex model | |
if "robertalex" in model_names: | |
print("Loading RoBERTalex model...") | |
robertalex_tokenizer = RobertaTokenizer.from_pretrained('PlanTL-GOB-ES/RoBERTalex') | |
robertalex_model = RobertaModel.from_pretrained( | |
'PlanTL-GOB-ES/RoBERTalex', | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
).to(device) | |
robertalex_model.eval() | |
models_cache['robertalex'] = { | |
'tokenizer': robertalex_tokenizer, | |
'model': robertalex_model, | |
'device': device, | |
'pooling': 'cls' | |
} | |
# Load Jina v3 model | |
if "jina-v3" in model_names: | |
print("Loading Jina embeddings v3 model...") | |
jina_v3_tokenizer = AutoTokenizer.from_pretrained( | |
'jinaai/jina-embeddings-v3', | |
trust_remote_code=True | |
) | |
jina_v3_model = AutoModel.from_pretrained( | |
'jinaai/jina-embeddings-v3', | |
trust_remote_code=True, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
).to(device) | |
jina_v3_model.eval() | |
models_cache['jina-v3'] = { | |
'tokenizer': jina_v3_tokenizer, | |
'model': jina_v3_model, | |
'device': device, | |
'pooling': 'mean' | |
} | |
# Load Legal BERT model | |
if "legal-bert" in model_names: | |
print("Loading Legal BERT model...") | |
legal_bert_tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased') | |
legal_bert_model = BertModel.from_pretrained( | |
'nlpaueb/legal-bert-base-uncased', | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
).to(device) | |
legal_bert_model.eval() | |
models_cache['legal-bert'] = { | |
'tokenizer': legal_bert_tokenizer, | |
'model': legal_bert_model, | |
'device': device, | |
'pooling': 'cls' | |
} | |
# Load Catalan RoBERTa model | |
if "roberta-ca" in model_names: | |
print("Loading Catalan RoBERTa-large model...") | |
roberta_ca_tokenizer = AutoTokenizer.from_pretrained('projecte-aina/roberta-large-ca-v2') | |
roberta_ca_model = AutoModel.from_pretrained( | |
'projecte-aina/roberta-large-ca-v2', | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
).to(device) | |
roberta_ca_model.eval() | |
models_cache['roberta-ca'] = { | |
'tokenizer': roberta_ca_tokenizer, | |
'model': roberta_ca_model, | |
'device': device, | |
'pooling': 'cls' | |
} | |
# Force garbage collection after loading | |
gc.collect() | |
return models_cache | |
except Exception as e: | |
print(f"Error loading models: {str(e)}") | |
raise | |
def mean_pooling(model_output, attention_mask): | |
""" | |
Apply mean pooling to get sentence embeddings | |
Args: | |
model_output: Model output containing token embeddings | |
attention_mask: Attention mask for valid tokens | |
Returns: | |
Pooled embeddings | |
""" | |
token_embeddings = model_output[0] | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
def get_embeddings( | |
texts: List[str], | |
model_name: str, | |
models_cache: Dict, | |
normalize: bool = True, | |
max_length: Optional[int] = None | |
) -> List[List[float]]: | |
""" | |
Generate embeddings for texts using specified model | |
Args: | |
texts: List of texts to embed | |
model_name: Name of model to use | |
models_cache: Dictionary containing loaded models | |
normalize: Whether to normalize embeddings | |
max_length: Maximum sequence length | |
Returns: | |
List of embedding vectors | |
""" | |
if model_name not in models_cache: | |
raise ValueError(f"Model {model_name} not available. Choose from: {list(models_cache.keys())}") | |
tokenizer = models_cache[model_name]['tokenizer'] | |
model = models_cache[model_name]['model'] | |
device = models_cache[model_name]['device'] | |
pooling_strategy = models_cache[model_name]['pooling'] | |
# Set max length based on model capabilities | |
if max_length is None: | |
if model_name in ['jina', 'jina-v3']: | |
max_length = 8192 | |
else: # robertalex, legal-bert, roberta-ca | |
max_length = 512 | |
# Process in batches for memory efficiency | |
# Reduce batch size for large models | |
if model_name in ['jina-v3', 'roberta-ca']: | |
batch_size = 4 if len(texts) > 4 else len(texts) | |
else: | |
batch_size = 8 if len(texts) > 8 else len(texts) | |
all_embeddings = [] | |
for i in range(0, len(texts), batch_size): | |
batch_texts = texts[i:i + batch_size] | |
# Tokenize inputs | |
encoded_input = tokenizer( | |
batch_texts, | |
padding=True, | |
truncation=True, | |
max_length=max_length, | |
return_tensors='pt' | |
).to(device) | |
# Generate embeddings | |
with torch.no_grad(): | |
model_output = model(**encoded_input) | |
if pooling_strategy == 'mean': | |
# Mean pooling for Jina models | |
embeddings = mean_pooling(model_output, encoded_input['attention_mask']) | |
else: | |
# CLS token for BERT-based models | |
embeddings = model_output.last_hidden_state[:, 0, :] | |
# Normalize if requested | |
if normalize: | |
embeddings = F.normalize(embeddings, p=2, dim=1) | |
# Convert to CPU and list | |
batch_embeddings = embeddings.cpu().numpy().tolist() | |
all_embeddings.extend(batch_embeddings) | |
return all_embeddings | |
def cleanup_memory(): | |
"""Force garbage collection and clear cache""" | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
def validate_input_texts(texts: List[str]) -> List[str]: | |
""" | |
Validate and clean input texts | |
Args: | |
texts: List of input texts | |
Returns: | |
Cleaned texts | |
""" | |
cleaned_texts = [] | |
for text in texts: | |
# Remove excess whitespace | |
text = ' '.join(text.split()) | |
# Skip empty texts | |
if text: | |
cleaned_texts.append(text) | |
if not cleaned_texts: | |
raise ValueError("No valid texts provided after cleaning") | |
return cleaned_texts | |
def get_model_info(model_name: str) -> Dict: | |
""" | |
Get detailed information about a model | |
Args: | |
model_name: Model identifier | |
Returns: | |
Dictionary with model information | |
""" | |
model_info = { | |
'jina': { | |
'full_name': 'jinaai/jina-embeddings-v2-base-es', | |
'dimensions': 768, | |
'max_length': 8192, | |
'pooling': 'mean', | |
'languages': ['Spanish', 'English'] | |
}, | |
'robertalex': { | |
'full_name': 'PlanTL-GOB-ES/RoBERTalex', | |
'dimensions': 768, | |
'max_length': 512, | |
'pooling': 'cls', | |
'languages': ['Spanish'] | |
}, | |
'jina-v3': { | |
'full_name': 'jinaai/jina-embeddings-v3', | |
'dimensions': 1024, | |
'max_length': 8192, | |
'pooling': 'mean', | |
'languages': ['Multilingual'] | |
}, | |
'legal-bert': { | |
'full_name': 'nlpaueb/legal-bert-base-uncased', | |
'dimensions': 768, | |
'max_length': 512, | |
'pooling': 'cls', | |
'languages': ['English'] | |
}, | |
'roberta-ca': { | |
'full_name': 'projecte-aina/roberta-large-ca-v2', | |
'dimensions': 1024, | |
'max_length': 512, | |
'pooling': 'cls', | |
'languages': ['Catalan'] | |
} | |
} | |
return model_info.get(model_name, {}) |