|
import numpy as np |
|
import tensorflow as tf |
|
import faiss |
|
import json |
|
from pathlib import Path |
|
from typing import List, Dict, Tuple, Optional, Generator |
|
from dataclasses import dataclass |
|
import threading |
|
from queue import Queue |
|
import gc |
|
try: |
|
from tqdm.notebook import tqdm |
|
except ImportError: |
|
from tqdm import tqdm |
|
|
|
@dataclass |
|
class GPUMemoryStats: |
|
total: int |
|
used: int |
|
free: int |
|
|
|
class GPUMemoryMonitor: |
|
"""Monitor GPU memory usage with safe CPU fallback.""" |
|
def __init__(self): |
|
self.has_gpu = False |
|
try: |
|
gpus = tf.config.list_physical_devices('GPU') |
|
self.has_gpu = len(gpus) > 0 |
|
except: |
|
pass |
|
|
|
def get_memory_stats(self) -> Optional[GPUMemoryStats]: |
|
"""Get current GPU memory statistics.""" |
|
if not self.has_gpu: |
|
return None |
|
|
|
try: |
|
memory_info = tf.config.experimental.get_memory_info('GPU:0') |
|
return GPUMemoryStats( |
|
total=memory_info['peak'], |
|
used=memory_info['current'], |
|
free=memory_info['peak'] - memory_info['current'] |
|
) |
|
except: |
|
return None |
|
|
|
def get_memory_usage(self) -> float: |
|
"""Get current GPU memory usage as a percentage.""" |
|
if not self.has_gpu: |
|
return 0.0 |
|
stats = self.get_memory_stats() |
|
if stats is None or stats.total == 0: |
|
return 0.0 |
|
return stats.used / stats.total |
|
|
|
def should_reduce_batch_size(self) -> bool: |
|
"""Check if batch size should be reduced based on memory usage.""" |
|
if not self.has_gpu: |
|
return False |
|
usage = self.get_memory_usage() |
|
return usage > 0.90 |
|
|
|
def can_increase_batch_size(self) -> bool: |
|
"""Check if batch size can be increased based on memory usage.""" |
|
if not self.has_gpu: |
|
return True |
|
usage = self.get_memory_usage() |
|
return usage < 0.70 |