File size: 2,019 Bytes
9decf80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import tensorflow as tf
import faiss
import json
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Generator
from dataclasses import dataclass
import threading
from queue import Queue
import gc
try:
    from tqdm.notebook import tqdm
except ImportError:
    from tqdm import tqdm

@dataclass
class GPUMemoryStats:
    total: int
    used: int
    free: int

class GPUMemoryMonitor:
    """Monitor GPU memory usage with safe CPU fallback."""
    def __init__(self):
        self.has_gpu = False
        try:
            gpus = tf.config.list_physical_devices('GPU')
            self.has_gpu = len(gpus) > 0
        except:
            pass
            
    def get_memory_stats(self) -> Optional[GPUMemoryStats]:
        """Get current GPU memory statistics."""
        if not self.has_gpu:
            return None
            
        try:
            memory_info = tf.config.experimental.get_memory_info('GPU:0')
            return GPUMemoryStats(
                total=memory_info['peak'],
                used=memory_info['current'],
                free=memory_info['peak'] - memory_info['current']
            )
        except:
            return None
    
    def get_memory_usage(self) -> float:
        """Get current GPU memory usage as a percentage."""
        if not self.has_gpu:
            return 0.0
        stats = self.get_memory_stats()
        if stats is None or stats.total == 0:
            return 0.0
        return stats.used / stats.total

    def should_reduce_batch_size(self) -> bool:
        """Check if batch size should be reduced based on memory usage."""
        if not self.has_gpu:
            return False
        usage = self.get_memory_usage()
        return usage > 0.90

    def can_increase_batch_size(self) -> bool:
        """Check if batch size can be increased based on memory usage."""
        if not self.has_gpu:
            return True  # Allow increase on CPU
        usage = self.get_memory_usage()
        return usage < 0.70