|
from typing import Callable, Dict, List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import ( |
|
_resize_with_antialiasing, |
|
StableVideoDiffusionPipelineOutput, |
|
StableVideoDiffusionPipeline, |
|
retrieve_timesteps, |
|
) |
|
from diffusers.utils import logging |
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class DepthCrafterPipeline(StableVideoDiffusionPipeline): |
|
|
|
@torch.inference_mode() |
|
def encode_video( |
|
self, |
|
video: torch.Tensor, |
|
chunk_size: int = 14, |
|
) -> torch.Tensor: |
|
""" |
|
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames |
|
:param chunk_size: the chunk size to encode video |
|
:return: image_embeddings in shape of [b, 1024] |
|
""" |
|
|
|
video_224 = _resize_with_antialiasing(video.float(), (224, 224)) |
|
video_224 = (video_224 + 1.0) / 2.0 |
|
|
|
embeddings = [] |
|
for i in range(0, video_224.shape[0], chunk_size): |
|
tmp = self.feature_extractor( |
|
images=video_224[i : i + chunk_size], |
|
do_normalize=True, |
|
do_center_crop=False, |
|
do_resize=False, |
|
do_rescale=False, |
|
return_tensors="pt", |
|
).pixel_values.to(video.device, dtype=video.dtype) |
|
embeddings.append(self.image_encoder(tmp).image_embeds) |
|
|
|
embeddings = torch.cat(embeddings, dim=0) |
|
return embeddings |
|
|
|
@torch.inference_mode() |
|
def encode_vae_video( |
|
self, |
|
video: torch.Tensor, |
|
chunk_size: int = 14, |
|
): |
|
""" |
|
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames |
|
:param chunk_size: the chunk size to encode video |
|
:return: vae latents in shape of [b, c, h, w] |
|
""" |
|
video_latents = [] |
|
for i in range(0, video.shape[0], chunk_size): |
|
video_latents.append( |
|
self.vae.encode(video[i : i + chunk_size]).latent_dist.mode() |
|
) |
|
video_latents = torch.cat(video_latents, dim=0) |
|
return video_latents |
|
|
|
@staticmethod |
|
def check_inputs(video, height, width): |
|
""" |
|
:param video: |
|
:param height: |
|
:param width: |
|
:return: |
|
""" |
|
if not isinstance(video, torch.Tensor) and not isinstance(video, np.ndarray): |
|
raise ValueError( |
|
f"Expected `video` to be a `torch.Tensor` or `VideoReader`, but got a {type(video)}" |
|
) |
|
|
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError( |
|
f"`height` and `width` have to be divisible by 8 but are {height} and {width}." |
|
) |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
video: Union[np.ndarray, torch.Tensor], |
|
height: int = 576, |
|
width: int = 1024, |
|
num_inference_steps: int = 25, |
|
guidance_scale: float = 1.0, |
|
window_size: Optional[int] = 110, |
|
noise_aug_strength: float = 0.02, |
|
decode_chunk_size: Optional[int] = None, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
|
return_dict: bool = True, |
|
overlap: int = 25, |
|
track_time: bool = False, |
|
): |
|
""" |
|
:param video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1] |
|
:param height: |
|
:param width: |
|
:param num_inference_steps: |
|
:param guidance_scale: |
|
:param window_size: sliding window processing size |
|
:param fps: |
|
:param motion_bucket_id: |
|
:param noise_aug_strength: |
|
:param decode_chunk_size: |
|
:param generator: |
|
:param latents: |
|
:param output_type: |
|
:param callback_on_step_end: |
|
:param callback_on_step_end_tensor_inputs: |
|
:param return_dict: |
|
:return: |
|
""" |
|
|
|
height = height or self.unet.config.sample_size * self.vae_scale_factor |
|
width = width or self.unet.config.sample_size * self.vae_scale_factor |
|
num_frames = video.shape[0] |
|
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8 |
|
if num_frames <= window_size: |
|
window_size = num_frames |
|
overlap = 0 |
|
stride = window_size - overlap |
|
|
|
|
|
self.check_inputs(video, height, width) |
|
|
|
|
|
batch_size = 1 |
|
device = self._execution_device |
|
|
|
|
|
|
|
self._guidance_scale = guidance_scale |
|
|
|
|
|
if isinstance(video, np.ndarray): |
|
video = torch.from_numpy(video.transpose(0, 3, 1, 2)) |
|
else: |
|
assert isinstance(video, torch.Tensor) |
|
video = video.to(device=device, dtype=self.dtype) |
|
video = video * 2.0 - 1.0 |
|
|
|
if track_time: |
|
start_event = torch.cuda.Event(enable_timing=True) |
|
encode_event = torch.cuda.Event(enable_timing=True) |
|
denoise_event = torch.cuda.Event(enable_timing=True) |
|
decode_event = torch.cuda.Event(enable_timing=True) |
|
start_event.record() |
|
|
|
video_embeddings = self.encode_video( |
|
video, chunk_size=decode_chunk_size |
|
).unsqueeze( |
|
0 |
|
) |
|
torch.cuda.empty_cache() |
|
|
|
noise = randn_tensor( |
|
video.shape, generator=generator, device=device, dtype=video.dtype |
|
) |
|
video = video + noise_aug_strength * noise |
|
|
|
|
|
needs_upcasting = ( |
|
self.vae.dtype == torch.float16 and self.vae.config.force_upcast |
|
) |
|
if needs_upcasting: |
|
self.vae.to(dtype=torch.float32) |
|
|
|
video_latents = self.encode_vae_video( |
|
video.to(self.vae.dtype), |
|
chunk_size=decode_chunk_size, |
|
).unsqueeze( |
|
0 |
|
) |
|
|
|
if track_time: |
|
encode_event.record() |
|
torch.cuda.synchronize() |
|
elapsed_time_ms = start_event.elapsed_time(encode_event) |
|
print(f"Elapsed time for encoding video: {elapsed_time_ms} ms") |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
if needs_upcasting: |
|
self.vae.to(dtype=torch.float16) |
|
|
|
|
|
added_time_ids = self._get_add_time_ids( |
|
7, |
|
127, |
|
noise_aug_strength, |
|
video_embeddings.dtype, |
|
batch_size, |
|
1, |
|
False, |
|
) |
|
added_time_ids = added_time_ids.to(device) |
|
|
|
|
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
self.scheduler, num_inference_steps, device, None, None |
|
) |
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
|
self._num_timesteps = len(timesteps) |
|
|
|
|
|
num_channels_latents = self.unet.config.in_channels |
|
latents_init = self.prepare_latents( |
|
batch_size, |
|
window_size, |
|
num_channels_latents, |
|
height, |
|
width, |
|
video_embeddings.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
latents_all = None |
|
|
|
idx_start = 0 |
|
if overlap > 0: |
|
weights = torch.linspace(0, 1, overlap, device=device) |
|
weights = weights.view(1, overlap, 1, 1, 1) |
|
else: |
|
weights = None |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
while idx_start < num_frames - overlap: |
|
idx_end = min(idx_start + window_size, num_frames) |
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
|
|
|
|
latents = latents_init[:, : idx_end - idx_start].clone() |
|
latents_init = torch.cat( |
|
[latents_init[:, -overlap:], latents_init[:, :stride]], dim=1 |
|
) |
|
|
|
video_latents_current = video_latents[:, idx_start:idx_end] |
|
video_embeddings_current = video_embeddings[:, idx_start:idx_end] |
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
if latents_all is not None and i == 0: |
|
latents[:, :overlap] = ( |
|
latents_all[:, -overlap:] |
|
+ latents[:, :overlap] |
|
/ self.scheduler.init_noise_sigma |
|
* self.scheduler.sigmas[i] |
|
) |
|
|
|
latent_model_input = latents |
|
latent_model_input = self.scheduler.scale_model_input( |
|
latent_model_input, t |
|
) |
|
latent_model_input = torch.cat( |
|
[latent_model_input, video_latents_current], dim=2 |
|
) |
|
noise_pred = self.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=video_embeddings_current, |
|
added_time_ids=added_time_ids, |
|
return_dict=False, |
|
)[0] |
|
|
|
if self.do_classifier_free_guidance: |
|
latent_model_input = latents |
|
latent_model_input = self.scheduler.scale_model_input( |
|
latent_model_input, t |
|
) |
|
latent_model_input = torch.cat( |
|
[latent_model_input, torch.zeros_like(latent_model_input)], |
|
dim=2, |
|
) |
|
noise_pred_uncond = self.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=torch.zeros_like( |
|
video_embeddings_current |
|
), |
|
added_time_ids=added_time_ids, |
|
return_dict=False, |
|
)[0] |
|
|
|
noise_pred = noise_pred_uncond + self.guidance_scale * ( |
|
noise_pred - noise_pred_uncond |
|
) |
|
latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
if callback_on_step_end is not None: |
|
callback_kwargs = {} |
|
for k in callback_on_step_end_tensor_inputs: |
|
callback_kwargs[k] = locals()[k] |
|
callback_outputs = callback_on_step_end( |
|
self, i, t, callback_kwargs |
|
) |
|
|
|
latents = callback_outputs.pop("latents", latents) |
|
|
|
if i == len(timesteps) - 1 or ( |
|
(i + 1) > num_warmup_steps |
|
and (i + 1) % self.scheduler.order == 0 |
|
): |
|
progress_bar.update() |
|
|
|
if latents_all is None: |
|
latents_all = latents.clone() |
|
else: |
|
assert weights is not None |
|
|
|
|
|
|
|
latents_all[:, -overlap:] = latents[ |
|
:, :overlap |
|
] * weights + latents_all[:, -overlap:] * (1 - weights) |
|
latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1) |
|
|
|
idx_start += stride |
|
|
|
if track_time: |
|
denoise_event.record() |
|
torch.cuda.synchronize() |
|
elapsed_time_ms = encode_event.elapsed_time(denoise_event) |
|
print(f"Elapsed time for denoising video: {elapsed_time_ms} ms") |
|
|
|
if not output_type == "latent": |
|
|
|
if needs_upcasting: |
|
self.vae.to(dtype=torch.float16) |
|
frames = self.decode_latents(latents_all, num_frames, decode_chunk_size) |
|
|
|
if track_time: |
|
decode_event.record() |
|
torch.cuda.synchronize() |
|
elapsed_time_ms = denoise_event.elapsed_time(decode_event) |
|
print(f"Elapsed time for decoding video: {elapsed_time_ms} ms") |
|
|
|
frames = self.video_processor.postprocess_video( |
|
video=frames, output_type=output_type |
|
) |
|
else: |
|
frames = latents_all |
|
|
|
self.maybe_free_model_hooks() |
|
|
|
if not return_dict: |
|
return frames |
|
|
|
return StableVideoDiffusionPipelineOutput(frames=frames) |
|
|