Spaces:
Paused
Paused
File size: 1,973 Bytes
59d751c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import gc
from typing import Any, Dict, Union
import torch
from accelerate.logging import get_logger
from finetune.constants import LOG_LEVEL, LOG_NAME
logger = get_logger(LOG_NAME, LOG_LEVEL)
def get_memory_statistics(precision: int = 3) -> Dict[str, Any]:
memory_allocated = None
memory_reserved = None
max_memory_allocated = None
max_memory_reserved = None
if torch.cuda.is_available():
device = torch.cuda.current_device()
memory_allocated = torch.cuda.memory_allocated(device)
memory_reserved = torch.cuda.memory_reserved(device)
max_memory_allocated = torch.cuda.max_memory_allocated(device)
max_memory_reserved = torch.cuda.max_memory_reserved(device)
elif torch.mps.is_available():
memory_allocated = torch.mps.current_allocated_memory()
else:
logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.")
return {
"memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
"memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
"max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
"max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
}
def bytes_to_gigabytes(x: int) -> float:
if x is not None:
return x / 1024**3
def free_memory() -> None:
if torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# TODO(aryan): handle non-cuda devices
def unload_model(model):
model.to("cpu")
def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if isinstance(x, torch.Tensor):
return x.contiguous()
elif isinstance(x, dict):
return {k: make_contiguous(v) for k, v in x.items()}
else:
return x
|