Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,277 Bytes
2ac1c2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import gc
import os
import re
import time
from collections import defaultdict
from contextlib import contextmanager
import psutil
import torch
from packaging import version
from .config import config_to_primitive
from .core import debug, find, info, warn
from .typing import *
def parse_version(ver: str):
return version.parse(ver)
def get_rank():
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
# therefore LOCAL_RANK needs to be checked first
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
for key in rank_keys:
rank = os.environ.get(key)
if rank is not None:
return int(rank)
return 0
def get_device():
return torch.device(f"cuda:{get_rank()}")
def load_module_weights(
path, module_name=None, ignore_modules=None, mapping=None, map_location=None
) -> Tuple[dict, int, int]:
if module_name is not None and ignore_modules is not None:
raise ValueError("module_name and ignore_modules cannot be both set")
if map_location is None:
map_location = get_device()
ckpt = torch.load(path, map_location=map_location)
state_dict = ckpt["state_dict"]
if mapping is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
if any([k.startswith(m["to"]) for m in mapping]):
pass
else:
state_dict_to_load[k] = v
for k, v in state_dict.items():
for m in mapping:
if k.startswith(m["from"]):
k_dest = k.replace(m["from"], m["to"])
info(f"Mapping {k} => {k_dest}")
state_dict_to_load[k_dest] = v.clone()
state_dict = state_dict_to_load
state_dict_to_load = state_dict
if ignore_modules is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
ignore = any(
[k.startswith(ignore_module + ".") for ignore_module in ignore_modules]
)
if ignore:
continue
state_dict_to_load[k] = v
if module_name is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
m = re.match(rf"^{module_name}\.(.*)$", k)
if m is None:
continue
state_dict_to_load[m.group(1)] = v
return state_dict_to_load, ckpt["epoch"], ckpt["global_step"]
def C(value: Any, epoch: int, global_step: int) -> float:
if isinstance(value, int) or isinstance(value, float):
pass
else:
value = config_to_primitive(value)
if not isinstance(value, list):
raise TypeError("Scalar specification only supports list, got", type(value))
if len(value) == 3:
value = [0] + value
assert len(value) == 4
start_step, start_value, end_value, end_step = value
if isinstance(end_step, int):
current_step = global_step
value = start_value + (end_value - start_value) * max(
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
)
elif isinstance(end_step, float):
current_step = epoch
value = start_value + (end_value - start_value) * max(
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
)
return value
def cleanup():
gc.collect()
torch.cuda.empty_cache()
try:
import tinycudann as tcnn
tcnn.free_temporary_memory()
except:
pass
def finish_with_cleanup(func: Callable):
def wrapper(*args, **kwargs):
out = func(*args, **kwargs)
cleanup()
return out
return wrapper
def _distributed_available():
return torch.distributed.is_available() and torch.distributed.is_initialized()
def barrier():
if not _distributed_available():
return
else:
torch.distributed.barrier()
def broadcast(tensor, src=0):
if not _distributed_available():
return tensor
else:
torch.distributed.broadcast(tensor, src=src)
return tensor
def enable_gradient(model, enabled: bool = True) -> None:
for param in model.parameters():
param.requires_grad_(enabled)
class TimeRecorder:
_instance = None
def __init__(self):
self.items = {}
self.accumulations = defaultdict(list)
self.time_scale = 1000.0 # ms
self.time_unit = "ms"
self.enabled = False
def __new__(cls):
# singleton
if cls._instance is None:
cls._instance = super(TimeRecorder, cls).__new__(cls)
return cls._instance
def enable(self, enabled: bool) -> None:
self.enabled = enabled
def start(self, name: str) -> None:
if not self.enabled:
return
torch.cuda.synchronize()
self.items[name] = time.time()
def end(self, name: str, accumulate: bool = False) -> float:
if not self.enabled or name not in self.items:
return
torch.cuda.synchronize()
start_time = self.items.pop(name)
delta = time.time() - start_time
if accumulate:
self.accumulations[name].append(delta)
t = delta * self.time_scale
info(f"{name}: {t:.2f}{self.time_unit}")
def get_accumulation(self, name: str, average: bool = False) -> float:
if not self.enabled or name not in self.accumulations:
return
acc = self.accumulations.pop(name)
total = sum(acc)
if average:
t = total / len(acc) * self.time_scale
else:
t = total * self.time_scale
info(f"{name} for {len(acc)} times: {t:.2f}{self.time_unit}")
### global time recorder
time_recorder = TimeRecorder()
@contextmanager
def time_recorder_enabled():
enabled = time_recorder.enabled
time_recorder.enable(enabled=True)
try:
yield
finally:
time_recorder.enable(enabled=enabled)
def show_vram_usage(name):
available, total = torch.cuda.mem_get_info()
used = total - available
print(
f"{name}: {used / 1024**2:.1f}MB, {psutil.Process(os.getpid()).memory_info().rss / 1024**2:.1f}MB"
)
|