import numpy as np import tensorflow as tf import faiss import json from pathlib import Path from typing import List, Dict, Tuple, Optional, Generator from dataclasses import dataclass import threading from queue import Queue import gc try: from tqdm.notebook import tqdm except ImportError: from tqdm import tqdm @dataclass class GPUMemoryStats: total: int used: int free: int class GPUMemoryMonitor: """Monitor GPU memory usage with safe CPU fallback.""" def __init__(self): self.has_gpu = False try: gpus = tf.config.list_physical_devices('GPU') self.has_gpu = len(gpus) > 0 except: pass def get_memory_stats(self) -> Optional[GPUMemoryStats]: """Get current GPU memory statistics.""" if not self.has_gpu: return None try: memory_info = tf.config.experimental.get_memory_info('GPU:0') return GPUMemoryStats( total=memory_info['peak'], used=memory_info['current'], free=memory_info['peak'] - memory_info['current'] ) except: return None def get_memory_usage(self) -> float: """Get current GPU memory usage as a percentage.""" if not self.has_gpu: return 0.0 stats = self.get_memory_stats() if stats is None or stats.total == 0: return 0.0 return stats.used / stats.total def should_reduce_batch_size(self) -> bool: """Check if batch size should be reduced based on memory usage.""" if not self.has_gpu: return False usage = self.get_memory_usage() return usage > 0.90 def can_increase_batch_size(self) -> bool: """Check if batch size can be increased based on memory usage.""" if not self.has_gpu: return True # Allow increase on CPU usage = self.get_memory_usage() return usage < 0.70