sayakpaul HF Staff commited on
Commit
10bbb52
·
1 Parent(s): 5460f04
Files changed (5) hide show
  1. check.py +82 -0
  2. optimization.py +21 -14
  3. optimization_utils.py +43 -4
  4. reproduce.py +117 -0
  5. requirements.txt +0 -11
check.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import LTXConditionPipeline
2
+ from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from diffusers.utils import load_image, load_video, export_to_video
7
+ from optimization import optimize_pipeline_
8
+
9
+ MODEL_ID = "Lightricks/LTX-Video-0.9.8-13B-distilled"
10
+
11
+ LANDSCAPE_WIDTH = 480
12
+ LANDSCAPE_HEIGHT = 832
13
+ MAX_SEED = np.iinfo(np.int32).max
14
+
15
+ FIXED_FPS = 24
16
+ MIN_FRAMES_MODEL = 8
17
+ MAX_FRAMES_MODEL = 96
18
+
19
+ MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1)
20
+ MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
21
+
22
+ def resize_image(image: Image.Image) -> Image.Image:
23
+ if image.height > image.width:
24
+ transposed = image.transpose(Image.Transpose.ROTATE_90)
25
+ resized = resize_image_landscape(transposed)
26
+ return resized.transpose(Image.Transpose.ROTATE_270)
27
+ return resize_image_landscape(image)
28
+
29
+
30
+ def resize_image_landscape(image: Image.Image) -> Image.Image:
31
+ target_aspect = LANDSCAPE_WIDTH / LANDSCAPE_HEIGHT
32
+ width, height = image.size
33
+ in_aspect = width / height
34
+ if in_aspect > target_aspect:
35
+ new_width = round(height * target_aspect)
36
+ left = (width - new_width) // 2
37
+ image = image.crop((left, 0, left + new_width, height))
38
+ else:
39
+ new_height = round(width / target_aspect)
40
+ top = (height - new_height) // 2
41
+ image = image.crop((0, top, width, top + new_height))
42
+ return image.resize((LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT), Image.LANCZOS)
43
+
44
+
45
+ pipe = LTXConditionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to("cuda")
46
+ dummy_image = Image.new("RGB", (LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT))
47
+ video = load_video(export_to_video([dummy_image]))
48
+ condition1 = LTXVideoCondition(video=video, frame_index=0)
49
+ optimize_pipeline_(
50
+ pipe,
51
+ conditions=[condition1],
52
+ prompt="prompt",
53
+ height=LANDSCAPE_HEIGHT,
54
+ width=LANDSCAPE_WIDTH,
55
+ num_frames=MAX_FRAMES_MODEL,
56
+ num_inference_steps=2
57
+ )
58
+
59
+ default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
60
+ default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
61
+ input_image = load_image("peng.png")
62
+ duration_seconds = MAX_DURATION
63
+ guidance_scale = 1.0
64
+ num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
65
+ current_seed = 42
66
+ resized_image = resize_image(input_image)
67
+ steps = 8
68
+
69
+ video = load_video(export_to_video([resized_image]))
70
+ condition1 = LTXVideoCondition(video=video, frame_index=0)
71
+
72
+ output_frames_list = pipe(
73
+ conditions=[condition1],
74
+ prompt=default_prompt_i2v,
75
+ negative_prompt=default_negative_prompt,
76
+ height=resized_image.height,
77
+ width=resized_image.width,
78
+ num_frames=num_frames,
79
+ guidance_scale=float(guidance_scale),
80
+ num_inference_steps=int(steps),
81
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
82
+ ).frames[0]
optimization.py CHANGED
@@ -9,11 +9,9 @@ from typing import ParamSpec
9
  import spaces
10
  import torch
11
  from torch.utils._pytree import tree_map_only
12
- from torchao.quantization import quantize_
13
- from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
14
  from diffusers import LTXConditionPipeline
15
- from optimization_utils import capture_component_call
16
- from optimization_utils import aoti_compile
17
 
18
 
19
  P = ParamSpec("P")
@@ -33,7 +31,8 @@ INDUCTOR_CONFIGS = {
33
  "epilogue_fusion": False,
34
  "coordinate_descent_tuning": True,
35
  "coordinate_descent_check_all_directions": True,
36
- "max_autotune": True,
 
37
  "triton.cudagraphs": True,
38
  }
39
  TRANSFORMER_SPATIAL_PATCH_SIZE = 1
@@ -50,15 +49,15 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
50
  latent_height = height // VAE_SPATIAL_COMPRESSION_RATIO
51
  latent_width = width // VAE_SPATIAL_COMPRESSION_RATIO
52
 
 
 
 
53
  @spaces.GPU(duration=1500)
