1inkusFace commited on
Commit
7a6cd8f
·
verified ·
1 Parent(s): eedc8c2

Update skyreelsinfer/offload.py

Browse files
Files changed (1) hide show
  1. skyreelsinfer/offload.py +493 -311
skyreelsinfer/offload.py CHANGED
@@ -1,333 +1,515 @@
1
- import spaces
2
- import gradio as gr
3
- import argparse
4
- import sys
5
  import os
6
- import random
7
- import subprocess
8
- from PIL import Image
9
- import numpy as np
10
-
11
- # Removed environment-specific lines
12
- from diffusers.utils import export_to_video
13
- from diffusers.utils import load_image
14
 
15
  import torch
16
- import logging
17
- from collections import OrderedDict
18
-
19
- torch.backends.cuda.matmul.allow_tf32 = False
20
- torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
21
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
22
- torch.backends.cudnn.allow_tf32 = False
23
- torch.backends.cudnn.deterministic = False
24
- torch.backends.cudnn.benchmark = False
25
- torch.set_float32_matmul_precision("highest")
26
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
27
-
28
- logger = logging.getLogger(__name__)
29
 
30
 
31
- # --- Dummy Classes (Keep for standalone execution) ---
32
  class OffloadConfig:
33
- def __init__(
34
- self,
35
- high_cpu_memory: bool = False,
36
- parameters_level: bool = False,
37
- compiler_transformer: bool = False,
38
- compiler_cache: str = "",
39
- ):
40
- self.high_cpu_memory = high_cpu_memory
41
- self.parameters_level = parameters_level
42
- self.compiler_transformer = compiler_transformer
43
- self.compiler_cache = compiler_cache
44
-
45
-
46
- class TaskType: # Keep here for infer
47
- T2V = 0
48
- I2V = 1
49
-
50
-
51
- class LlamaModel:
52
- @staticmethod
53
- def from_pretrained(*args, **kwargs):
54
- return LlamaModel()
55
-
56
- def to(self, device):
57
- return self
58
-
59
-
60
- class HunyuanVideoTransformer3DModel:
61
- @staticmethod
62
- def from_pretrained(*args, **kwargs):
63
- return HunyuanVideoTransformer3DModel()
64
-
65
- def to(self, device):
66
- return self
67
-
68
-
69
- class SkyreelsVideoPipeline:
70
- @staticmethod
71
- def from_pretrained(*args, **kwargs):
72
- return SkyreelsVideoPipeline()
73
-
74
- def to(self, device):
75
- return self
76
 
77
- def __call__(self, *args, **kwargs):
78
- num_frames = kwargs.get("num_frames", 16) # Default to 16 frames
79
- height = kwargs.get("height", 512)
80
- width = kwargs.get("width", 512)
81
-
82
- if "image" in kwargs: # I2V
83
- image = kwargs["image"]
84
- # Convert PIL Image to PyTorch tensor (and normalize to [0, 1])
85
- image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
86
- image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # (H, W, C) -> (1, C, H, W)
87
-
88
- # Create video by repeating the image
89
- frames = image_tensor.repeat(1, 1, num_frames, 1, 1) # (1, C, T, H, W)
90
- frames = frames + torch.randn_like(frames) * 0.05 # Add a little noise
91
- # Correct shape: (1, C, T, H, W) - NO PERMUTE HERE
92
-
93
- else: # T2V
94
- frames = torch.randn(1, 3, num_frames, height, width) # (1, C, T, H, W) - Correct!
95
-
96
- return type("obj", (object,), {"frames": frames})() # No longer a list!
97
 
 
98
  def __init__(self):
