ReubenSun's picture
1
2ac1c2d
import gc
import os
import re
import time
from collections import defaultdict
from contextlib import contextmanager
import psutil
import torch
from packaging import version
from .config import config_to_primitive
from .core import debug, find, info, warn
from .typing import *
def parse_version(ver: str):
return version.parse(ver)
def get_rank():
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
# therefore LOCAL_RANK needs to be checked first
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
for key in rank_keys:
rank = os.environ.get(key)
if rank is not None:
return int(rank)
return 0
def get_device():
return torch.device(f"cuda:{get_rank()}")
def load_module_weights(
path, module_name=None, ignore_modules=None, mapping=None, map_location=None
) -> Tuple[dict, int, int]:
if module_name is not None and ignore_modules is not None:
raise ValueError("module_name and ignore_modules cannot be both set")
if map_location is None:
map_location = get_device()
ckpt = torch.load(path, map_location=map_location)
state_dict = ckpt["state_dict"]
if mapping is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
if any([k.startswith(m["to"]) for m in mapping]):
pass
else:
state_dict_to_load[k] = v
for k, v in state_dict.items():
for m in mapping:
if k.startswith(m["from"]):
k_dest = k.replace(m["from"], m["to"])
info(f"Mapping {k} => {k_dest}")
state_dict_to_load[k_dest] = v.clone()
state_dict = state_dict_to_load
state_dict_to_load = state_dict
if ignore_modules is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
ignore = any(
[k.startswith(ignore_module + ".") for ignore_module in ignore_modules]
)
if ignore:
continue
state_dict_to_load[k] = v
if module_name is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
m = re.match(rf"^{module_name}\.(.*)$", k)
if m is None:
continue
state_dict_to_load[m.group(1)] = v
return state_dict_to_load, ckpt["epoch"], ckpt["global_step"]
def C(value: Any, epoch: int, global_step: int) -> float:
if isinstance(value, int) or isinstance(value, float):
pass
else:
value = config_to_primitive(value)
if not isinstance(value, list):
raise TypeError("Scalar specification only supports list, got", type(value))
if len(value) == 3:
value = [0] + value
assert len(value) == 4
start_step, start_value, end_value, end_step = value
if isinstance(end_step, int):
current_step = global_step
value = start_value + (end_value - start_value) * max(
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
)
elif isinstance(end_step, float):
current_step = epoch
value = start_value + (end_value - start_value) * max(
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
)
return value
def cleanup():
gc.collect()
torch.cuda.empty_cache()
try:
import tinycudann as tcnn
tcnn.free_temporary_memory()
except:
pass
def finish_with_cleanup(func: Callable):
def wrapper(*args, **kwargs):
out = func(*args, **kwargs)
cleanup()
return out
return wrapper
def _distributed_available():
return torch.distributed.is_available() and torch.distributed.is_initialized()
def barrier():
if not _distributed_available():
return
else:
torch.distributed.barrier()
def broadcast(tensor, src=0):
if not _distributed_available():
return tensor
else:
torch.distributed.broadcast(tensor, src=src)
return tensor
def enable_gradient(model, enabled: bool = True) -> None:
for param in model.parameters():
param.requires_grad_(enabled)
class TimeRecorder:
_instance = None
def __init__(self):
self.items = {}
self.accumulations = defaultdict(list)
self.time_scale = 1000.0 # ms
self.time_unit = "ms"
self.enabled = False
def __new__(cls):
# singleton
if cls._instance is None:
cls._instance = super(TimeRecorder, cls).__new__(cls)
return cls._instance
def enable(self, enabled: bool) -> None:
self.enabled = enabled
def start(self, name: str) -> None:
if not self.enabled:
return
torch.cuda.synchronize()
self.items[name] = time.time()
def end(self, name: str, accumulate: bool = False) -> float:
if not self.enabled or name not in self.items:
return
torch.cuda.synchronize()
start_time = self.items.pop(name)
delta = time.time() - start_time
if accumulate:
self.accumulations[name].append(delta)
t = delta * self.time_scale
info(f"{name}: {t:.2f}{self.time_unit}")
def get_accumulation(self, name: str, average: bool = False) -> float:
if not self.enabled or name not in self.accumulations:
return
acc = self.accumulations.pop(name)
total = sum(acc)
if average:
t = total / len(acc) * self.time_scale
else:
t = total * self.time_scale
info(f"{name} for {len(acc)} times: {t:.2f}{self.time_unit}")
### global time recorder
time_recorder = TimeRecorder()
@contextmanager
def time_recorder_enabled():
enabled = time_recorder.enabled
time_recorder.enable(enabled=True)
try:
yield
finally:
time_recorder.enable(enabled=enabled)
def show_vram_usage(name):
available, total = torch.cuda.mem_get_info()
used = total - available
print(
f"{name}: {used / 1024**2:.1f}MB, {psutil.Process(os.getpid()).memory_info().rss / 1024**2:.1f}MB"
)