File size: 2,845 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import time
import torch
import inspect


def fold_path(fn:str):
    ''' Fold a path like `from/to/file.py` to relative `f/t/file.py`. '''
    return '/'.join([p[:1] for p in fn.split('/')[:-1]]) + '/' + fn.split('/')[-1]


def summary_frame_info(frame:inspect.FrameInfo):
    ''' Convert a FrameInfo object to a summary string. '''
    return f'{frame.function} @ {fold_path(frame.filename)}:{frame.lineno}'


class GPUMonitor():
    '''
    This monitor is designed for GPU memory analysis. It records the peak memory usage in a period of time.
    A snapshot will record the peak memory usage until the snapshot is taken. (After init / reset / previous snapshot.)
    '''

    def __init__(self):
        self.reset()
        self.clear()
        self.log_fn = 'gpu_monitor.log'


    def snapshot(self, desc:str='snapshot'):
        timestamp = time.time()
        caller_frame = inspect.stack()[1]
        peak_MB = torch.cuda.max_memory_allocated() / 1024 / 1024
        free_mem, total_mem = torch.cuda.mem_get_info(0)
        free_mem_MB, total_mem_MB = free_mem / 1024 / 1024, total_mem / 1024 / 1024

        record = {
                'until'     : time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp)),
                'until_raw' : timestamp,
                'position'  : summary_frame_info(caller_frame),
                'peak'      : peak_MB,
                'peak_msg'  : f'{peak_MB:.2f} MB',
                'free'      : free_mem_MB,
                'total'     : total_mem_MB,
                'free_msg'  : f'{free_mem_MB:.2f} MB',
                'total_msg' : f'{total_mem_MB:.2f} MB',
                'desc'      : desc,
            }

        self.max_peak = max(self.max_peak_MB, peak_MB)

        self.records.append(record)
        self._update_log(record)

        self.reset()
        return record


    def report_latest(self, k:int=1):
        import rich
        caller_frame = inspect.stack()[1]
        caller_info = summary_frame_info(caller_frame)
        rich.print(f'{caller_info} -> latest {k} records:')
        for rid, record in enumerate(self.records[-k:]):
            msg = self._generate_log_msg(record)
            rich.print(msg)


    def report_all(self):
        self.report_latest(len(self.records))


    def reset(self):
        torch.cuda.reset_peak_memory_stats()
        return


    def clear(self):
        self.records = []
        self.max_peak_MB = 0

    def _generate_log_msg(self, record):
        time = record['until']
        peak = record['peak']
        desc = record['desc']
        position = record['position']
        msg = f'[{time}] ⛰️ {peak:>8.2f} MB πŸ“Œ {desc} 🌐 {position}'
        return msg


    def _update_log(self, record):
        msg = self._generate_log_msg(record)
        with open(self.log_fn, 'a') as f:
            f.write(msg + '\n')