MiniMax-Remover / pipeline_minimax_remover.py
PengWeixuanSZU's picture
Update pipeline_minimax_remover.py
5c93d77 verified
raw
history blame
7.02 kB
from typing import Callable, Dict, List, Optional, Union
import torch
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.models import AutoencoderKLWan
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
import scipy
import numpy as np
import torch.nn.functional as F
from transformer_minimax_remover import Transformer3DModel
from einops import rearrange
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
class Minimax_Remover_Pipeline(DiffusionPipeline):
model_cpu_offload_seq = "transformer->vae"
_callback_tensor_inputs = ["latents"]
def __init__(
self,
transformer: Transformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler
):
super().__init__()
self.register_modules(
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def prepare_latents(
self,
batch_size: int,
num_channels_latents: 16,
height: int = 720,
width: int = 1280,
num_latent_frames: int = 21,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (
batch_size,
num_channels_latents,
num_latent_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def expand_masks(self, masks, iterations):
masks = masks.cpu().detach().numpy()
# numpy array, masks [0,1], f h w c
masks2 = []
for i in range(len(masks)):
mask = masks[i]
mask = mask > 0
mask = scipy.ndimage.binary_dilation(mask, iterations=iterations)
masks2.append(mask)
masks = np.array(masks2).astype(np.float32)
masks = torch.from_numpy(masks)
masks = masks.repeat(1,1,1,3)
masks = rearrange(masks, "f h w c -> c f h w")
masks = masks[None,...]
return masks
def resize(self, images, w, h):
bsz,_,_,_,_ = images.shape
images = rearrange(images, "b c f w h -> (b f) c w h")
images = F.interpolate(images, (w,h), mode='bilinear')
images = rearrange(images, "(b f) c w h -> b c f w h", b=bsz)
return images
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
def __call__(
self,
height: int = 720,
width: int = 1280,
num_frames: int = 81,
num_inference_steps: int = 50,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
images: Optional[torch.Tensor] = None,
masks: Optional[torch.Tensor] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
iterations: int = 16
):
self._current_timestep = None
self._interrupt = False
device = self._execution_device
batch_size = 1
transformer_dtype = torch.float16
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
num_channels_latents = 16
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latents = self.prepare_latents(
batch_size,
num_channels_latents,
height,
width,
num_latent_frames,
torch.float16,
device,
generator,
latents,
)
masks = self.expand_masks(masks, iterations)
masks = self.resize(masks, height, width).to("cuda:0").half()
masks[masks>0] = 1
images = rearrange(images, "f h w c -> c f h w")
images = self.resize(images[None,...], height, width).to("cuda:0").half()
masked_images = images * (1-masks)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(self.vae.device, torch.float16)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
self.vae.device, torch.float16
)
with torch.no_grad():
masked_latents = self.vae.encode(masked_images.half()).latent_dist.mode()
masks_latents = self.vae.encode(2*masks.half()-1.0).latent_dist.mode()
masked_latents = (masked_latents - latents_mean) * latents_std
masks_latents = (masks_latents - latents_mean) * latents_std
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = latents.to(transformer_dtype)
#print("latent_model_input, masked_latents, masks_latents", latent_model_input.shape, masked_latents.shape, masks_latents.shape)
latent_model_input = torch.cat([latent_model_input, masked_latents, masks_latents], dim=1)
timestep = t.expand(latents.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input.half(),
timestep=timestep
)[0]
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
progress_bar.update()
latents = latents.half() / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
return WanPipelineOutput(frames=video)