multimodalart HF Staff commited on
Commit
17b56a5
·
verified ·
1 Parent(s): dff15f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -94,21 +94,21 @@ pipe = create_ltx_video_pipeline(
94
  sampler=PIPELINE_CONFIG_YAML["sampler"], # "from_checkpoint" or specific sampler
95
  device=DEVICE,
96
  enhance_prompt=False, # Assuming Gradio controls this, or set based on YAML later
97
- )
98
 
99
  # Create Latent Upsampler
100
  latent_upsampler = create_latent_upsampler(
101
  latent_upsampler_model_path=spatial_upsampler_path,
102
  device=DEVICE
103
  )
104
- latent_upsampler = latent_upsampler.to(torch.bfloat16 if PIPELINE_CONFIG_YAML["precision"] == "bfloat16" else torch.float32)
105
 
106
 
107
  # Multi-scale pipeline (wrapper)
108
  multi_scale_pipe = LTXMultiScalePipeline(
109
  video_pipeline=pipe,
110
  latent_upsampler=latent_upsampler
111
- )
112
  # --- End Global Configuration & Model Loading ---
113
 
114
 
@@ -287,7 +287,7 @@ def generate(prompt,
287
  "output_type": "latent"
288
  }
289
  latents = pipe(**first_pass_args).images # .images here is actually latents
290
-
291
  # 2. Upsample latents manually
292
  # Need to handle normalization around latent upsampler if it expects unnormalized latents
293
  latents_unnorm = un_normalize_latents(latents, pipe.vae, vae_per_channel_normalize=True)
@@ -324,8 +324,9 @@ def generate(prompt,
324
  decode_noise_val = PIPELINE_CONFIG_YAML.get("decode_noise_scale", 0.025)
325
  upsampled_latents = upsampled_latents * (1 - decode_noise_val) + noise * decode_noise_val
326
 
327
-
328
  result_frames_tensor = pipe.vae.decode(upsampled_latents, **decode_kwargs).sample
 
329
  # result_frames_tensor shape: (B, C, F_video, H_video, W_video)
330
 
331
  # --- Post-processing: Cropping and Converting to PIL ---
 
94
  sampler=PIPELINE_CONFIG_YAML["sampler"], # "from_checkpoint" or specific sampler
95
  device=DEVICE,
96
  enhance_prompt=False, # Assuming Gradio controls this, or set based on YAML later
97
+ ).to(torch.bfloat16)
98
 
99
  # Create Latent Upsampler
100
  latent_upsampler = create_latent_upsampler(
101
  latent_upsampler_model_path=spatial_upsampler_path,
102
  device=DEVICE
103
  )
104
+ latent_upsampler = latent_upsampler.to(torch.bfloat16)
105
 
106
 
107
  # Multi-scale pipeline (wrapper)
108
  multi_scale_pipe = LTXMultiScalePipeline(
109
  video_pipeline=pipe,
110
  latent_upsampler=latent_upsampler
111
+ ).to(torch.bfloat16)
112
  # --- End Global Configuration & Model Loading ---
113
 
114
 
 
287
  "output_type": "latent"
288
  }
289
  latents = pipe(**first_pass_args).images # .images here is actually latents
290
+ print("First pass done!")
291
  # 2. Upsample latents manually
292
  # Need to handle normalization around latent upsampler if it expects unnormalized latents
293
  latents_unnorm = un_normalize_latents(latents, pipe.vae, vae_per_channel_normalize=True)
 
324
  decode_noise_val = PIPELINE_CONFIG_YAML.get("decode_noise_scale", 0.025)
325
  upsampled_latents = upsampled_latents * (1 - decode_noise_val) + noise * decode_noise_val
326
 
327
+ print("before vae decoding")
328
  result_frames_tensor = pipe.vae.decode(upsampled_latents, **decode_kwargs).sample
329
+ print("after vae decoding?")
330
  # result_frames_tensor shape: (B, C, F_video, H_video, W_video)
331
 
332
  # --- Post-processing: Cropping and Converting to PIL ---