54
  def compile_transformer():
55
- with capture_component_call(pipeline, "transformer") as call:
56
- pipeline(*args, **kwargs)
57
-
58
- dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
59
  # dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
60
 
61
- quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
62
 
63
  hidden_states: torch.Tensor = call.kwargs["hidden_states"]
64
  unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
@@ -88,14 +87,13 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
88
  mod=pipeline.transformer,
89
  args=call.args,
90
  kwargs=call.kwargs | {"hidden_states": hidden_states_landscape},
91
- dynamic_shapes=dynamic_shapes,
92
  )
93
-
94
  exported_portrait = torch.export.export(
95
  mod=pipeline.transformer,
96
  args=call.args,
97
  kwargs=call.kwargs | {"hidden_states": hidden_states_portrait},
98
- dynamic_shapes=dynamic_shapes,
99
  )
100
 
101
  compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
@@ -108,6 +106,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
108
 
109
  compiled_landscape, compiled_portrait = compile_transformer()
110
 
 
111
  def combined_transformer(*args, **kwargs):
112
  hidden_states: torch.Tensor = kwargs["hidden_states"]
113
  unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
@@ -126,7 +125,15 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
126
  transformer_config = pipeline.transformer.config
127
  transformer_dtype = pipeline.transformer.dtype
128
  cache_context = pipeline.transformer.cache_context
129
- pipeline.transformer = combined_transformer
 
 
 
 
 
 
 
 
130
  pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
131
  pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
132
  pipeline.transformer.cache_context = cache_context
 
9
  import spaces
10
  import torch
11
  from torch.utils._pytree import tree_map_only
12
+ from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
 
13
  from diffusers import LTXConditionPipeline
14
+ from optimization_utils import capture_component_call, aoti_compile, cudagraph
 
15
 
16
 
17
  P = ParamSpec("P")
 
31
  "epilogue_fusion": False,
32
  "coordinate_descent_tuning": True,
33
  "coordinate_descent_check_all_directions": True,
34
+ # "max_autotune": True,
35
+ "max_autotune": False,
36
  "triton.cudagraphs": True,
37
  }
38
  TRANSFORMER_SPATIAL_PATCH_SIZE = 1
 
49
  latent_height = height // VAE_SPATIAL_COMPRESSION_RATIO
50
  latent_width = width // VAE_SPATIAL_COMPRESSION_RATIO
51
 
52
+ with capture_component_call(pipeline, "transformer") as call:
53
+ pipeline(*args, **kwargs)
54
+
55
  @spaces.GPU(duration=1500)
56
  def compile_transformer():
57
+ # dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
 
 
 
58
  # dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
59
 
60
+ quantize_(pipeline.transformer, float8_dynamic_activation_float8_weight())
61
 
62
  hidden_states: torch.Tensor = call.kwargs["hidden_states"]
63
  unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
 
87
  mod=pipeline.transformer,
88
  args=call.args,
89
  kwargs=call.kwargs | {"hidden_states": hidden_states_landscape},
90
+ # dynamic_shapes=dynamic_shapes,
91
  )
 
92
  exported_portrait = torch.export.export(
93
  mod=pipeline.transformer,
94
  args=call.args,
95
  kwargs=call.kwargs | {"hidden_states": hidden_states_portrait},
96
+ # dynamic_shapes=dynamic_shapes,
97
  )
98
 
99
  compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
 
106
 
107
  compiled_landscape, compiled_portrait = compile_transformer()
108
 
109
+ @torch.no_grad()
110
  def combined_transformer(*args, **kwargs):
111
  hidden_states: torch.Tensor = kwargs["hidden_states"]
112
  unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
 
125
  transformer_config = pipeline.transformer.config
126
  transformer_dtype = pipeline.transformer.dtype
127
  cache_context = pipeline.transformer.cache_context
128
+
129
+ with torch.no_grad():
130
+ combined_transformer(**call.kwargs)
131
+
132
+ pipeline.transformer = cudagraph(combined_transformer)
133
+
134
+ with torch.no_grad():
135
+ pipeline.transformer(**call.kwargs)
136
+
137
  pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
138
  pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
139
  pipeline.transformer.cache_context = cache_context
optimization_utils.py CHANGED
@@ -5,12 +5,13 @@ Taken from https://huggingface.co/spaces/cbensimon/wan2-1-fast/
5
  import contextlib
6
  from contextvars import ContextVar
7
  from io import BytesIO
8
- from typing import Any
9
- from typing import cast
10
  from unittest.mock import patch
11
 
12
  import torch
 
13
  from torch._inductor.package.package import package_aoti
 
14
  from torch.export.pt2_archive._package import AOTICompiledModel
15
  from torch.export.pt2_archive._package_weights import Weights
