PengWeixuanSZU commited on
Commit
8c24d5c
·
verified ·
1 Parent(s): 6a2f6f4

Update pipeline_minimax_remover.py

Browse files
Files changed (1) hide show
  1. pipeline_minimax_remover.py +2 -6
pipeline_minimax_remover.py CHANGED
@@ -121,6 +121,7 @@ class Minimax_Remover_Pipeline(DiffusionPipeline):
121
  output_type: Optional[str] = "np",
122
  iterations: int = 16
123
  ):
 
124
  self._current_timestep = None
125
  self._interrupt = False
126
  device = self._execution_device
@@ -145,16 +146,11 @@ class Minimax_Remover_Pipeline(DiffusionPipeline):
145
  latents,
146
  )
147
 
148
- print(f"1 images.shape: {images.shape}, masks.shape:{masks.shape}")
149
  masks = self.expand_masks(masks, iterations)
150
- print(f"2 images.shape: {images.shape}, masks.shape:{masks.shape}")
151
  masks = self.resize(masks, height, width).to("cuda:0").half()
152
- print(f"3 images.shape: {images.shape}, masks.shape:{masks.shape}")
153
  masks[masks>0] = 1
154
  images = rearrange(images, "f h w c -> c f h w")
155
- print(f"4 images.shape: {images.shape}, masks.shape:{masks.shape}")
156
  images = self.resize(images[None,...], height, width).to("cuda:0").half()
157
- print(f"5 images.shape: {images.shape}, masks.shape:{masks.shape}")
158
 
159
  masked_images = images * (1-masks)
160
 
@@ -199,4 +195,4 @@ class Minimax_Remover_Pipeline(DiffusionPipeline):
199
  video = self.vae.decode(latents, return_dict=False)[0]
200
  video = self.video_processor.postprocess_video(video, output_type=output_type)
201
 
202
- return WanPipelineOutput(frames=video)
 
121
  output_type: Optional[str] = "np",
122
  iterations: int = 16
123
  ):
124
+
125
  self._current_timestep = None
126
  self._interrupt = False
127
  device = self._execution_device
 
146
  latents,
147
  )
148
 
 
149
  masks = self.expand_masks(masks, iterations)
 
150
  masks = self.resize(masks, height, width).to("cuda:0").half()
 
151
  masks[masks>0] = 1
152
  images = rearrange(images, "f h w c -> c f h w")
 
153
  images = self.resize(images[None,...], height, width).to("cuda:0").half()
 
154
 
155
  masked_images = images * (1-masks)
156
 
 
195
  video = self.vae.decode(latents, return_dict=False)[0]
196
  video = self.video_processor.postprocess_video(video, output_type=output_type)
197
 
198
+ return WanPipelineOutput(frames=video)