1inkusFace commited on
Commit
9eeb954
·
verified ·
1 Parent(s): 2bbaaed

Update skyreelsinfer/pipelines/pipeline_skyreels_video.py

Browse files
skyreelsinfer/pipelines/pipeline_skyreels_video.py CHANGED
@@ -240,7 +240,8 @@ class SkyreelsVideoPipeline(HunyuanVideoPipeline):
240
  batch_size = len(prompt)
241
  else:
242
  batch_size = prompt_embeds.shape[0]
243
- self.text_encoder.to("cuda")
 
244
 
245
  # 3. Encode input prompt
246
  (
@@ -341,7 +342,7 @@ class SkyreelsVideoPipeline(HunyuanVideoPipeline):
341
  self.text_encoder.to("cpu")
342
  self.vae.to("cpu")
343
  torch.cuda.empty_cache()
344
-
345
  with self.progress_bar(total=num_inference_steps) as progress_bar:
346
  for i, t in enumerate(timesteps):
347
  if self.interrupt:
@@ -414,7 +415,8 @@ class SkyreelsVideoPipeline(HunyuanVideoPipeline):
414
  progress_bar.update()
415
 
416
  if not output_type == "latent":
417
- self.vae.to("cuda")
 
418
  latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
419
  video = self.vae.decode(latents, return_dict=False)[0]
420
  video = self.video_processor.postprocess_video(video, output_type=output_type)
 
240
  batch_size = len(prompt)
241
  else:
242
  batch_size = prompt_embeds.shape[0]
243
+ if self.text_encoder.device.type == 'cpu':
244
+ self.text_encoder.to("cuda")
245
 
246
  # 3. Encode input prompt
247
  (
 
342
  self.text_encoder.to("cpu")
343
  self.vae.to("cpu")
344
  torch.cuda.empty_cache()
345
+ torch.cuda.reset_peak_memory_stats()
346
  with self.progress_bar(total=num_inference_steps) as progress_bar:
347
  for i, t in enumerate(timesteps):
348
  if self.interrupt:
 
415
  progress_bar.update()
416
 
417
  if not output_type == "latent":
418
+ if self.vae.device.type == 'cpu':
419
+ self.vae.to("cuda")
420
  latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
421
  video = self.vae.decode(latents, return_dict=False)[0]
422
  video = self.video_processor.postprocess_video(video, output_type=output_type)