16
 
@@ -45,7 +46,9 @@ class ZeroGPUCompiledModel:
45
 
46
  def __call__(self, *args, **kwargs):
47
  if (compiled_model := self.compiled_model.get()) is None:
48
- compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
 
 
49
  compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
50
  self.compiled_model.set(compiled_model)
51
  return compiled_model(*args, **kwargs)
@@ -67,7 +70,7 @@ def aoti_compile(
67
  files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
68
  package_aoti(archive_file, files)
69
  (weights,) = (artifact for artifact in artifacts if isinstance(artifact, Weights))
70
- zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights})
71
  return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
72
 
73
 
@@ -100,3 +103,39 @@ def capture_component_call(
100
  except CapturedCallException as e:
101
  captured_call.args = e.args
102
  captured_call.kwargs = e.kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import contextlib
6
  from contextvars import ContextVar
7
  from io import BytesIO
8
+ from typing import Any, cast
 
9
  from unittest.mock import patch
10
 
11
  import torch
12
+ from torch.utils._pytree import tree_map_only
13
  from torch._inductor.package.package import package_aoti
14
+ from torch._inductor.package import load_package
15
  from torch.export.pt2_archive._package import AOTICompiledModel
16
  from torch.export.pt2_archive._package_weights import Weights
17
 
 
46
 
47
  def __call__(self, *args, **kwargs):
48
  if (compiled_model := self.compiled_model.get()) is None:
49
+ # compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
50
+ # compiled_model = torch._inductor.aoti_load_package(self.archive_file, run_single_threaded=True)
51
+ compiled_model = load_package(self.archive_file, run_single_threaded=True)
52
  compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
53
  self.compiled_model.set(compiled_model)
54
  return compiled_model(*args, **kwargs)
 
70
  files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
71
  package_aoti(archive_file, files)
72
  (weights,) = (artifact for artifact in artifacts if isinstance(artifact, Weights))
73
+ zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights}, to_cuda=True)
74
  return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
75
 
76
 
 
103
  except CapturedCallException as e:
104
  captured_call.args = e.args
105
  captured_call.kwargs = e.kwargs
