JoeArmani
restructuring
71ca212
raw
history blame
1.8 kB
import tensorflow as tf
from typing import List, Dict, Optional
from dataclasses import dataclass
from tqdm.auto import tqdm
@dataclass
class GPUMemoryStats:
total: int
used: int
free: int
class GPUMemoryMonitor:
"""Monitor GPU memory usage with safe CPU fallback."""
def __init__(self):
self.has_gpu = False
try:
gpus = tf.config.list_physical_devices('GPU')
self.has_gpu = len(gpus) > 0
except:
pass
def get_memory_stats(self) -> Optional[GPUMemoryStats]:
"""Get current GPU memory statistics."""
if not self.has_gpu:
return None
try:
memory_info = tf.config.experimental.get_memory_info('GPU:0')
return GPUMemoryStats(
total=memory_info['peak'],
used=memory_info['current'],
free=memory_info['peak'] - memory_info['current']
)
except:
return None
def get_memory_usage(self) -> float:
"""Get current GPU memory usage as a percentage."""
if not self.has_gpu:
return 0.0
stats = self.get_memory_stats()
if stats is None or stats.total == 0:
return 0.0
return stats.used / stats.total
def should_reduce_batch_size(self) -> bool:
"""Check if batch size should be reduced based on memory usage."""
if not self.has_gpu:
return False
usage = self.get_memory_usage()
return usage > 0.90
def can_increase_batch_size(self) -> bool:
"""Check if batch size can be increased based on memory usage."""
if not self.has_gpu:
return True
usage = self.get_memory_usage()
return usage < 0.70