JoeArmani
FAISS and streaming updates
9decf80
raw
history blame
2.02 kB
import numpy as np
import tensorflow as tf
import faiss
import json
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Generator
from dataclasses import dataclass
import threading
from queue import Queue
import gc
try:
from tqdm.notebook import tqdm
except ImportError:
from tqdm import tqdm
@dataclass
class GPUMemoryStats:
total: int
used: int
free: int
class GPUMemoryMonitor:
"""Monitor GPU memory usage with safe CPU fallback."""
def __init__(self):
self.has_gpu = False
try:
gpus = tf.config.list_physical_devices('GPU')
self.has_gpu = len(gpus) > 0
except:
pass
def get_memory_stats(self) -> Optional[GPUMemoryStats]:
"""Get current GPU memory statistics."""
if not self.has_gpu:
return None
try:
memory_info = tf.config.experimental.get_memory_info('GPU:0')
return GPUMemoryStats(
total=memory_info['peak'],
used=memory_info['current'],
free=memory_info['peak'] - memory_info['current']
)
except:
return None
def get_memory_usage(self) -> float:
"""Get current GPU memory usage as a percentage."""
if not self.has_gpu:
return 0.0
stats = self.get_memory_stats()
if stats is None or stats.total == 0:
return 0.0
return stats.used / stats.total
def should_reduce_batch_size(self) -> bool:
"""Check if batch size should be reduced based on memory usage."""
if not self.has_gpu:
return False
usage = self.get_memory_usage()
return usage > 0.90
def can_increase_batch_size(self) -> bool:
"""Check if batch size can be increased based on memory usage."""
if not self.has_gpu:
return True # Allow increase on CPU
usage = self.get_memory_usage()
return usage < 0.70