tori29umai's picture
Update diffusers_helper/memory.py
82b1671 verified
# By lllyasviel
# WindowsとHugging Face Space環境の両方に対応した DynamicSwap + zeroGPU 対応バージョン
import os
import torch
# Hugging Face Space環境で実行されているかどうかを確認
IN_HF_SPACE = os.environ.get('SPACE_ID') is not None
# CPU デバイスを設定
cpu = torch.device('cpu')
# ステートレスGPU環境では、メインプロセスでCUDAを初期化しない
def get_gpu_device():
if IN_HF_SPACE:
# Spacesではデバイスの初期化を遅延させる
return 'cuda'
try:
if torch.cuda.is_available():
return torch.device(f'cuda:{torch.cuda.current_device()}')
else:
print("CUDAが利用できません。デフォルトデバイスとしてCPUを使用します")
return torch.device('cpu')
except Exception as e:
print(f"CUDAデバイスの初期化中にエラーが発生しました: {e}")
print("CPUデバイスにフォールバックします")
return torch.device('cpu')
# GPUデバイスを取得(文字列または実際のデバイスオブジェクト)
gpu = get_gpu_device()
# 完全にGPUにロードされたモジュールのリスト
gpu_complete_modules = []
class DynamicSwapInstaller:
@staticmethod
def _install_module(module: torch.nn.Module, **kwargs):
original_class = module.__class__
module.__dict__['forge_backup_original_class'] = original_class
def hacked_get_attr(self, name: str):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
p = _parameters[name]
if p is None:
return None
if p.__class__ == torch.nn.Parameter:
return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
else:
return p.to(**kwargs)
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name].to(**kwargs)
return super(original_class, self).__getattr__(name)
module.__class__ = type(
'DynamicSwap_' + original_class.__name__,
(original_class,),
{'__getattr__': hacked_get_attr}
)
@staticmethod
def _uninstall_module(module: torch.nn.Module):
if 'forge_backup_original_class' in module.__dict__:
module.__class__ = module.__dict__.pop('forge_backup_original_class')
@staticmethod
def install_model(model: torch.nn.Module, **kwargs):
for m in model.modules():
DynamicSwapInstaller._install_module(m, **kwargs)
@staticmethod
def uninstall_model(model: torch.nn.Module):
for m in model.modules():
DynamicSwapInstaller._uninstall_module(m)
def fake_diffusers_current_device(model: torch.nn.Module, target_device):
# 文字列デバイスをtorch.deviceに変換
if isinstance(target_device, str):
target_device = torch.device(target_device)
if hasattr(model, 'scale_shift_table'):
model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
return
for _, p in model.named_modules():
if hasattr(p, 'weight'):
p.to(target_device)
return
def get_cuda_free_memory_gb(device=None):
if device is None:
device = gpu
if isinstance(device, str):
device = torch.device(device)
if device.type != 'cuda':
# CUDAでない場合はデフォルト値
return 6.0
try:
stats = torch.cuda.memory_stats(device)
active = stats['active_bytes.all.current']
reserved = stats['reserved_bytes.all.current']
free_cuda, _ = torch.cuda.mem_get_info(device)
inactive = reserved - active
available = free_cuda + inactive
return available / (1024 ** 3)
except Exception as e:
print(f"CUDAメモリ情報取得エラー: {e}")
return 6.0
def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
print(f"{model.__class__.__name__}{target_device} に移動します。保持メモリ: {preserved_memory_gb} GB")
if isinstance(target_device, str):
target_device = torch.device(target_device)
# CPUまたはGPU未使用時は直接移動
if target_device.type == 'cpu':
model.to(device=target_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return
for m in model.modules():
if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
torch.cuda.empty_cache()
return
if hasattr(m, 'weight'):
m.to(device=target_device)
model.to(device=target_device)
torch.cuda.empty_cache()
def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
print(f"メモリ保持のため {model.__class__.__name__}{target_device} からオフロードします: {preserved_memory_gb} GB")
if isinstance(target_device, str):
target_device = torch.device(target_device)
if target_device.type == 'cpu':
model.to(device=cpu)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return
for m in model.modules():
if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
torch.cuda.empty_cache()
return
if hasattr(m, 'weight'):
m.to(device=cpu)
model.to(device=cpu)
torch.cuda.empty_cache()
def unload_complete_models(*args):
for m in gpu_complete_modules + list(args):
if m is None:
continue
m.to(device=cpu)
print(f"{m.__class__.__name__} を完全にアンロードしました")
gpu_complete_modules.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def load_model_as_complete(model, target_device, unload=True):
if isinstance(target_device, str):
target_device = torch.device(target_device)
if unload:
unload_complete_models()
model.to(device=target_device)
print(f"{model.__class__.__name__}{target_device} に完全にロードしました")
gpu_complete_modules.append(model)