Spaces:
Sleeping
Sleeping
""" | |
Embedding Generator Module | |
This module is responsible for generating vector embeddings for text chunks | |
using Gemini Embedding v3 with complete API integration. | |
Technology: Gemini Embedding v3 (gemini-embedding-exp-03-07) | |
""" | |
import logging | |
import os | |
import time | |
import hashlib | |
from datetime import datetime, timedelta | |
from typing import Dict, List, Any, Optional, Union | |
import json | |
# Import Gemini API and caching libraries | |
try: | |
import google.generativeai as genai | |
from cachetools import TTLCache | |
except ImportError as e: | |
logging.warning(f"Some embedding libraries are not installed: {e}") | |
from utils.error_handler import EmbeddingError, error_handler, ErrorType | |
class EmbeddingGenerator: | |
""" | |
Generates vector embeddings for text chunks using Gemini Embedding v3 with full functionality. | |
Features: | |
- Gemini Embedding v3 API integration | |
- Batch processing with rate limiting | |
- Intelligent retry logic with exponential backoff | |
- Embedding caching mechanism | |
- Cost optimization | |
""" | |
def __init__(self, config: Optional[Dict[str, Any]] = None): | |
""" | |
Initialize the EmbeddingGenerator with configuration. | |
Args: | |
config: Configuration dictionary with API parameters | |
""" | |
self.config = config or {} | |
self.logger = logging.getLogger(__name__) | |
# API Configuration | |
self.api_key = self.config.get("api_key", os.environ.get("GEMINI_API_KEY")) | |
self.model = self.config.get("model", "gemini-embedding-exp-03-07") | |
self.batch_size = self.config.get("batch_size", 5) | |
self.max_retries = self.config.get("max_retries", 3) | |
self.retry_delay = self.config.get("retry_delay", 1) | |
# Performance settings | |
self.rate_limit_delay = self.config.get("rate_limit_delay", 0.1) | |
self.max_text_length = self.config.get( | |
"max_text_length", 8192 | |
) # ✨ 8K token limit for latest model | |
self.enable_caching = self.config.get("enable_caching", True) | |
self.cache_ttl = self.config.get("cache_ttl", 3600) # 1 hour | |
# Statistics tracking | |
self.stats = { | |
"total_requests": 0, | |
"successful_requests": 0, | |
"failed_requests": 0, | |
"cache_hits": 0, | |
"total_tokens_processed": 0, | |
"start_time": datetime.now(), | |
} | |
# Initialize cache | |
if self.enable_caching: | |
self.cache = TTLCache(maxsize=1000, ttl=self.cache_ttl) | |
else: | |
self.cache = None | |
# Validate and initialize API client | |
self._initialize_client() | |
def _initialize_client(self): | |
"""Initialize Gemini API client with validation.""" | |
if not self.api_key: | |
self.logger.warning( | |
"No Gemini API key provided. Embeddings will not be generated." | |
) | |
self.client = None | |
return | |
try: | |
# Configure Gemini API | |
genai.configure(api_key=self.api_key) | |
# Test API connection | |
self._test_api_connection() | |
self.client = genai | |
self.logger.info("Gemini API client initialized successfully") | |
except Exception as e: | |
self.logger.error(f"Failed to initialize Gemini API client: {str(e)}") | |
self.client = None | |
def _test_api_connection(self): | |
"""Test API connection with a simple request.""" | |
try: | |
# Test with a simple embedding request | |
test_result = genai.embed_content( | |
model=self.model, | |
content="test connection", | |
task_type="retrieval_document", | |
) | |
if not test_result.get("embedding"): | |
raise Exception("No embedding returned from test request") | |
self.logger.info("API connection test successful") | |
except Exception as e: | |
raise EmbeddingError(f"API connection test failed: {str(e)}") | |
def generate_embeddings(self, texts: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
""" | |
Generate embeddings for a list of text chunks with full functionality. | |
Args: | |
texts: List of dictionaries containing text chunks and metadata | |
Each dict should have 'content' and 'metadata' keys | |
Returns: | |
List of dictionaries with original content, metadata, and embeddings | |
""" | |
if not self.client or not texts: | |
self.logger.warning("No API client or empty text list") | |
return texts | |
self.logger.info(f"Generating embeddings for {len(texts)} text chunks") | |
start_time = time.time() | |
# Filter and validate texts | |
valid_texts = self._validate_texts(texts) | |
if not valid_texts: | |
self.logger.warning("No valid texts to process") | |
return texts | |
# Process in batches to respect API limits | |
results = [] | |
total_batches = (len(valid_texts) + self.batch_size - 1) // self.batch_size | |
for i in range(0, len(valid_texts), self.batch_size): | |
batch_num = (i // self.batch_size) + 1 | |
batch = valid_texts[i : i + self.batch_size] | |
self.logger.info( | |
f"Processing batch {batch_num}/{total_batches} ({len(batch)} items)" | |
) | |
try: | |
batch_results = self._process_batch(batch) | |
results.extend(batch_results) | |
# Rate limiting between batches | |
if i + self.batch_size < len(valid_texts): | |
time.sleep(self.rate_limit_delay) | |
except Exception as e: | |
self.logger.error(f"Batch {batch_num} failed: {str(e)}") | |
# Add original items without embeddings | |
for item in batch: | |
item_copy = item.copy() | |
item_copy["embedding"] = [] | |
item_copy["embedding_error"] = str(e) | |
results.append(item_copy) | |
# Update statistics | |
processing_time = time.time() - start_time | |
self.logger.info(f"Embedding generation completed in {processing_time:.2f}s") | |
return results | |
def _validate_texts(self, texts: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
""" | |
Validate and filter text inputs. | |
Args: | |
texts: List of text dictionaries | |
Returns: | |
List of valid text dictionaries | |
""" | |
valid_texts = [] | |
for i, item in enumerate(texts): | |
if not isinstance(item, dict) or "content" not in item: | |
self.logger.warning(f"Invalid item at index {i}: missing 'content' key") | |
continue | |
content = item["content"] | |
if not content or not isinstance(content, str): | |
self.logger.warning( | |
f"Invalid content at index {i}: empty or non-string" | |
) | |
continue | |
# Truncate if too long | |
if len(content) > self.max_text_length: | |
self.logger.warning( | |
f"Truncating text at index {i}: {len(content)} -> {self.max_text_length} chars" | |
) | |
item = item.copy() | |
item["content"] = content[: self.max_text_length] | |
item["metadata"] = item.get("metadata", {}) | |
item["metadata"]["truncated"] = True | |
item["metadata"]["original_length"] = len(content) | |
valid_texts.append(item) | |
return valid_texts | |
def _process_batch(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
""" | |
Process a batch of text chunks to generate embeddings. | |
Args: | |
batch: List of dictionaries containing text chunks and metadata | |
Returns: | |
List of dictionaries with original content, metadata, and embeddings | |
""" | |
# Extract content and check cache | |
contents = [] | |
cache_results = {} | |
for i, item in enumerate(batch): | |
content = item["content"] | |
# Check cache first | |
if self.cache is not None: | |
cache_key = self._get_cache_key(content) | |
if cache_key in self.cache: | |
cache_results[i] = self.cache[cache_key] | |
self.stats["cache_hits"] += 1 | |
continue | |
contents.append((i, content)) | |
# Generate embeddings for non-cached content | |
embeddings_map = {} | |
if contents: | |
content_texts = [content for _, content in contents] | |
embeddings = self._generate_with_retry(content_texts) | |
# Map embeddings back to indices | |
for j, (original_index, content) in enumerate(contents): | |
if j < len(embeddings): | |
embedding = embeddings[j] | |
embeddings_map[original_index] = embedding | |
# Cache the result | |
if self.cache is not None: | |
cache_key = self._get_cache_key(content) | |
self.cache[cache_key] = embedding | |
# 🔗 Combine results | |
results = [] | |
for i, item in enumerate(batch): | |
result = item.copy() | |
# Add embedding from cache or new generation | |
if i in cache_results: | |
result["embedding"] = cache_results[i] | |
result["embedding_source"] = "cache" | |
elif i in embeddings_map: | |
result["embedding"] = embeddings_map[i] | |
result["embedding_source"] = "api" | |
else: | |
result["embedding"] = [] | |
result["embedding_source"] = "failed" | |
self.logger.warning(f"Missing embedding for batch item {i}") | |
# Add embedding metadata | |
if result["embedding"]: | |
result["metadata"] = result.get("metadata", {}) | |
result["metadata"].update( | |
{ | |
"embedding_model": self.model, | |
"embedding_dimension": len(result["embedding"]), | |
"embedding_generated_at": datetime.now().isoformat(), | |
} | |
) | |
results.append(result) | |
return results | |
def _generate_with_retry(self, texts: List[str]) -> List[List[float]]: | |
""" | |
Generate embeddings with intelligent retry logic. | |
Args: | |
texts: List of text strings to embed | |
Returns: | |
List of embedding vectors (each is a list of floats) | |
""" | |
for attempt in range(self.max_retries): | |
try: | |
self.stats["total_requests"] += 1 | |
# Generate embeddings using Gemini API | |
embeddings = [] | |
for text in texts: | |
try: | |
# Track tokens | |
self.stats["total_tokens_processed"] += len(text.split()) | |
# Call Gemini API | |
result = self.client.embed_content( | |
model=self.model, | |
content=text, | |
task_type="retrieval_document", | |
title="Document chunk for RAG system", | |
) | |
if result and "embedding" in result: | |
embeddings.append(result["embedding"]) | |
else: | |
self.logger.warning( | |
f"No embedding in API response for text: {text[:50]}..." | |
) | |
embeddings.append([]) | |
except Exception as e: | |
self.logger.warning( | |
f"Failed to embed individual text: {str(e)}" | |
) | |
embeddings.append([]) | |
self.stats["successful_requests"] += 1 | |
return embeddings | |
except Exception as e: | |
self.stats["failed_requests"] += 1 | |
self.logger.warning( | |
f"Embedding generation failed (attempt {attempt+1}/{self.max_retries}): {str(e)}" | |
) | |
if attempt < self.max_retries - 1: | |
# Exponential backoff with jitter | |
delay = self.retry_delay * (2**attempt) + (time.time() % 1) | |
self.logger.info(f"Retrying in {delay:.1f} seconds...") | |
time.sleep(delay) | |
# All retries failed | |
self.logger.error("All embedding generation attempts failed") | |
return [[] for _ in texts] | |
def generate_query_embedding(self, query: str) -> List[float]: | |
""" | |
Generate embedding for a single query string. | |
Args: | |
query: Query text to embed | |
Returns: | |
Embedding vector as a list of floats | |
""" | |
if not self.client or not query: | |
return [] | |
self.logger.info(f"Generating embedding for query: {query[:50]}...") | |
# Check cache first | |
if self.cache is not None: | |
cache_key = self._get_cache_key(query, "query") | |
if cache_key in self.cache: | |
self.stats["cache_hits"] += 1 | |
return self.cache[cache_key] | |
# Generate embedding | |
embeddings = self._generate_with_retry([query]) | |
embedding = embeddings[0] if embeddings else [] | |
# Cache the result | |
if embedding and self.cache is not None: | |
cache_key = self._get_cache_key(query, "query") | |
self.cache[cache_key] = embedding | |
return embedding | |
def _get_cache_key(self, text: str, prefix: str = "doc") -> str: | |
""" | |
Generate cache key for text. | |
Args: | |
text: Text content | |
prefix: Key prefix | |
Returns: | |
Cache key string | |
""" | |
# 🔐 Create hash of text + model for unique key | |
content_hash = hashlib.md5(f"{self.model}:{text}".encode()).hexdigest() | |
return f"{prefix}:{content_hash}" | |
def get_statistics(self) -> Dict[str, Any]: | |
""" | |
Get embedding generation statistics. | |
Returns: | |
Dictionary with statistics | |
""" | |
runtime = datetime.now() - self.stats["start_time"] | |
return { | |
**self.stats, | |
"runtime_seconds": runtime.total_seconds(), | |
"cache_hit_rate": ( | |
self.stats["cache_hits"] / max(1, self.stats["total_requests"]) * 100 | |
), | |
"success_rate": ( | |
self.stats["successful_requests"] | |
/ max(1, self.stats["total_requests"]) | |
* 100 | |
), | |
"avg_tokens_per_request": ( | |
self.stats["total_tokens_processed"] | |
/ max(1, self.stats["total_requests"]) | |
), | |
"cache_size": len(self.cache) if self.cache else 0, | |
"model": self.model, | |
"batch_size": self.batch_size, | |
} | |
def clear_cache(self): | |
"""Clear the embedding cache.""" | |
if self.cache: | |
self.cache.clear() | |
self.logger.info("Embedding cache cleared") | |
def warm_up_cache(self, sample_texts: List[str]): | |
""" | |
🔥 Warm up the cache with sample texts. | |
Args: | |
sample_texts: List of sample texts to pre-generate embeddings | |
""" | |
if not sample_texts: | |
return | |
self.logger.info(f"🔥 Warming up cache with {len(sample_texts)} sample texts") | |
sample_items = [{"content": text, "metadata": {}} for text in sample_texts] | |
self.generate_embeddings(sample_items) | |
self.logger.info("Cache warm-up completed") | |