106
+
107
+
108
+ # Taken from
109
+ # https://github.com/huggingface/flux-fast/blob/5027798d7f69a8e0e478df92f48663c40727f8ea/utils/pipeline_utils.py#L198C1-L231C14
110
+ def cudagraph(f):
111
+ _graphs = {}
112
+ def f_(*args, **kwargs):
113
+ key = hash(tuple(tuple(kwargs[a].shape) for a in sorted(kwargs.keys())
114
+ if isinstance(kwargs[a], torch.Tensor)))
115
+ if key in _graphs:
116
+ # use the cached wrapper if one exists. this will perform CUDAGraph replay
117
+ wrapped, *_ = _graphs[key]
118
+ return wrapped(*args, **kwargs)
119
+
120
+ # record a new CUDAGraph and cache it for future use
121
+ g = torch.cuda.CUDAGraph()
122
+ in_args, in_kwargs = tree_map_only(torch.Tensor, lambda t: t.clone(), (args, kwargs))
123
+ f(*in_args, **in_kwargs) # stream warmup
124
+ with torch.cuda.graph(g):
125
+ out_tensors = f(*in_args, **in_kwargs)
126
+ def wrapped(*args, **kwargs):
127
+ # note that CUDAGraphs require inputs / outputs to be in fixed memory locations.
128
+ # inputs must be copied into the fixed input memory locations.
129
+ [a.copy_(b) for a, b in zip(in_args, args) if isinstance(a, torch.Tensor)]
130
+ for key in kwargs:
131
+ if isinstance(kwargs[key], torch.Tensor):
132
+ in_kwargs[key].copy_(kwargs[key])
133
+ g.replay()
134
+ # clone() outputs on the way out to disconnect them from the fixed output memory
135
+ # locations. this allows for CUDAGraph reuse without accidentally overwriting memory
136
+ return [o.clone() for o in out_tensors]
137
+
138
+ # cache function that does CUDAGraph replay
139
+ _graphs[key] = (wrapped, g, in_args, in_kwargs, out_tensors)
140
+ return wrapped(*args, **kwargs)
141
+ return f_
reproduce.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import LTXConditionPipeline
2
+ from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
3
+ from diffusers.utils import load_video, export_to_video
4
+ from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
5
+ from io import BytesIO
6
+ import contextlib
7
+ from typing import Any, cast
8
+ from unittest.mock import patch
9
+ import torch
10
+ from torch._inductor.package.package import package_aoti
11
+ from torch._inductor.package import load_package
12
+ from PIL import Image
13
+
14
+ MODEL_ID = "Lightricks/LTX-Video-0.9.8-13B-distilled"
15
+ LANDSCAPE_WIDTH = 480
16
+ LANDSCAPE_HEIGHT = 832
17
+ MAX_FRAMES_MODEL = 96
18
+ INDUCTOR_CONFIGS = {
19
+ "conv_1x1_as_mm": True,
20
+ "epilogue_fusion": False,
21
+ "coordinate_descent_tuning": True,
22
+ "coordinate_descent_check_all_directions": True,
23
+ "max_autotune": False,
24
+ "triton.cudagraphs": True,
25
+ }
26
+ INDUCTOR_CONFIGS_OVERRIDES = {
27
+ "aot_inductor.package_constants_in_so": False,
28
+ "aot_inductor.package_constants_on_disk": True,
29
+ "aot_inductor.package": True,
30
+ }
31
+
32
+
33
+ @contextlib.contextmanager
34
+ def capture_component_call(
35
+ pipeline: LTXConditionPipeline,
36
+ component_name: str,
37
+ component_method="forward",
38
+ ):
39
+ class CapturedCallException(Exception):
40
+ def __init__(self, *args, **kwargs):
41
+ super().__init__()
42
+ self.args = args
43
+ self.kwargs = kwargs
44
+
45
+ class CapturedCall:
46
+ def __init__(self):
47
+ self.args: tuple[Any, ...] = ()
48
+ self.kwargs: dict[str, Any] = {}
49
+
50
+ component = getattr(pipeline, component_name)
51
+ captured_call = CapturedCall()
52
+
53
+ def capture_call(*args, **kwargs):
54
+ raise CapturedCallException(*args, **kwargs)
55
+
56
+ with patch.object(component, component_method, new=capture_call):
57
+ try:
58
+ yield captured_call
59
+ except CapturedCallException as e:
60
+ captured_call.args = e.args
61
+ captured_call.kwargs = e.kwargs
62
+
63
+
64
+ pipe = LTXConditionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to("cuda")
65
+ quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
66
+
67
+ resized_image = Image.new("RGB", (LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT))
68
+ video = load_video(export_to_video([resized_image]))
69
+ condition1 = LTXVideoCondition(video=video, frame_index=0)
70
+
71
+ with capture_component_call(pipe, "transformer") as call:
72
+ pipe(
73
+ conditions=[condition1],
74
+ prompt="prompt",
75
+ height=LANDSCAPE_HEIGHT,
76
+ width=LANDSCAPE_WIDTH,
77
+ num_frames=MAX_FRAMES_MODEL,
78
+ num_inference_steps=2
79
+ )
80
+
81
+ hidden_states: torch.Tensor = call.kwargs["hidden_states"]
82
+ exported = torch.export.export(
83
+ mod=pipe.transformer,
84
+ args=call.args,
85
+ kwargs=call.kwargs | {"hidden_states": hidden_states},
86
+ )
87
+
88
+ assert exported.example_inputs is not None
89
+ args, kwargs = exported.example_inputs
90
+ gm = cast(torch.fx.GraphModule, exported.module())
91
+ artifacts = torch._inductor.aot_compile(
92
+ gm, args, kwargs, options=INDUCTOR_CONFIGS | INDUCTOR_CONFIGS_OVERRIDES
93
+ )
94
+ archive_file = BytesIO()
95
+ files = [file for file in artifacts if isinstance(file, str)]
96
+ package_aoti(archive_file, files)
97
+ compiled_model = load_package(archive_file, run_single_threaded=True)
98
+ print("Package loaded.")
99
+
100
+ transformer_config = pipe.transformer.config
101
+ transformer_dtype = pipe.transformer.dtype
102
+ cache_context = pipe.transformer.cache_context
103
+ pipe.transformer = compiled_model
104
+ pipe.transformer.config = transformer_config
105
+ pipe.transformer.dtype = transformer_dtype
106
+ pipe.transformer.cache_context = cache_context
107
+ print("Configs done.")
108
+
109
+ pipe(
110
+ conditions=[condition1],
111
+ prompt="prompt",
112
+ height=LANDSCAPE_HEIGHT,
113
+ width=LANDSCAPE_WIDTH,
114
+ num_frames=MAX_FRAMES_MODEL,
115
+ num_inference_steps=2
116
+ )
117
+ print("Okay")
requirements.txt DELETED
@@ -1,11 +0,0 @@
1
- git+https://github.com/huggingface/diffusers.git
2
- transformers
3
- accelerate
4
- safetensors
5
- sentencepiece
6
- peft
7
- ftfy
8
- imageio
9
- imageio-ffmpeg
10
- opencv-python
11
- torchao==0.11.0