sayakpaul HF Staff commited on
Commit
a806942
·
1 Parent(s): c341814
Files changed (2) hide show
  1. optimization.py +4 -3
  2. optimization_utils.py +1 -1
optimization.py CHANGED
@@ -100,14 +100,15 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
100
 
101
  compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
102
  compiled_portrait = aoti_compile(exported_portrait, INDUCTOR_CONFIGS)
103
- # compiled_portrait.weights = (
104
- # compiled_landscape.weights
105
- # ) # Avoid weights duplication when serializing back to main process
106
 
107
  return compiled_landscape, compiled_portrait
108
 
109
  compiled_landscape, compiled_portrait = compile_transformer()
110
 
 
111
  def combined_transformer(*args, **kwargs):
112
  hidden_states: torch.Tensor = kwargs["hidden_states"]
113
  unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
 
100
 
101
  compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
102
  compiled_portrait = aoti_compile(exported_portrait, INDUCTOR_CONFIGS)
103
+ compiled_portrait.weights = (
104
+ compiled_landscape.weights
105
+ ) # Avoid weights duplication when serializing back to main process
106
 
107
  return compiled_landscape, compiled_portrait
108
 
109
  compiled_landscape, compiled_portrait = compile_transformer()
110
 
111
+ @torch.no_grad()
112
  def combined_transformer(*args, **kwargs):
113
  hidden_states: torch.Tensor = kwargs["hidden_states"]
114
  unpacked_hidden_states = LTXConditionPipeline._unpack_latents(
optimization_utils.py CHANGED
@@ -67,7 +67,7 @@ def aoti_compile(
67
  files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
68
  package_aoti(archive_file, files)
69
  (weights,) = (artifact for artifact in artifacts if isinstance(artifact, Weights))
70
- zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights})
71
  return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
72
 
73
 
 
67
  files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
68
  package_aoti(archive_file, files)
69
  (weights,) = (artifact for artifact in artifacts if isinstance(artifact, Weights))
70
+ zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights}, to_cuda=True)
71
  return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
72
 
73