payman / src /embedding /embedding_generator.py
satyamdev404's picture
Upload 31 files
e0aa230 verified
"""
Embedding Generator Module
This module is responsible for generating vector embeddings for text chunks
using Gemini Embedding v3 with complete API integration.
Technology: Gemini Embedding v3 (gemini-embedding-exp-03-07)
"""
import logging
import os
import time
import hashlib
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Union
import json
# Import Gemini API and caching libraries
try:
import google.generativeai as genai
from cachetools import TTLCache
except ImportError as e:
logging.warning(f"Some embedding libraries are not installed: {e}")
from utils.error_handler import EmbeddingError, error_handler, ErrorType
class EmbeddingGenerator:
"""
Generates vector embeddings for text chunks using Gemini Embedding v3 with full functionality.
Features:
- Gemini Embedding v3 API integration
- Batch processing with rate limiting
- Intelligent retry logic with exponential backoff
- Embedding caching mechanism
- Cost optimization
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
Initialize the EmbeddingGenerator with configuration.
Args:
config: Configuration dictionary with API parameters
"""
self.config = config or {}
self.logger = logging.getLogger(__name__)
# API Configuration
self.api_key = self.config.get("api_key", os.environ.get("GEMINI_API_KEY"))
self.model = self.config.get("model", "gemini-embedding-exp-03-07")
self.batch_size = self.config.get("batch_size", 5)
self.max_retries = self.config.get("max_retries", 3)
self.retry_delay = self.config.get("retry_delay", 1)
# Performance settings
self.rate_limit_delay = self.config.get("rate_limit_delay", 0.1)
self.max_text_length = self.config.get(
"max_text_length", 8192
) # ✨ 8K token limit for latest model
self.enable_caching = self.config.get("enable_caching", True)
self.cache_ttl = self.config.get("cache_ttl", 3600) # 1 hour
# Statistics tracking
self.stats = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"cache_hits": 0,
"total_tokens_processed": 0,
"start_time": datetime.now(),
}
# Initialize cache
if self.enable_caching:
self.cache = TTLCache(maxsize=1000, ttl=self.cache_ttl)
else:
self.cache = None
# Validate and initialize API client
self._initialize_client()
def _initialize_client(self):
"""Initialize Gemini API client with validation."""
if not self.api_key:
self.logger.warning(
"No Gemini API key provided. Embeddings will not be generated."
)
self.client = None
return
try:
# Configure Gemini API
genai.configure(api_key=self.api_key)
# Test API connection
self._test_api_connection()
self.client = genai
self.logger.info("Gemini API client initialized successfully")
except Exception as e:
self.logger.error(f"Failed to initialize Gemini API client: {str(e)}")
self.client = None
def _test_api_connection(self):
"""Test API connection with a simple request."""
try:
# Test with a simple embedding request
test_result = genai.embed_content(
model=self.model,
content="test connection",
task_type="retrieval_document",
)
if not test_result.get("embedding"):
raise Exception("No embedding returned from test request")
self.logger.info("API connection test successful")
except Exception as e:
raise EmbeddingError(f"API connection test failed: {str(e)}")
@error_handler(ErrorType.EMBEDDING_GENERATION)
def generate_embeddings(self, texts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Generate embeddings for a list of text chunks with full functionality.
Args:
texts: List of dictionaries containing text chunks and metadata
Each dict should have 'content' and 'metadata' keys
Returns:
List of dictionaries with original content, metadata, and embeddings
"""
if not self.client or not texts:
self.logger.warning("No API client or empty text list")
return texts
self.logger.info(f"Generating embeddings for {len(texts)} text chunks")
start_time = time.time()
# Filter and validate texts
valid_texts = self._validate_texts(texts)
if not valid_texts:
self.logger.warning("No valid texts to process")
return texts
# Process in batches to respect API limits
results = []
total_batches = (len(valid_texts) + self.batch_size - 1) // self.batch_size
for i in range(0, len(valid_texts), self.batch_size):
batch_num = (i // self.batch_size) + 1
batch = valid_texts[i : i + self.batch_size]
self.logger.info(
f"Processing batch {batch_num}/{total_batches} ({len(batch)} items)"
)
try:
batch_results = self._process_batch(batch)
results.extend(batch_results)
# Rate limiting between batches
if i + self.batch_size < len(valid_texts):
time.sleep(self.rate_limit_delay)
except Exception as e:
self.logger.error(f"Batch {batch_num} failed: {str(e)}")
# Add original items without embeddings
for item in batch:
item_copy = item.copy()
item_copy["embedding"] = []
item_copy["embedding_error"] = str(e)
results.append(item_copy)
# Update statistics
processing_time = time.time() - start_time
self.logger.info(f"Embedding generation completed in {processing_time:.2f}s")
return results
def _validate_texts(self, texts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Validate and filter text inputs.
Args:
texts: List of text dictionaries
Returns:
List of valid text dictionaries
"""
valid_texts = []
for i, item in enumerate(texts):
if not isinstance(item, dict) or "content" not in item:
self.logger.warning(f"Invalid item at index {i}: missing 'content' key")
continue
content = item["content"]
if not content or not isinstance(content, str):
self.logger.warning(
f"Invalid content at index {i}: empty or non-string"
)
continue
# Truncate if too long
if len(content) > self.max_text_length:
self.logger.warning(
f"Truncating text at index {i}: {len(content)} -> {self.max_text_length} chars"
)
item = item.copy()
item["content"] = content[: self.max_text_length]
item["metadata"] = item.get("metadata", {})
item["metadata"]["truncated"] = True
item["metadata"]["original_length"] = len(content)
valid_texts.append(item)
return valid_texts
def _process_batch(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Process a batch of text chunks to generate embeddings.
Args:
batch: List of dictionaries containing text chunks and metadata
Returns:
List of dictionaries with original content, metadata, and embeddings
"""
# Extract content and check cache
contents = []
cache_results = {}
for i, item in enumerate(batch):
content = item["content"]
# Check cache first
if self.cache is not None:
cache_key = self._get_cache_key(content)
if cache_key in self.cache:
cache_results[i] = self.cache[cache_key]
self.stats["cache_hits"] += 1
continue
contents.append((i, content))
# Generate embeddings for non-cached content
embeddings_map = {}
if contents:
content_texts = [content for _, content in contents]
embeddings = self._generate_with_retry(content_texts)
# Map embeddings back to indices
for j, (original_index, content) in enumerate(contents):
if j < len(embeddings):
embedding = embeddings[j]
embeddings_map[original_index] = embedding
# Cache the result
if self.cache is not None:
cache_key = self._get_cache_key(content)
self.cache[cache_key] = embedding
# 🔗 Combine results
results = []
for i, item in enumerate(batch):
result = item.copy()
# Add embedding from cache or new generation
if i in cache_results:
result["embedding"] = cache_results[i]
result["embedding_source"] = "cache"
elif i in embeddings_map:
result["embedding"] = embeddings_map[i]
result["embedding_source"] = "api"
else:
result["embedding"] = []
result["embedding_source"] = "failed"
self.logger.warning(f"Missing embedding for batch item {i}")
# Add embedding metadata
if result["embedding"]:
result["metadata"] = result.get("metadata", {})
result["metadata"].update(
{
"embedding_model": self.model,
"embedding_dimension": len(result["embedding"]),
"embedding_generated_at": datetime.now().isoformat(),
}
)
results.append(result)
return results
def _generate_with_retry(self, texts: List[str]) -> List[List[float]]:
"""
Generate embeddings with intelligent retry logic.
Args:
texts: List of text strings to embed
Returns:
List of embedding vectors (each is a list of floats)
"""
for attempt in range(self.max_retries):
try:
self.stats["total_requests"] += 1
# Generate embeddings using Gemini API
embeddings = []
for text in texts:
try:
# Track tokens
self.stats["total_tokens_processed"] += len(text.split())
# Call Gemini API
result = self.client.embed_content(
model=self.model,
content=text,
task_type="retrieval_document",
title="Document chunk for RAG system",
)
if result and "embedding" in result:
embeddings.append(result["embedding"])
else:
self.logger.warning(
f"No embedding in API response for text: {text[:50]}..."
)
embeddings.append([])
except Exception as e:
self.logger.warning(
f"Failed to embed individual text: {str(e)}"
)
embeddings.append([])
self.stats["successful_requests"] += 1
return embeddings
except Exception as e:
self.stats["failed_requests"] += 1
self.logger.warning(
f"Embedding generation failed (attempt {attempt+1}/{self.max_retries}): {str(e)}"
)
if attempt < self.max_retries - 1:
# Exponential backoff with jitter
delay = self.retry_delay * (2**attempt) + (time.time() % 1)
self.logger.info(f"Retrying in {delay:.1f} seconds...")
time.sleep(delay)
# All retries failed
self.logger.error("All embedding generation attempts failed")
return [[] for _ in texts]
@error_handler(ErrorType.EMBEDDING_GENERATION)
def generate_query_embedding(self, query: str) -> List[float]:
"""
Generate embedding for a single query string.
Args:
query: Query text to embed
Returns:
Embedding vector as a list of floats
"""
if not self.client or not query:
return []
self.logger.info(f"Generating embedding for query: {query[:50]}...")
# Check cache first
if self.cache is not None:
cache_key = self._get_cache_key(query, "query")
if cache_key in self.cache:
self.stats["cache_hits"] += 1
return self.cache[cache_key]
# Generate embedding
embeddings = self._generate_with_retry([query])
embedding = embeddings[0] if embeddings else []
# Cache the result
if embedding and self.cache is not None:
cache_key = self._get_cache_key(query, "query")
self.cache[cache_key] = embedding
return embedding
def _get_cache_key(self, text: str, prefix: str = "doc") -> str:
"""
Generate cache key for text.
Args:
text: Text content
prefix: Key prefix
Returns:
Cache key string
"""
# 🔐 Create hash of text + model for unique key
content_hash = hashlib.md5(f"{self.model}:{text}".encode()).hexdigest()
return f"{prefix}:{content_hash}"
def get_statistics(self) -> Dict[str, Any]:
"""
Get embedding generation statistics.
Returns:
Dictionary with statistics
"""
runtime = datetime.now() - self.stats["start_time"]
return {
**self.stats,
"runtime_seconds": runtime.total_seconds(),
"cache_hit_rate": (
self.stats["cache_hits"] / max(1, self.stats["total_requests"]) * 100
),
"success_rate": (
self.stats["successful_requests"]
/ max(1, self.stats["total_requests"])
* 100
),
"avg_tokens_per_request": (
self.stats["total_tokens_processed"]
/ max(1, self.stats["total_requests"])
),
"cache_size": len(self.cache) if self.cache else 0,
"model": self.model,
"batch_size": self.batch_size,
}
def clear_cache(self):
"""Clear the embedding cache."""
if self.cache:
self.cache.clear()
self.logger.info("Embedding cache cleared")
def warm_up_cache(self, sample_texts: List[str]):
"""
🔥 Warm up the cache with sample texts.
Args:
sample_texts: List of sample texts to pre-generate embeddings
"""
if not sample_texts:
return
self.logger.info(f"🔥 Warming up cache with {len(sample_texts)} sample texts")
sample_items = [{"content": text, "metadata": {}} for text in sample_texts]
self.generate_embeddings(sample_items)
self.logger.info("Cache warm-up completed")