import tensorflow as tf from typing import List, Dict, Optional from dataclasses import dataclass from tqdm.auto import tqdm @dataclass class GPUMemoryStats: total: int used: int free: int class GPUMemoryMonitor: """Monitor GPU memory usage with safe CPU fallback.""" def __init__(self): self.has_gpu = False try: gpus = tf.config.list_physical_devices('GPU') self.has_gpu = len(gpus) > 0 except: pass def get_memory_stats(self) -> Optional[GPUMemoryStats]: """Get current GPU memory statistics.""" if not self.has_gpu: return None try: memory_info = tf.config.experimental.get_memory_info('GPU:0') return GPUMemoryStats( total=memory_info['peak'], used=memory_info['current'], free=memory_info['peak'] - memory_info['current'] ) except: return None def get_memory_usage(self) -> float: """Get current GPU memory usage as a percentage.""" if not self.has_gpu: return 0.0 stats = self.get_memory_stats() if stats is None or stats.total == 0: return 0.0 return stats.used / stats.total def should_reduce_batch_size(self) -> bool: """Check if batch size should be reduced based on memory usage.""" if not self.has_gpu: return False usage = self.get_memory_usage() return usage > 0.90 def can_increase_batch_size(self) -> bool: """Check if batch size can be increased based on memory usage.""" if not self.has_gpu: return True usage = self.get_memory_usage() return usage < 0.70