import gc import os import re import time from collections import defaultdict from contextlib import contextmanager import psutil import torch from packaging import version from .config import config_to_primitive from .core import debug, find, info, warn from .typing import * def parse_version(ver: str): return version.parse(ver) def get_rank(): # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, # therefore LOCAL_RANK needs to be checked first rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") for key in rank_keys: rank = os.environ.get(key) if rank is not None: return int(rank) return 0 def get_device(): return torch.device(f"cuda:{get_rank()}") def load_module_weights( path, module_name=None, ignore_modules=None, mapping=None, map_location=None ) -> Tuple[dict, int, int]: if module_name is not None and ignore_modules is not None: raise ValueError("module_name and ignore_modules cannot be both set") if map_location is None: map_location = get_device() ckpt = torch.load(path, map_location=map_location) state_dict = ckpt["state_dict"] if mapping is not None: state_dict_to_load = {} for k, v in state_dict.items(): if any([k.startswith(m["to"]) for m in mapping]): pass else: state_dict_to_load[k] = v for k, v in state_dict.items(): for m in mapping: if k.startswith(m["from"]): k_dest = k.replace(m["from"], m["to"]) info(f"Mapping {k} => {k_dest}") state_dict_to_load[k_dest] = v.clone() state_dict = state_dict_to_load state_dict_to_load = state_dict if ignore_modules is not None: state_dict_to_load = {} for k, v in state_dict.items(): ignore = any( [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] ) if ignore: continue state_dict_to_load[k] = v if module_name is not None: state_dict_to_load = {} for k, v in state_dict.items(): m = re.match(rf"^{module_name}\.(.*)$", k) if m is None: continue state_dict_to_load[m.group(1)] = v return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] def C(value: Any, epoch: int, global_step: int) -> float: if isinstance(value, int) or isinstance(value, float): pass else: value = config_to_primitive(value) if not isinstance(value, list): raise TypeError("Scalar specification only supports list, got", type(value)) if len(value) == 3: value = [0] + value assert len(value) == 4 start_step, start_value, end_value, end_step = value if isinstance(end_step, int): current_step = global_step value = start_value + (end_value - start_value) * max( min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 ) elif isinstance(end_step, float): current_step = epoch value = start_value + (end_value - start_value) * max( min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 ) return value def cleanup(): gc.collect() torch.cuda.empty_cache() try: import tinycudann as tcnn tcnn.free_temporary_memory() except: pass def finish_with_cleanup(func: Callable): def wrapper(*args, **kwargs): out = func(*args, **kwargs) cleanup() return out return wrapper def _distributed_available(): return torch.distributed.is_available() and torch.distributed.is_initialized() def barrier(): if not _distributed_available(): return else: torch.distributed.barrier() def broadcast(tensor, src=0): if not _distributed_available(): return tensor else: torch.distributed.broadcast(tensor, src=src) return tensor def enable_gradient(model, enabled: bool = True) -> None: for param in model.parameters(): param.requires_grad_(enabled) class TimeRecorder: _instance = None def __init__(self): self.items = {} self.accumulations = defaultdict(list) self.time_scale = 1000.0 # ms self.time_unit = "ms" self.enabled = False def __new__(cls): # singleton if cls._instance is None: cls._instance = super(TimeRecorder, cls).__new__(cls) return cls._instance def enable(self, enabled: bool) -> None: self.enabled = enabled def start(self, name: str) -> None: if not self.enabled: return torch.cuda.synchronize() self.items[name] = time.time() def end(self, name: str, accumulate: bool = False) -> float: if not self.enabled or name not in self.items: return torch.cuda.synchronize() start_time = self.items.pop(name) delta = time.time() - start_time if accumulate: self.accumulations[name].append(delta) t = delta * self.time_scale info(f"{name}: {t:.2f}{self.time_unit}") def get_accumulation(self, name: str, average: bool = False) -> float: if not self.enabled or name not in self.accumulations: return acc = self.accumulations.pop(name) total = sum(acc) if average: t = total / len(acc) * self.time_scale else: t = total * self.time_scale info(f"{name} for {len(acc)} times: {t:.2f}{self.time_unit}") ### global time recorder time_recorder = TimeRecorder() @contextmanager def time_recorder_enabled(): enabled = time_recorder.enabled time_recorder.enable(enabled=True) try: yield finally: time_recorder.enable(enabled=enabled) def show_vram_usage(name): available, total = torch.cuda.mem_get_info() used = total - available print( f"{name}: {used / 1024**2:.1f}MB, {psutil.Process(os.getpid()).memory_info().rss / 1024**2:.1f}MB" )