tori29umai commited on
Commit
82b1671
·
verified ·
1 Parent(s): fd4a82d

Update diffusers_helper/memory.py

Browse files
Files changed (1) hide show
  1. diffusers_helper/memory.py +78 -40
diffusers_helper/memory.py CHANGED
@@ -1,13 +1,36 @@
1
  # By lllyasviel
 
2
 
3
-
4
  import torch
5
 
 
 
6
 
 
7
  cpu = torch.device('cpu')
8
- gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
9
- gpu_complete_modules = []
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class DynamicSwapInstaller:
13
  @staticmethod
@@ -32,37 +55,36 @@ class DynamicSwapInstaller:
32
  return _buffers[name].to(**kwargs)
33
  return super(original_class, self).__getattr__(name)
34
 
35
- module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
36
- '__getattr__': hacked_get_attr,
37
- })
38
-
39
- return
40
 
41
  @staticmethod
42
  def _uninstall_module(module: torch.nn.Module):
43
  if 'forge_backup_original_class' in module.__dict__:
44
  module.__class__ = module.__dict__.pop('forge_backup_original_class')
45
- return
46
 
47
  @staticmethod
48
  def install_model(model: torch.nn.Module, **kwargs):
49
  for m in model.modules():
50
  DynamicSwapInstaller._install_module(m, **kwargs)
51
- return
52
 
53
  @staticmethod
54
  def uninstall_model(model: torch.nn.Module):
55
  for m in model.modules():
56
  DynamicSwapInstaller._uninstall_module(m)
57
- return
58
 
59
 
60
- def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device):
 
 
 
61
  if hasattr(model, 'scale_shift_table'):
62
  model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
63
  return
64
-
65
- for k, p in model.named_modules():
66
  if hasattr(p, 'weight'):
67
  p.to(target_device)
68
  return
@@ -71,64 +93,80 @@ def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.d
71
  def get_cuda_free_memory_gb(device=None):
72
  if device is None:
73
  device = gpu
74
-
75
- memory_stats = torch.cuda.memory_stats(device)
76
- bytes_active = memory_stats['active_bytes.all.current']
77
- bytes_reserved = memory_stats['reserved_bytes.all.current']
78
- bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
79
- bytes_inactive_reserved = bytes_reserved - bytes_active
80
- bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
81
- return bytes_total_available / (1024 ** 3)
 
 
 
 
 
 
 
 
82
 
83
 
84
  def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
85
- print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')
86
-
 
 
 
 
 
 
 
87
  for m in model.modules():
88
  if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
89
  torch.cuda.empty_cache()
90
  return
91
-
92
  if hasattr(m, 'weight'):
93
  m.to(device=target_device)
94
-
95
  model.to(device=target_device)
96
  torch.cuda.empty_cache()
97
- return
98
 
99
 
100
  def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
101
- print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')
102
-
 
 
 
 
 
 
103
  for m in model.modules():
104
  if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
105
  torch.cuda.empty_cache()
106
  return
107
-
108
  if hasattr(m, 'weight'):
109
  m.to(device=cpu)
110
-
111
  model.to(device=cpu)
112
  torch.cuda.empty_cache()
113
- return
114
 
115
 
116
  def unload_complete_models(*args):
117
  for m in gpu_complete_modules + list(args):
 
 
118
  m.to(device=cpu)
119
- print(f'Unloaded {m.__class__.__name__} as complete.')
120
-
121
  gpu_complete_modules.clear()
122
- torch.cuda.empty_cache()
123
- return
124
 
125
 
126
  def load_model_as_complete(model, target_device, unload=True):
 
 
127
  if unload:
128
  unload_complete_models()
129
-
130
  model.to(device=target_device)
131
- print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')
132
-
133
  gpu_complete_modules.append(model)
134
- return
 
1
  # By lllyasviel
2
+ # WindowsとHugging Face Space環境の両方に対応した DynamicSwap + zeroGPU 対応バージョン
3
 
4
+ import os
5
  import torch
6
 
7
+ # Hugging Face Space環境で実行されているかどうかを確認
8
+ IN_HF_SPACE = os.environ.get('SPACE_ID') is not None
9
 
10
+ # CPU デバイスを設定
11
  cpu = torch.device('cpu')
 
 
12
 
13
+ # ステートレスGPU環境では、メインプロセスでCUDAを初期化しない
14
+ def get_gpu_device():
15
+ if IN_HF_SPACE:
16
+ # Spacesではデバイスの初期化を遅延させる
17
+ return 'cuda'
18
+ try:
19
+ if torch.cuda.is_available():
20
+ return torch.device(f'cuda:{torch.cuda.current_device()}')
21
+ else:
22
+ print("CUDAが利用できません。デフォルトデバイスとしてCPUを使用します")
23
+ return torch.device('cpu')
24
+ except Exception as e:
25
+ print(f"CUDAデバイスの初期化中にエラーが発生しました: {e}")
26
+ print("CPUデバイスにフォールバックします")
27
+ return torch.device('cpu')
28
+
29
+ # GPUデバイスを取得(文字列または実際のデバイスオブジェクト)
30
+ gpu = get_gpu_device()
31
+
32
+ # 完全にGPUにロードされたモジュールのリスト
33
+ gpu_complete_modules = []
34
 
