File size: 2,083 Bytes
817cd1e
 
7e8823d
817cd1e
 
 
 
7e8823d
817cd1e
7e8823d
817cd1e
 
7e8823d
 
 
 
817cd1e
 
 
 
 
 
 
 
 
 
 
 
 
7e8823d
817cd1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4953ce6
 
 
 
 
 
 
 
 
 
817cd1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from subprocess import check_output
from threading import Timer
from typing import Callable, List, Tuple


def get_gpu_memory() -> List[int]:
    """
    Get the used and total GPU memory (VRAM) in MiB

    :return memory_values: List of used and total GPU memory (VRAM) in MiB
    """

    command = "nvidia-smi --query-gpu=memory.used,memory.total --format=csv,noheader,nounits"
    memory_info = check_output(command.split()).decode("ascii").replace("\r", "").split("\n")[:-1]
    memory_values = list(map(lambda x: tuple(map(int, x.split(","))), memory_info))
    return memory_values


class RepeatingTimer(Timer):
    def run(self):
        self.finished.wait(self.interval)
        while not self.finished.is_set():
            self.function(*self.args, **self.kwargs)
            self.finished.wait(self.interval)


gpu_memory_watcher: RepeatingTimer = None


def watch_gpu_memory(interval: int = 1, callback: Callable[[List[Tuple[int, int]]], None] = None) -> RepeatingTimer:
    """
    Start a repeating timer to watch the GPU memory usage

    :param interval: Interval in seconds
    :return timer: RepeatingTimer object
    """
    global gpu_memory_watcher
    if gpu_memory_watcher is not None:
        raise RuntimeError("GPU memory watcher is already running")

    if callback is None:
        callback = print

    gpu_memory_watcher = RepeatingTimer(interval, lambda: callback(get_gpu_memory()))
    gpu_memory_watcher.start()

    return gpu_memory_watcher


def stop_watcher():
    global gpu_memory_watcher
    if gpu_memory_watcher is None:
        return

    gpu_memory_watcher.cancel()
    del gpu_memory_watcher
    gpu_memory_watcher = None


if __name__ == "__main__":
    from time import sleep

    t = watch_gpu_memory()

    counter = 0
    while True:
        sleep(1)
        counter += 1
        if counter == 10:
            try:
                watch_gpu_memory()
            except RuntimeError:
                print("Got exception")
                pass
        elif counter >= 20:
            gpu_memory_watcher.cancel()
            break