99
- super().__init__()
100
- self._modules = OrderedDict()
101
- self.vae = self.VAE()
102
- self._modules["vae"] = self.vae
103
-
104
- def named_children(self):
105
- return self._modules.items()
106
-
107
- class VAE:
108
- def enable_tiling(self):
109
- pass
110
-
111
-
112
- def quantize_(*args, **kwargs):
113
- return
114
-
115
-
116
- def float8_weight_only():
117
- return
118
-
119
-
120
- # --- End Dummy Classes ---
121
-
122
-
123
- class SkyReelsVideoSingleGpuInfer:
124
- def _load_model(
125
- self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True
126
- ):
127
- logger.info(f"load model model_id:{model_id} quan_model:{quant_model}")
128
- text_encoder = LlamaModel.from_pretrained(
129
- base_model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
130
- ).to("cpu")
131
- transformer = HunyuanVideoTransformer3DModel.from_pretrained(
132
- model_id, torch_dtype=torch.bfloat16, device="cpu"
133
- ).to("cpu")
134
-
135
- if quant_model:
136
- quantize_(text_encoder, float8_weight_only())
137
- text_encoder.to("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  torch.cuda.empty_cache()
139
- quantize_(transformer, float8_weight_only())
140
- transformer.to("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  torch.cuda.empty_cache()
142
 
143
- pipe = SkyreelsVideoPipeline.from_pretrained(
144
- base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16
145
- ).to("cpu")
146
- pipe.vae.enable_tiling()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  torch.cuda.empty_cache()
148
- return pipe
149
-
150
- def __init__(
151
- self,
152
- task_type: TaskType,
153
- model_id: str,
154
- quant_model: bool = True,
155
- is_offload: bool = True,
156
- offload_config: OffloadConfig = OffloadConfig(),
157
- enable_cfg_parallel: bool = True,
158
- ):
159
- self.task_type = task_type
160
- self.model_id = model_id
161
- self.quant_model = quant_model
162
- self.is_offload = is_offload
163
- self.offload_config = offload_config
164
- self.enable_cfg_parallel = enable_cfg_parallel
165
- self.pipe = None
166
- self.is_initialized = False
167
- self.gpu_device = None
168
-
169
- def initialize(self):
170
- """Initializes the model and moves it to the GPU."""
171
- if self.is_initialized:
172
- return
173
 
174
- if not torch.cuda.is_available():
175
- raise RuntimeError("CUDA is not available. Cannot initialize model.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- self.gpu_device = "cuda:0"
178
- self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- if self.is_offload:
181
- pass
 
 
 
 
 
182
  else:
183
- self.pipe.to(self.gpu_device)
184
-
185
- if self.offload_config.compiler_transformer:
186
- torch._dynamo.config.suppress_errors = True
187
- os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
188
- os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{self.offload_config.compiler_cache}"
189
- self.pipe.transformer = torch.compile(
190
- self.pipe.transformer, mode="max-autotune-no-cudagraphs", dynamic=True
191
  )
192
- if self.offload_config.compiler_transformer:
193
- self.warm_up()
194
- self.is_initialized = True
195
-
196
- def warm_up(self):
197
- if not self.is_initialized:
198
- raise RuntimeError("Model must be initialized before warm-up.")
199
-
200
- init_kwargs = {
201
- "prompt": "A woman is dancing in a room",
202
- "height": 544,
203
- "width": 960,
204
- "guidance_scale": 6,
205
- "num_inference_steps": 1,
206
- "negative_prompt": "bad quality",
207
- "num_frames": 16,
208
- "generator": torch.Generator(self.gpu_device).manual_seed(42),
209
- "embedded_guidance_scale": 1.0,
210
- }
211
- if self.task_type == TaskType.I2V:
212
- init_kwargs["image"] = Image.new("RGB", (544, 960), color="black")
213
- self.pipe(**init_kwargs)
214
- logger.info("Warm-up complete.")
215
-
216
- def infer(self, **kwargs):
217
- """Handles inference requests."""
218
- if not self.is_initialized:
219
- self.initialize()
220
- if "seed" in kwargs:
221
- kwargs["generator"] = torch.Generator(self.gpu_device).manual_seed(kwargs["seed"])
222
- del kwargs["seed"]
223
- assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
224
- result = self.pipe(**kwargs).frames # Return the tensor directly
225
- return result
226
-
227
-
228
- _predictor = None
229
-
230
-
231
- @spaces.GPU(duration=90)
232
- def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict]:
233
- """Generates a video based on the given prompt and seed.
234
-
235
- Args:
236
- prompt: The text prompt to guide video generation.
237
- seed: The random seed for reproducibility.
238
- image: Optional path to an image for Image-to-Video.
239
-
240
- Returns:
241
- A tuple containing the path to the generated video and the parameters used.
242
- """
243
- global _predictor
244
-
245
- if seed == -1:
246
- random.seed()
247
- seed = int(random.randrange(4294967294))
248
-
249
- if image is None:
250
- task_type = TaskType.T2V
251
- model_id = "Skywork/SkyReels-V1-Hunyuan-T2V"
252
- kwargs = {
253
- "prompt": prompt,
254
- "height": 512,
255
- "width": 512,
256
- "num_frames": 16,
257
- "num_inference_steps": 30,
258
- "seed": seed,
259
- "guidance_scale": 7.5,
260
- "negative_prompt": "bad quality, worst quality",
261
- }
262
- else:
263
- task_type = TaskType.I2V
264
- model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
265
- kwargs = {
266
- "prompt": prompt,
267
- "image": load_image(image),
268
- "height": 512,
269
- "width": 512,
270
- "num_frames": 97,
271
- "num_inference_steps": 30,
272
- "seed": seed,
273
- "guidance_scale": 6.0,
274
- "embedded_guidance_scale": 1.0,
275
- "negative_prompt": "Aerial view, low quality, bad hands",
276
- "cfg_for": False,
277
- }
278
 
279
- if _predictor is None:
280
- _predictor = SkyReelsVideoSingleGpuInfer(
281
- task_type=task_type,
282
- model_id=model_id,
283
- quant_model=True,
284
- is_offload=True,
285
- offload_config=OffloadConfig(
286
- high_cpu_memory=True,
287
- parameters_level=True,
288
- compiler_transformer=False,
289
- ),
290
- )
291
- _predictor.initialize()
292
- logger.info("Predictor initialized")
293
-
294
- with torch.no_grad():
295
- output = _predictor.infer(**kwargs)
296
- '''
297
- output = (output.numpy() * 255).astype(np.uint8)
298
- # Correct Transpose: (1, C, T, H, W) -> (1, T, H, W, C)
299
- output = output.transpose(0, 2, 3, 4, 1)
300
- output = output[0] # Remove batch dimension: (T, H, W, C)
301
- '''
302
-
303
- save_dir = f"./result"
304
- os.makedirs(save_dir, exist_ok=True)
305
- video_out_file = f"{save_dir}/{seed}.mp4"
306
- print(f"generate video, local path: {video_out_file}")
307
- export_to_video(output, video_out_file, fps=24)
308
- return video_out_file, kwargs
309
-
310
-
311
- def create_gradio_interface():
312
- with gr.Blocks() as demo:
313
- with gr.Row():
314
- with gr.Column():
315
- image = gr.Image(label="Upload Image", type="filepath")
316
- prompt = gr.Textbox(label="Input Prompt")
317
- seed = gr.Number(label="Random Seed", value=-1)
318
- with gr.Column():
319
- submit_button = gr.Button("Generate Video")
320
- output_video = gr.Video(label="Generated Video")
321
- output_params = gr.Textbox(label="Output Parameters")
322
-
323
- submit_button.click(
324
- fn=generate_video,
325
- inputs=[prompt, seed, image],
326
- outputs=[output_video, output_params],
327
  )
328
- return demo
329
 
 
 
330
 
331
- if __name__ == "__main__":
332
- demo = create_gradio_interface()
333
- demo.queue().launch()
 
1
+ import functools
2
+ import gc
 
 
3
  import os
4
+ import time
5
+ from dataclasses import dataclass
 
 
 
 
 
 
6
 
7
  import torch
8
+ from diffusers.pipelines import DiffusionPipeline
9
+ from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
+ @dataclass
13
  class OffloadConfig:
14
+ # high_cpu_memory: Whether to use pinned memory for offload optimization. This can effectively prevent increased model offload latency caused by memory swapping.
15
+ high_cpu_memory: bool = True
16
+ # parameters_level: Whether to enable parameter-level offload. This further reduces VRAM requirements but may result in increased latency.
17
+ parameters_level: bool = False
18
+ # compiler_transformer: Whether to enable compilation optimization for the transformer.
19
+ compiler_transformer: bool = False
20
+ compiler_cache: str = "/tmp/compile_cache"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ class HfHook:
24
  def __init__(self):
25
+ device_id = os.environ.get("LOCAL_RANK", 0)
26
+ self.execution_device = f"cuda:{device_id}"
27
+
28
+ def detach_hook(self, module):
29
+ pass
30
+
31
+
32
+ class Offload:
33
+ def __init__(self) -> None:
34
+ self.active_models = []
35
+ self.active_models_ids = []
36
+ self.active_subcaches = {}
37
+ self.models = {}
38
+ self.verboseLevel = 0
39
+ self.models_to_quantize = []
40
+ self.pinned_modules_data = {}
41
+ self.blocks_of_modules = {}
42
+ self.blocks_of_modules_sizes = {}
43
+ self.compile = False
44
+ self.device_mem_capacity = torch.cuda.get_device_properties(0).total_memory
45
+ self.last_reserved_mem_check = 0
46
+ self.loaded_blocks = {}
47
+ self.prev_blocks_names = {}
48
+ self.next_blocks_names = {}
49
+ device_id = os.environ.get("LOCAL_RANK", 0)
50
+ self.device_id = f"cuda:{device_id}"
51
+ self.default_stream = torch.cuda.default_stream(self.device_id) # torch.cuda.current_stream()
52
+ self.transfer_stream = torch.cuda.Stream()
53
+ self.async_transfers = False
54
+ self.last_run_model = None
55
+
56
+ @classmethod
57
+ def offload(cls, pipeline: DiffusionPipeline, config: OffloadConfig = OffloadConfig()):
58
+ """
59
+ Enable offloading for multiple models in the pipeline, supporting video generation inference on user-level GPUs.
60
+ pipe: the pipeline object
61
+ config: offload strategy configuration
62
+ """
63
+ self = cls()
64
+ self.pinned_modules_data = {}
65
+ if config.parameters_level:
66
+ model_budgets = {
67
+ "transformer": 600 * 1024 * 1024,
68
+ "text_encoder": 3 * 1024 * 1024 * 1024,
69
+ "text_encoder_2": 3 * 1024 * 1024 * 1024,
70
+ }
71
+ self.async_transfers = True
72
+ else:
73
+ model_budgets = {}
74
+
75
+ device_id = os.getenv("LOCAL_RANK", 0)
76
+ torch.set_default_device(f"cuda:{device_id}")
77
+ pipeline.hf_device_map = torch.device(f"cuda:{device_id}")
78
+ pipe_or_dict_of_modules = pipeline.components
79
+ if config.compiler_transformer:
80
+ pipeline.transformer.to("cuda")
81
+ models = {
82
+ k: v
83
+ for k, v in pipe_or_dict_of_modules.items()
84
+ if isinstance(v, torch.nn.Module) and not (config.compiler_transformer and k == "transformer")
85
+ }
86
+ print_info = {k: type(v) for k, v in models.items()}
87
+ print(f"offload models: {print_info}")
88
+ if config.compiler_transformer:
89
+ pipeline.text_encoder.to("cpu")
90
+ pipeline.text_encoder_2.to("cpu")
91
  torch.cuda.empty_cache()
92
+ pipeline.transformer.to("cuda")
93
+ pipeline.vae.to("cuda")
94
+
95
+ def move_text_encoder_to_gpu(pipe):
96
+ torch.cuda.empty_cache()
97
+ pipe.text_encoder.to("cuda")
98
+ pipe.text_encoder_2.to("cuda")
99
+
100
+ def move_text_encoder_to_cpu(pipe):
101
+ pipe.text_encoder.to("cpu")
102
+ pipe.text_encoder_2.to("cpu")
103
+ torch.cuda.empty_cache()
104
+
105
+ setattr(pipeline, "text_encoder_to_cpu", functools.partial(move_text_encoder_to_cpu, pipeline))
106
+ setattr(pipeline, "text_encoder_to_gpu", functools.partial(move_text_encoder_to_gpu, pipeline))
107
+
108
+ for k, module in pipe_or_dict_of_modules.items():
109
+ if isinstance(module, torch.nn.Module):
110
+ for submodule_name, submodule in module.named_modules():
111
+ if not hasattr(submodule, "_hf_hook"):
112
+ setattr(submodule, "_hf_hook", HfHook())
113
+ return self
114
+
115
+ sizeofbfloat16 = torch.bfloat16.itemsize
116
+ modelPinned = config.high_cpu_memory
117
+ # Pin in RAM models
118
+ # Calculate the VRAM requirements of the computational modules to determine whether parameters-level offload is necessary.
119
+ for model_name, curr_model in models.items():
120
+ curr_model.to("cpu").eval()
121
+ pinned_parameters_data = {}
122
+ current_model_size = 0
123
+ print(f"{model_name} move to pinned memory:{modelPinned}")
124
+ for p in curr_model.parameters():
125
+ if isinstance(p, AffineQuantizedTensor):
126
+ if not modelPinned and p.tensor_impl.scale.dtype == torch.float32:
127
+ p.tensor_impl.scale = p.tensor_impl.scale.to(torch.bfloat16)
128
+ current_model_size += torch.numel(p.tensor_impl.scale) * sizeofbfloat16
129
+ current_model_size += torch.numel(p.tensor_impl.float8_data) * sizeofbfloat16 / 2
130
+ if modelPinned:
131
+ p.tensor_impl.float8_data = p.tensor_impl.float8_data.pin_memory()
132
+ p.tensor_impl.scale = p.tensor_impl.scale.pin_memory()
133
+ pinned_parameters_data[p] = [p.tensor_impl.float8_data, p.tensor_impl.scale]
134
+ else:
135
+ p.data = p.data.to(torch.bfloat16) if p.data.dtype == torch.float32 else p.data.to(p.data.dtype)
136
+ current_model_size += torch.numel(p.data) * p.data.element_size()
137
+ if modelPinned:
138
+ p.data = p.data.pin_memory()
139
+ pinned_parameters_data[p] = p.data
140
+
141
+ for buffer in curr_model.buffers():
142
+ buffer.data = (
143
+ buffer.data.to(torch.bfloat16)
144
+ if buffer.data.dtype == torch.float32
145
+ else buffer.data.to(buffer.data.dtype)
146
+ )
147
+ current_model_size += torch.numel(buffer.data) * buffer.data.element_size()
148
+ if modelPinned:
149
+ buffer.data = buffer.data.pin_memory()
150
+
151
+ if model_name not in self.models:
152
+ self.models[model_name] = curr_model
153
+
154
+ curr_model_budget = model_budgets.get(model_name, 0)
155
+ if curr_model_budget > 0 and curr_model_budget > current_model_size:
156
+ model_budgets[model_name] = 0
157
+
158
+ if modelPinned:
159
+ pinned_buffers_data = {b: b.data for b in curr_model.buffers()}
160
+ pinned_parameters_data.update(pinned_buffers_data)
161
+ self.pinned_modules_data[model_name] = pinned_parameters_data
162
+ gc.collect()
163
  torch.cuda.empty_cache()
164
 
165
+ # if config.compiler_transformer:
166
+ # module = pipeline.transformer
167
+ # print("wrap transformer forward")
168
+ # # gpu model wrap
169
+ # for submodule_name, submodule in module.named_modules():
170
+ # if not hasattr(submodule, "_hf_hook"):
171
+ # setattr(submodule, "_hf_hook", HfHook())
172
+ #
173
+ # forward_method = getattr(module, "forward")
174
+ #
175
+ # def wrap_unload_all(*args, **kwargs):
176
+ # self.unload_all("transformer")
177
+ # return forward_method(*args, **kwargs)
178
+ #
179
+ # setattr(module, "forward", functools.update_wrapper(wrap_unload_all, forward_method))
180
+
181
+ # wrap forward methods
182
+ for model_name, curr_model in models.items():
183
+ current_budget = model_budgets.get(model_name, 0)
184
+ current_size = 0
185
+ self.loaded_blocks[model_name] = None
186
+ cur_blocks_prefix, prev_blocks_name, cur_blocks_name, cur_blocks_seq = None, None, None, -1
187
+
188
+ for submodule_name, submodule in curr_model.named_modules():
189
+ # create a fake accelerate parameter so that the _execution_device property returns always "cuda"
190
+ if not hasattr(submodule, "_hf_hook"):
191
+ setattr(submodule, "_hf_hook", HfHook())
192
+
193
+ if not submodule_name:
194
+ continue
195
+
196
+ # usr parameters-level offload
197
+ if current_budget > 0:
198
+ if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
199
+ if cur_blocks_prefix == None:
200
+ cur_blocks_prefix = submodule_name + "."
201
+ else:
202
+ if not submodule_name.startswith(cur_blocks_prefix):
203
+ cur_blocks_prefix = submodule_name + "."
204
+ cur_blocks_name, cur_blocks_seq = None, -1
205
+ else:
206
+ if cur_blocks_prefix is not None:
207
+ if submodule_name.startswith(cur_blocks_prefix):
208
+ num = int(submodule_name[len(cur_blocks_prefix) :].split(".")[0])
209
+ if num != cur_blocks_seq and (cur_blocks_name == None or current_size > current_budget):
210
+ prev_blocks_name = cur_blocks_name
211
+ cur_blocks_name = cur_blocks_prefix + str(num)
212
+ cur_blocks_seq = num
213
+ else:
214
+ cur_blocks_prefix = None
215
+ prev_blocks_name = None
216
+ cur_blocks_name = None
217
+ cur_blocks_seq = -1
218
+
219
+ if hasattr(submodule, "forward"):
220
+ submodule_forward = getattr(submodule, "forward")
221
+ if not callable(submodule_forward):
222
+ print("***")
223
+ continue
224
+ if len(submodule_name.split(".")) == 1:
225
+ self.hook_me(submodule, curr_model, model_name, submodule_name, submodule_forward)
226
+ else:
227
+ self.hook_me_light(
228
+ submodule, model_name, cur_blocks_name, submodule_forward, context=submodule_name
229
+ )
230
+ current_size = self.add_module_to_blocks(model_name, cur_blocks_name, submodule, prev_blocks_name)
231
+
232
+ gc.collect()
233
  torch.cuda.empty_cache()
234
+ return self
235
+
236
+ def add_module_to_blocks(self, model_name, blocks_name, submodule, prev_block_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
+ entry_name = model_name if blocks_name is None else model_name + "/" + blocks_name
239
+ if entry_name in self.blocks_of_modules:
240
+ blocks_params = self.blocks_of_modules[entry_name]
241
+ blocks_params_size = self.blocks_of_modules_sizes[entry_name]
242
+ else:
243
+ blocks_params = []
244
+ self.blocks_of_modules[entry_name] = blocks_params
245
+ blocks_params_size = 0
246
+ if blocks_name != None:
247
+ prev_entry_name = None if prev_block_name == None else model_name + "/" + prev_block_name
248
+ self.prev_blocks_names[entry_name] = prev_entry_name
249
+ if not prev_block_name == None:
250
+ self.next_blocks_names[prev_entry_name] = entry_name
251
+
252
+ for p in submodule.parameters(recurse=False):
253
+ blocks_params.append(p)
254
+ if isinstance(p, AffineQuantizedTensor):
255
+ blocks_params_size += p.tensor_impl.float8_data.nbytes
256
+ blocks_params_size += p.tensor_impl.scale.nbytes
257
+ else:
258
+ blocks_params_size += p.data.nbytes
259
+
260
+ for p in submodule.buffers(recurse=False):
261
+ blocks_params.append(p)
262
+ blocks_params_size += p.data.nbytes
263
+
264
+ self.blocks_of_modules_sizes[entry_name] = blocks_params_size
265
+
266
+ return blocks_params_size
267
+
268
+ def can_model_be_cotenant(self, model_name):
269
+ cotenants_map = {
270
+ "text_encoder": ["vae", "text_encoder_2"],
271
+ "text_encoder_2": ["vae", "text_encoder"],
272
+ }
273
+ potential_cotenants = cotenants_map.get(model_name, None)
274
+ if potential_cotenants is None:
275
+ return False
276
+ for existing_cotenant in self.active_models_ids:
277
+ if existing_cotenant not in potential_cotenants:
278
+ return False
279
+ return True
280
+
281
+ @torch.compiler.disable()
282
+ def gpu_load_blocks(self, model_name, blocks_name, async_load=False):
283
+ if blocks_name != None:
284
+ self.loaded_blocks[model_name] = blocks_name
285
+
286
+ def cpu_to_gpu(stream_to_use, blocks_params, record_for_stream=None):
287
+ with torch.cuda.stream(stream_to_use):
288
+ for p in blocks_params:
289
+ if isinstance(p, AffineQuantizedTensor):
290
+ p.tensor_impl.float8_data = p.tensor_impl.float8_data.cuda(
291
+ non_blocking=True, device=self.device_id
292
+ )
293
+ p.tensor_impl.scale = p.tensor_impl.scale.cuda(non_blocking=True, device=self.device_id)
294
+ else:
295
+ p.data = p.data.cuda(non_blocking=True, device=self.device_id)
296
+
297
+ if record_for_stream != None:
298
+ if isinstance(p, AffineQuantizedTensor):
299
+ p.tensor_impl.float8_data.record_stream(record_for_stream)
300
+ p.tensor_impl.scale.record_stream(record_for_stream)
301
+ else:
302
+ p.data.record_stream(record_for_stream)
303
+
304
+ entry_name = model_name if blocks_name is None else model_name + "/" + blocks_name
305
+ if self.verboseLevel >= 2:
306
+ model = self.models[model_name]
307
+ model_name = model._get_name()
308
+ print(f"Loading model {entry_name} ({model_name}) in GPU")
309
+
310
+ if self.async_transfers and blocks_name != None:
311
+ first = self.prev_blocks_names[entry_name] == None
312
+ next_blocks_entry = self.next_blocks_names[entry_name] if entry_name in self.next_blocks_names else None
313
+ if first:
314
+ cpu_to_gpu(torch.cuda.current_stream(), self.blocks_of_modules[entry_name])
315
+ torch.cuda.synchronize()
316
+
317
+ if next_blocks_entry != None:
318
+ cpu_to_gpu(self.transfer_stream, self.blocks_of_modules[next_blocks_entry])
319
 
320
+ else:
321
+ cpu_to_gpu(self.default_stream, self.blocks_of_modules[entry_name])
322
+ torch.cuda.synchronize()
323
+
324
+ @torch.compiler.disable()
325
+ def gpu_unload_blocks(self, model_name, blocks_name):
326
+ if blocks_name != None:
327
+ self.loaded_blocks[model_name] = None
328
+
329
+ blocks_name = model_name if blocks_name is None else model_name + "/" + blocks_name
330
+
331
+ if self.verboseLevel >= 2:
332
+ model = self.models[model_name]
333
+ model_name = model._get_name()
334
+ print(f"Unloading model {blocks_name} ({model_name}) from GPU")
335
+
336
+ blocks_params = self.blocks_of_modules[blocks_name]
337
+
338
+ if model_name in self.pinned_modules_data:
339
+ pinned_parameters_data = self.pinned_modules_data[model_name]
340
+ for p in blocks_params:
341
+ if isinstance(p, AffineQuantizedTensor):
342
+ data = pinned_parameters_data[p]
343
+ p.tensor_impl.float8_data = data[0]
344
+ p.tensor_impl.scale = data[1]
345
+ else:
346
+ p.data = pinned_parameters_data[p]
347
+ else:
348
+ for p in blocks_params:
349
+ if isinstance(p, AffineQuantizedTensor):
350
+ p.tensor_impl.float8_data = p.tensor_impl.float8_data.cpu()
351
+ p.tensor_impl.scale = p.tensor_impl.scale.cpu()
352
+ else:
353
+ p.data = p.data.cpu()
354
+
355
+ @torch.compiler.disable()
356
+ def gpu_load(self, model_name):
357
+ model = self.models[model_name]
358
+ self.active_models.append(model)
359
+ self.active_models_ids.append(model_name)
360
+
361
+ self.gpu_load_blocks(model_name, None)
362
+
363
+ # torch.cuda.current_stream().synchronize()
364
+
365
+ @torch.compiler.disable()
366
+ def unload_all(self, model_name: str):
367
+ if len(self.active_models_ids) == 0 and self.last_run_model == model_name:
368
+ self.last_run_model = model_name
369
+ return
370
+ for model_name in self.active_models_ids:
371
+ self.gpu_unload_blocks(model_name, None)
372
+ loaded_block = self.loaded_blocks[model_name]
373
+ if loaded_block != None:
374
+ self.gpu_unload_blocks(model_name, loaded_block)
375
+ self.loaded_blocks[model_name] = None
376
+
377
+ self.active_models = []
378
+ self.active_models_ids = []
379
+ self.active_subcaches = []
380
+ torch.cuda.empty_cache()
381
+ gc.collect()
382
+ self.last_reserved_mem_check = time.time()
383
+ self.last_run_model = model_name
384
+
385
+ def move_args_to_gpu(self, *args, **kwargs):
386
+ new_args = []
387
+ new_kwargs = {}
388
+ for arg in args:
389
+ if torch.is_tensor(arg):
390
+ if arg.dtype == torch.float32:
391
+ arg = arg.to(torch.bfloat16).cuda(non_blocking=True, device=self.device_id)
392
+ else:
393
+ arg = arg.cuda(non_blocking=True, device=self.device_id)
394
+ new_args.append(arg)
395
+
396
+ for k in kwargs:
397
+ arg = kwargs[k]
398
+ if torch.is_tensor(arg):
399
+ if arg.dtype == torch.float32:
400
+ arg = arg.to(torch.bfloat16).cuda(non_blocking=True, device=self.device_id)
401
+ else:
402
+ arg = arg.cuda(non_blocking=True, device=self.device_id)
403
+ new_kwargs[k] = arg
404
+
405
+ return new_args, new_kwargs
406
+
407
+ def ready_to_check_mem(self):
408
+ if self.compile:
409
+ return
410
+ cur_clock = time.time()
411
+ # can't check at each call if we can empty the cuda cache as quering the reserved memory value is a time consuming operation
412
+ if (cur_clock - self.last_reserved_mem_check) < 0.200:
413
+ return False
414
+ self.last_reserved_mem_check = cur_clock
415
+ return True
416
+
417
+ def empty_cache_if_needed(self):
418
+ mem_reserved = torch.cuda.memory_reserved()
419
+ mem_threshold = 0.9 * self.device_mem_capacity
420
+ if mem_reserved >= mem_threshold:
421
+ mem_allocated = torch.cuda.memory_allocated()
422
+ if mem_allocated <= 0.70 * mem_reserved:
423
+ torch.cuda.empty_cache()
424
+ tm = time.time()
425
+ if self.verboseLevel >= 2:
426
+ print(f"Empty Cuda cache at {tm}")
427
+
428
+ def any_param_or_buffer(self, target_module: torch.nn.Module):
429
+
430
+ for _ in target_module.parameters(recurse=False):
431
+ return True
432
+
433
+ for _ in target_module.buffers(recurse=False):
434
+ return True
435
+
436
+ return False
437
+
438
+ def hook_me_light(self, target_module, model_name, blocks_name, previous_method, context):
439
+
440
+ anyParam = self.any_param_or_buffer(target_module)
441
+
442
+ def check_empty_cuda_cache(module, *args, **kwargs):
443
+ if self.ready_to_check_mem():
444
+ self.empty_cache_if_needed()
445
+ return previous_method(*args, **kwargs)
446
+
447
+ def load_module_blocks(module, *args, **kwargs):
448
+ if blocks_name == None:
449
+ if self.ready_to_check_mem():
450
+ self.empty_cache_if_needed()
451
+ else:
452
+ loaded_block = self.loaded_blocks[model_name]
453
+ if loaded_block == None or loaded_block != blocks_name:
454
+ if loaded_block != None:
455
+ self.gpu_unload_blocks(model_name, loaded_block)
456
+ if self.ready_to_check_mem():
457
+ self.empty_cache_if_needed()
458
+ self.loaded_blocks[model_name] = blocks_name
459
+ self.gpu_load_blocks(model_name, blocks_name)
460
+ return previous_method(*args, **kwargs)
461
+
462
+ if hasattr(target_module, "_mm_id"):
463
+ orig_model_name = getattr(target_module, "_mm_id")
464
+ if self.verboseLevel >= 2:
465
+ print(
466
+ f"Model '{model_name}' shares module '{target_module._get_name()}' with module '{orig_model_name}' "
467
+ )
468
+ assert not anyParam
469
+ return
470
+ setattr(target_module, "_mm_id", model_name)
471
 
472
+ if blocks_name != None and anyParam:
473
+ setattr(
474
+ target_module,
475
+ "forward",
476
+ functools.update_wrapper(functools.partial(load_module_blocks, target_module), previous_method),
477
+ )
478
+ # print(f"new cache:{blocks_name}")
479
  else:
480
+ setattr(
481
+ target_module,
482
+ "forward",
483
+ functools.update_wrapper(functools.partial(check_empty_cuda_cache, target_module), previous_method),
 
 
 
 
484
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
+ def hook_me(self, target_module, model, model_name, module_id, previous_method):
487
+ def check_change_module(module, *args, **kwargs):
488
+ performEmptyCacheTest = False
489
+ if not model_name in self.active_models_ids:
490
+ new_model_name = getattr(module, "_mm_id")
491
+ if not self.can_model_be_cotenant(new_model_name):
492
+ self.unload_all(model_name)
493
+ performEmptyCacheTest = False
494
+ self.gpu_load(new_model_name)
495
+ args, kwargs = self.move_args_to_gpu(*args, **kwargs)
496
+ if performEmptyCacheTest:
497
+ self.empty_cache_if_needed()
498
+ return previous_method(*args, **kwargs)
499
+
500
+ if hasattr(target_module, "_mm_id"):
501
+ return
502
+ setattr(target_module, "_mm_id", model_name)
503
+
504
+ setattr(
505
+ target_module,
506
+ "forward",
507
+ functools.update_wrapper(functools.partial(check_change_module, target_module), previous_method),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  )
 
509
 
510
+ if not self.verboseLevel >= 1:
511
+ return
512
 
513
+ if module_id == None or module_id == "":
514
+ model_name = model._get_name()
515
+ print(f"Hooked in model '{model_name}' ({model_name})")