|
|
|
|
|
import asyncio |
|
import time |
|
import hashlib |
|
from typing import Dict, Any, List, Optional, Callable, TypeVar, Generic |
|
from dataclasses import dataclass, field |
|
from functools import wraps, lru_cache |
|
import json |
|
|
|
from ankigen_core.logging import logger |
|
from ankigen_core.models import Card |
|
|
|
|
|
T = TypeVar("T") |
|
|
|
|
|
@dataclass |
|
class CacheConfig: |
|
"""Configuration for agent response caching""" |
|
|
|
enable_caching: bool = True |
|
cache_ttl: int = 3600 |
|
max_cache_size: int = 1000 |
|
cache_backend: str = "memory" |
|
cache_directory: Optional[str] = None |
|
|
|
def __post_init__(self): |
|
if self.cache_backend == "file" and not self.cache_directory: |
|
self.cache_directory = "cache/agents" |
|
|
|
|
|
@dataclass |
|
class PerformanceConfig: |
|
"""Configuration for performance optimizations""" |
|
|
|
enable_batch_processing: bool = True |
|
max_batch_size: int = 10 |
|
batch_timeout: float = 2.0 |
|
enable_parallel_execution: bool = True |
|
max_concurrent_requests: int = 5 |
|
enable_request_deduplication: bool = True |
|
enable_response_caching: bool = True |
|
cache_config: CacheConfig = field(default_factory=CacheConfig) |
|
|
|
|
|
@dataclass |
|
class CacheEntry(Generic[T]): |
|
"""Cache entry with metadata""" |
|
|
|
value: T |
|
created_at: float |
|
access_count: int = 0 |
|
last_accessed: float = field(default_factory=time.time) |
|
cache_key: str = "" |
|
|
|
def is_expired(self, ttl: int) -> bool: |
|
"""Check if cache entry is expired""" |
|
return time.time() - self.created_at > ttl |
|
|
|
def touch(self): |
|
"""Update access metadata""" |
|
self.access_count += 1 |
|
self.last_accessed = time.time() |
|
|
|
|
|
class MemoryCache(Generic[T]): |
|
"""In-memory cache with LRU eviction""" |
|
|
|
def __init__(self, config: CacheConfig): |
|
self.config = config |
|
self._cache: Dict[str, CacheEntry[T]] = {} |
|
self._access_order: List[str] = [] |
|
self._lock = asyncio.Lock() |
|
|
|
async def get(self, key: str) -> Optional[T]: |
|
"""Get value from cache""" |
|
async with self._lock: |
|
entry = self._cache.get(key) |
|
if not entry: |
|
return None |
|
|
|
if entry.is_expired(self.config.cache_ttl): |
|
await self._remove(key) |
|
return None |
|
|
|
entry.touch() |
|
self._update_access_order(key) |
|
|
|
logger.debug(f"Cache hit for key: {key[:20]}...") |
|
return entry.value |
|
|
|
async def set(self, key: str, value: T) -> None: |
|
"""Set value in cache""" |
|
async with self._lock: |
|
|
|
if len(self._cache) >= self.config.max_cache_size: |
|
await self._evict_lru() |
|
|
|
entry = CacheEntry(value=value, created_at=time.time(), cache_key=key) |
|
|
|
self._cache[key] = entry |
|
self._update_access_order(key) |
|
|
|
logger.debug(f"Cache set for key: {key[:20]}...") |
|
|
|
async def remove(self, key: str) -> bool: |
|
"""Remove entry from cache""" |
|
async with self._lock: |
|
return await self._remove(key) |
|
|
|
async def clear(self) -> None: |
|
"""Clear all cache entries""" |
|
async with self._lock: |
|
self._cache.clear() |
|
self._access_order.clear() |
|
logger.info("Cache cleared") |
|
|
|
async def _remove(self, key: str) -> bool: |
|
"""Internal remove method""" |
|
if key in self._cache: |
|
del self._cache[key] |
|
if key in self._access_order: |
|
self._access_order.remove(key) |
|
return True |
|
return False |
|
|
|
async def _evict_lru(self) -> None: |
|
"""Evict least recently used entries""" |
|
if not self._access_order: |
|
return |
|
|
|
|
|
to_remove = self._access_order[: len(self._access_order) // 4] |
|
for key in to_remove: |
|
await self._remove(key) |
|
|
|
logger.debug(f"Evicted {len(to_remove)} cache entries") |
|
|
|
def _update_access_order(self, key: str) -> None: |
|
"""Update access order for LRU tracking""" |
|
if key in self._access_order: |
|
self._access_order.remove(key) |
|
self._access_order.append(key) |
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
"""Get cache statistics""" |
|
total_accesses = sum(entry.access_count for entry in self._cache.values()) |
|
return { |
|
"entries": len(self._cache), |
|
"max_size": self.config.max_cache_size, |
|
"total_accesses": total_accesses, |
|
"hit_rate": total_accesses / max(1, len(self._cache)), |
|
} |
|
|
|
|
|
class BatchProcessor: |
|
"""Batch processor for agent requests""" |
|
|
|
def __init__(self, config: PerformanceConfig): |
|
self.config = config |
|
self._batches: Dict[str, List[Dict[str, Any]]] = {} |
|
self._batch_timers: Dict[str, asyncio.Task] = {} |
|
self._lock = asyncio.Lock() |
|
|
|
async def add_request( |
|
self, batch_key: str, request_data: Dict[str, Any], processor_func: Callable |
|
) -> Any: |
|
"""Add request to batch for processing""" |
|
|
|
if not self.config.enable_batch_processing: |
|
|
|
return await processor_func([request_data]) |
|
|
|
async with self._lock: |
|
|
|
if batch_key not in self._batches: |
|
self._batches[batch_key] = [] |
|
self._start_batch_timer(batch_key, processor_func) |
|
|
|
|
|
self._batches[batch_key].append(request_data) |
|
|
|
|
|
if len(self._batches[batch_key]) >= self.config.max_batch_size: |
|
return await self._process_batch(batch_key, processor_func) |
|
|
|
|
|
return await self._wait_for_batch_result( |
|
batch_key, request_data, processor_func |
|
) |
|
|
|
def _start_batch_timer(self, batch_key: str, processor_func: Callable) -> None: |
|
"""Start timer for batch processing""" |
|
|
|
async def timer(): |
|
await asyncio.sleep(self.config.batch_timeout) |
|
async with self._lock: |
|
if batch_key in self._batches and self._batches[batch_key]: |
|
await self._process_batch(batch_key, processor_func) |
|
|
|
self._batch_timers[batch_key] = asyncio.create_task(timer()) |
|
|
|
async def _process_batch( |
|
self, batch_key: str, processor_func: Callable |
|
) -> List[Any]: |
|
"""Process accumulated batch""" |
|
if batch_key not in self._batches: |
|
return [] |
|
|
|
batch = self._batches.pop(batch_key) |
|
|
|
|
|
if batch_key in self._batch_timers: |
|
self._batch_timers[batch_key].cancel() |
|
del self._batch_timers[batch_key] |
|
|
|
if not batch: |
|
return [] |
|
|
|
logger.debug(f"Processing batch {batch_key} with {len(batch)} requests") |
|
|
|
try: |
|
|
|
results = await processor_func(batch) |
|
return results if isinstance(results, list) else [results] |
|
|
|
except Exception as e: |
|
logger.error(f"Batch processing failed for {batch_key}: {e}") |
|
raise |
|
|
|
async def _wait_for_batch_result( |
|
self, batch_key: str, request_data: Dict[str, Any], processor_func: Callable |
|
) -> Any: |
|
"""Wait for batch processing to complete""" |
|
|
|
|
|
|
|
|
|
while batch_key in self._batches: |
|
await asyncio.sleep(0.1) |
|
|
|
|
|
return await processor_func([request_data]) |
|
|
|
|
|
class RequestDeduplicator: |
|
"""Deduplicates identical agent requests""" |
|
|
|
def __init__(self): |
|
self._pending_requests: Dict[str, asyncio.Future] = {} |
|
self._lock = asyncio.Lock() |
|
|
|
@lru_cache(maxsize=1000) |
|
def _generate_request_hash(self, request_data: str) -> str: |
|
"""Generate hash for request deduplication""" |
|
return hashlib.md5(request_data.encode()).hexdigest() |
|
|
|
async def deduplicate_request( |
|
self, request_data: Dict[str, Any], processor_func: Callable |
|
) -> Any: |
|
"""Deduplicate and process request""" |
|
|
|
|
|
request_str = json.dumps(request_data, sort_keys=True) |
|
request_hash = self._generate_request_hash(request_str) |
|
|
|
async with self._lock: |
|
|
|
if request_hash in self._pending_requests: |
|
logger.debug(f"Deduplicating request: {request_hash[:16]}...") |
|
return await self._pending_requests[request_hash] |
|
|
|
|
|
future = asyncio.create_task( |
|
self._process_unique_request(request_hash, request_data, processor_func) |
|
) |
|
|
|
self._pending_requests[request_hash] = future |
|
|
|
try: |
|
result = await future |
|
return result |
|
finally: |
|
|
|
async with self._lock: |
|
self._pending_requests.pop(request_hash, None) |
|
|
|
async def _process_unique_request( |
|
self, request_hash: str, request_data: Dict[str, Any], processor_func: Callable |
|
) -> Any: |
|
"""Process unique request""" |
|
logger.debug(f"Processing unique request: {request_hash[:16]}...") |
|
return await processor_func(request_data) |
|
|
|
|
|
class PerformanceOptimizer: |
|
"""Main performance optimization coordinator""" |
|
|
|
def __init__(self, config: PerformanceConfig): |
|
self.config = config |
|
self.cache = ( |
|
MemoryCache(config.cache_config) if config.enable_response_caching else None |
|
) |
|
self.batch_processor = ( |
|
BatchProcessor(config) if config.enable_batch_processing else None |
|
) |
|
self.deduplicator = ( |
|
RequestDeduplicator() if config.enable_request_deduplication else None |
|
) |
|
self._semaphore = asyncio.Semaphore(config.max_concurrent_requests) |
|
|
|
async def optimize_agent_call( |
|
self, |
|
agent_name: str, |
|
request_data: Dict[str, Any], |
|
processor_func: Callable, |
|
cache_key_generator: Optional[Callable[[Dict[str, Any]], str]] = None, |
|
) -> Any: |
|
"""Optimize agent call with caching, batching, and deduplication""" |
|
|
|
|
|
cache_key = None |
|
if self.cache and cache_key_generator: |
|
cache_key = cache_key_generator(request_data) |
|
|
|
|
|
cached_result = await self.cache.get(cache_key) |
|
if cached_result is not None: |
|
return cached_result |
|
|
|
|
|
async with self._semaphore: |
|
|
|
if self.deduplicator and self.config.enable_request_deduplication: |
|
result = await self.deduplicator.deduplicate_request( |
|
request_data, processor_func |
|
) |
|
else: |
|
result = await processor_func(request_data) |
|
|
|
|
|
if self.cache and cache_key and result is not None: |
|
await self.cache.set(cache_key, result) |
|
|
|
return result |
|
|
|
async def optimize_batch_processing( |
|
self, batch_key: str, request_data: Dict[str, Any], processor_func: Callable |
|
) -> Any: |
|
"""Optimize using batch processing""" |
|
if self.batch_processor: |
|
return await self.batch_processor.add_request( |
|
batch_key, request_data, processor_func |
|
) |
|
else: |
|
return await processor_func([request_data]) |
|
|
|
def get_performance_stats(self) -> Dict[str, Any]: |
|
"""Get performance optimization statistics""" |
|
stats = { |
|
"config": { |
|
"batch_processing": self.config.enable_batch_processing, |
|
"parallel_execution": self.config.enable_parallel_execution, |
|
"request_deduplication": self.config.enable_request_deduplication, |
|
"response_caching": self.config.enable_response_caching, |
|
}, |
|
"concurrency": { |
|
"max_concurrent": self.config.max_concurrent_requests, |
|
"current_available": self._semaphore._value, |
|
}, |
|
} |
|
|
|
if self.cache: |
|
stats["cache"] = self.cache.get_stats() |
|
|
|
return stats |
|
|
|
|
|
|
|
_global_optimizer: Optional[PerformanceOptimizer] = None |
|
|
|
|
|
def get_performance_optimizer( |
|
config: Optional[PerformanceConfig] = None, |
|
) -> PerformanceOptimizer: |
|
"""Get global performance optimizer instance""" |
|
global _global_optimizer |
|
if _global_optimizer is None: |
|
_global_optimizer = PerformanceOptimizer(config or PerformanceConfig()) |
|
return _global_optimizer |
|
|
|
|
|
|
|
def cache_response(cache_key_func: Callable[[Any], str], ttl: int = 3600): |
|
"""Decorator to cache function responses""" |
|
|
|
def decorator(func): |
|
@wraps(func) |
|
async def wrapper(*args, **kwargs): |
|
optimizer = get_performance_optimizer() |
|
if not optimizer.cache: |
|
return await func(*args, **kwargs) |
|
|
|
|
|
cache_key = cache_key_func(*args, **kwargs) |
|
|
|
|
|
cached_result = await optimizer.cache.get(cache_key) |
|
if cached_result is not None: |
|
return cached_result |
|
|
|
|
|
result = await func(*args, **kwargs) |
|
|
|
|
|
if result is not None: |
|
await optimizer.cache.set(cache_key, result) |
|
|
|
return result |
|
|
|
return wrapper |
|
|
|
return decorator |
|
|
|
|
|
def rate_limit(max_concurrent: int = 5): |
|
"""Decorator to apply rate limiting""" |
|
semaphore = asyncio.Semaphore(max_concurrent) |
|
|
|
def decorator(func): |
|
@wraps(func) |
|
async def wrapper(*args, **kwargs): |
|
async with semaphore: |
|
return await func(*args, **kwargs) |
|
|
|
return wrapper |
|
|
|
return decorator |
|
|
|
|
|
|
|
def generate_card_cache_key( |
|
topic: str, subject: str, num_cards: int, difficulty: str, **kwargs |
|
) -> str: |
|
"""Generate cache key for card generation""" |
|
key_data = { |
|
"topic": topic, |
|
"subject": subject, |
|
"num_cards": num_cards, |
|
"difficulty": difficulty, |
|
"context": kwargs.get("context", {}), |
|
} |
|
key_str = json.dumps(key_data, sort_keys=True) |
|
return f"cards:{hashlib.md5(key_str.encode()).hexdigest()}" |
|
|
|
|
|
def generate_judgment_cache_key( |
|
cards: List[Card], judgment_type: str = "general" |
|
) -> str: |
|
"""Generate cache key for card judgment""" |
|
|
|
card_data = [] |
|
for card in cards: |
|
card_data.append( |
|
{ |
|
"question": card.front.question, |
|
"answer": card.back.answer, |
|
"type": card.card_type, |
|
} |
|
) |
|
|
|
key_data = {"cards": card_data, "judgment_type": judgment_type} |
|
key_str = json.dumps(key_data, sort_keys=True) |
|
return f"judgment:{hashlib.md5(key_str.encode()).hexdigest()}" |
|
|