35
  class DynamicSwapInstaller:
36
  @staticmethod
 
55
  return _buffers[name].to(**kwargs)
56
  return super(original_class, self).__getattr__(name)
57
 
58
+ module.__class__ = type(
59
+ 'DynamicSwap_' + original_class.__name__,
60
+ (original_class,),
61
+ {'__getattr__': hacked_get_attr}
62
+ )
63
 
64
  @staticmethod
65
  def _uninstall_module(module: torch.nn.Module):
66
  if 'forge_backup_original_class' in module.__dict__:
67
  module.__class__ = module.__dict__.pop('forge_backup_original_class')
 
68
 
69
  @staticmethod
70
  def install_model(model: torch.nn.Module, **kwargs):
71
  for m in model.modules():
72
  DynamicSwapInstaller._install_module(m, **kwargs)
 
73
 
74
  @staticmethod
75
  def uninstall_model(model: torch.nn.Module):
76
  for m in model.modules():
77
  DynamicSwapInstaller._uninstall_module(m)
 
78
 
79
 
80
+ def fake_diffusers_current_device(model: torch.nn.Module, target_device):
81
+ # 文字列デバイスをtorch.deviceに変換
82
+ if isinstance(target_device, str):
83
+ target_device = torch.device(target_device)
84
  if hasattr(model, 'scale_shift_table'):
85
  model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
86
  return
87
+ for _, p in model.named_modules():
 
88
  if hasattr(p, 'weight'):
89
  p.to(target_device)
90
  return
 
93
  def get_cuda_free_memory_gb(device=None):
94
  if device is None:
95
  device = gpu
96
+ if isinstance(device, str):
97
+ device = torch.device(device)
98
+ if device.type != 'cuda':
99
+ # CUDAでない場合はデフォルト値
100
+ return 6.0
101
+ try:
102
+ stats = torch.cuda.memory_stats(device)
103
+ active = stats['active_bytes.all.current']
104
+ reserved = stats['reserved_bytes.all.current']
105
+ free_cuda, _ = torch.cuda.mem_get_info(device)
106
+ inactive = reserved - active
107
+ available = free_cuda + inactive
108
+ return available / (1024 ** 3)
109
+ except Exception as e:
110
+ print(f"CUDAメモリ情報取得エラー: {e}")
111
+ return 6.0
112
 
113
 
114
  def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
115
+ print(f"{model.__class__.__name__} {target_device} に移動します。保持メモリ: {preserved_memory_gb} GB")
116
+ if isinstance(target_device, str):
117
+ target_device = torch.device(target_device)
118
+ # CPUまたはGPU未使用時は直接移動
119
+ if target_device.type == 'cpu':
120
+ model.to(device=target_device)
121
+ if torch.cuda.is_available():
122
+ torch.cuda.empty_cache()
123
+ return
124
  for m in model.modules():
125
  if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
126
  torch.cuda.empty_cache()
127
  return
 
128
  if hasattr(m, 'weight'):
129
  m.to(device=target_device)
 
130
  model.to(device=target_device)
131
  torch.cuda.empty_cache()
 
132
 
133
 
134
  def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
135
+ print(f"メモリ保持のため {model.__class__.__name__} {target_device} からオフロードします: {preserved_memory_gb} GB")
136
+ if isinstance(target_device, str):
137
+ target_device = torch.device(target_device)
138
+ if target_device.type == 'cpu':
139
+ model.to(device=cpu)
140
+ if torch.cuda.is_available():
141
+ torch.cuda.empty_cache()
142
+ return
143
  for m in model.modules():
144
  if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
145
  torch.cuda.empty_cache()
146
  return
 
147
  if hasattr(m, 'weight'):
148
  m.to(device=cpu)
 
149
  model.to(device=cpu)
150
  torch.cuda.empty_cache()
 
151
 
152
 
153
  def unload_complete_models(*args):
154
  for m in gpu_complete_modules + list(args):
155
+ if m is None:
156
+ continue
157
  m.to(device=cpu)
158
+ print(f"{m.__class__.__name__} を完全にアンロードしました")
 
159
  gpu_complete_modules.clear()
160
+ if torch.cuda.is_available():
161
+ torch.cuda.empty_cache()
162
 
163
 
164
  def load_model_as_complete(model, target_device, unload=True):
165
+ if isinstance(target_device, str):
166
+ target_device = torch.device(target_device)
167
  if unload:
168
  unload_complete_models()
 
169
  model.to(device=target_device)
170
+ print(f"{model.__class__.__name__} {target_device} に完全にロードしました")
 
171
  gpu_complete_modules.append(model)
172
+