|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union |
|
|
|
import torch |
|
from einops import rearrange |
|
from megatron.core import parallel_state |
|
from torch import Tensor |
|
|
|
from cosmos_transfer1.diffusion.conditioner import BaseVideoCondition, VideoConditionerWithCtrl, VideoExtendCondition |
|
from cosmos_transfer1.diffusion.inference.inference_utils import merge_patches_into_video, split_video_into_patches |
|
from cosmos_transfer1.diffusion.model.model_t2w import DiffusionT2WModel, broadcast_condition |
|
from cosmos_transfer1.diffusion.model.model_v2w import DiffusionV2WModel, DistillV2WModel |
|
from cosmos_transfer1.diffusion.module.parallel import broadcast, cat_outputs_cp, split_inputs_cp |
|
from cosmos_transfer1.diffusion.networks.distill_controlnet_wrapper import DistillControlNet |
|
from cosmos_transfer1.utils import log, misc |
|
from cosmos_transfer1.utils.lazy_config import instantiate as lazy_instantiate |
|
|
|
T = TypeVar("T") |
|
IS_PREPROCESSED_KEY = "is_preprocessed" |
|
|
|
|
|
class VideoDiffusionModelWithCtrl(DiffusionV2WModel): |
|
def build_model(self) -> torch.nn.ModuleDict: |
|
log.info("Start creating base model") |
|
base_model = super().build_model() |
|
|
|
self.load_base_model(base_model) |
|
log.info("Done creating base model") |
|
|
|
log.info("Start creating ctrlnet model") |
|
net = lazy_instantiate(self.config.net_ctrl) |
|
conditioner = base_model.conditioner |
|
logvar = base_model.logvar |
|
|
|
model = torch.nn.ModuleDict({"net": net, "conditioner": conditioner, "logvar": logvar}) |
|
model.load_state_dict(base_model.state_dict(), strict=False) |
|
|
|
model.base_model = base_model |
|
log.info("Done creating ctrlnet model") |
|
|
|
self.hint_key = self.config.hint_key["hint_key"] |
|
return model |
|
|
|
@property |
|
def base_net(self): |
|
return self.model.base_model.net |
|
|
|
@property |
|
def conditioner(self): |
|
return self.model.conditioner |
|
|
|
def load_base_model(self, base_model) -> None: |
|
config = self.config |
|
if config.base_load_from is not None: |
|
checkpoint_path = config.base_load_from["load_path"] |
|
else: |
|
checkpoint_path = "" |
|
|
|
if checkpoint_path: |
|
log.info(f"Loading base model checkpoint (local): {checkpoint_path}") |
|
state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) |
|
log.success(f"Complete loading base model checkpoint (local): {checkpoint_path}") |
|
|
|
if "ema" in state_dict and state_dict["ema"] is not None: |
|
|
|
log.info("Copying ema to base model") |
|
base_state_dict = {k.replace("-", "."): v for k, v in state_dict["ema"].items()} |
|
elif "model" in state_dict: |
|
|
|
log.warning("Using non-EMA base model") |
|
base_state_dict = state_dict["model"] |
|
else: |
|
log.info("Loading from an EMA only model") |
|
base_state_dict = state_dict |
|
missing, unexpected = base_model.load_state_dict(base_state_dict, strict=False) |
|
log.info(f"Missing: {missing}") |
|
log.info(f"Unexpected: {unexpected}") |
|
log.info("Done loading the base model checkpoint.") |
|
|
|
def get_data_and_condition( |
|
self, data_batch: dict[str, Tensor], **kwargs |
|
) -> Tuple[Tensor, VideoConditionerWithCtrl]: |
|
|
|
hint_key = self.config.hint_key["hint_key"] |
|
_data = {hint_key: data_batch[hint_key]} |
|
if IS_PREPROCESSED_KEY in data_batch: |
|
_data[IS_PREPROCESSED_KEY] = data_batch[IS_PREPROCESSED_KEY] |
|
data_batch[hint_key] = _data[hint_key] |
|
data_batch["hint_key"] = hint_key |
|
raw_state, latent_state, condition = super().get_data_and_condition(data_batch, **kwargs) |
|
use_multicontrol = ( |
|
("control_weight" in data_batch) |
|
and not isinstance(data_batch["control_weight"], float) |
|
and data_batch["control_weight"].shape[0] > 1 |
|
) |
|
if use_multicontrol: |
|
latent_hint = [] |
|
num_conditions = data_batch[data_batch["hint_key"]].size(1) // 3 |
|
for i in range(num_conditions): |
|
cond_mask = [False] * num_conditions |
|
cond_mask[i] = True |
|
latent_hint += [self.encode_latent(data_batch, cond_mask=cond_mask)] |
|
latent_hint = torch.cat(latent_hint) |
|
else: |
|
latent_hint = self.encode_latent(data_batch) |
|
|
|
|
|
data_batch["latent_hint"] = latent_hint |
|
setattr(condition, hint_key, latent_hint) |
|
setattr(condition, "base_model", self.model.base_model) |
|
return raw_state, latent_state, condition |
|
|
|
def get_x_from_clean( |
|
self, |
|
in_clean_img: torch.Tensor, |
|
sigma_max: float | None, |
|
seed: int = 1, |
|
) -> Tensor: |
|
""" |
|
in_clean_img (torch.Tensor): input clean image for image-to-image/video-to-video by adding noise then denoising |
|
sigma_max (float): maximum sigma applied to in_clean_image for image-to-image/video-to-video |
|
""" |
|
if in_clean_img is None: |
|
return None |
|
generator = torch.Generator(device=self.tensor_kwargs["device"]) |
|
generator.manual_seed(seed) |
|
noise = torch.randn(*in_clean_img.shape, **self.tensor_kwargs, generator=generator) |
|
if sigma_max is None: |
|
sigma_max = self.sde.sigma_max |
|
x_sigma_max = in_clean_img + noise * sigma_max |
|
return x_sigma_max |
|
|
|
def encode_latent(self, data_batch: dict, cond_mask: list = []) -> torch.Tensor: |
|
x = data_batch[data_batch["hint_key"]] |
|
latent = [] |
|
|
|
num_conditions = x.size(1) // 3 |
|
if num_conditions > 1 and self.config.hint_dropout_rate > 0: |
|
if torch.is_grad_enabled(): |
|
cond_mask = torch.rand(num_conditions) > self.config.hint_dropout_rate |
|
if not cond_mask.any(): |
|
cond_mask = [True] * num_conditions |
|
elif not cond_mask: |
|
cond_mask = self.config.hint_mask |
|
else: |
|
cond_mask = [True] * num_conditions |
|
for idx in range(0, x.size(1), 3): |
|
x_rgb = x[:, idx : idx + 3] |
|
if not cond_mask[idx // 3]: |
|
x_rgb = torch.zeros_like(x_rgb) |
|
latent.append(self.encode(x_rgb)) |
|
latent = torch.cat(latent, dim=1) |
|
return latent |
|
|
|
def get_x0_fn_from_batch( |
|
self, |
|
data_batch: Dict, |
|
guidance: float = 1.5, |
|
is_negative_prompt: bool = False, |
|
condition_latent: torch.Tensor = None, |
|
num_condition_t: Union[int, None] = None, |
|
condition_video_augment_sigma_in_inference: float = None, |
|
seed: int = 1, |
|
target_h: int = 88, |
|
target_w: int = 160, |
|
patch_h: int = 88, |
|
patch_w: int = 160, |
|
use_batch_processing: bool = True, |
|
) -> Callable: |
|
""" |
|
Generates a callable function `x0_fn` based on the provided data batch and guidance factor. |
|
|
|
This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. |
|
|
|
Args: |
|
- data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` |
|
- guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. |
|
- is_negative_prompt (bool): use negative prompt t5 in uncondition if true |
|
condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. |
|
- num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" |
|
- condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference |
|
- target_h (int): final stitched latent height |
|
- target_w (int): final stitched latent width |
|
- patch_h (int): latent patch height for each network inference |
|
- patch_w (int): latent patch width for each network inference |
|
|
|
Returns: |
|
- Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin |
|
|
|
The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. |
|
""" |
|
|
|
if is_negative_prompt: |
|
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) |
|
else: |
|
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) |
|
|
|
|
|
if condition_latent is None: |
|
condition_latent = torch.zeros(data_batch["latent_hint"].shape, **self.tensor_kwargs) |
|
num_condition_t = 0 |
|
condition_video_augment_sigma_in_inference = 1000 |
|
|
|
if use_batch_processing: |
|
condition = self.add_condition_video_indicator_and_video_input_mask( |
|
condition_latent, condition, num_condition_t |
|
) |
|
uncondition = self.add_condition_video_indicator_and_video_input_mask( |
|
condition_latent, uncondition, num_condition_t |
|
) |
|
else: |
|
condition = self.add_condition_video_indicator_and_video_input_mask( |
|
condition_latent[:1], condition, num_condition_t |
|
) |
|
uncondition = self.add_condition_video_indicator_and_video_input_mask( |
|
condition_latent[:1], uncondition, num_condition_t |
|
) |
|
condition.video_cond_bool = True |
|
uncondition.video_cond_bool = False |
|
|
|
|
|
latent_hint = data_batch["latent_hint"] |
|
hint_key = data_batch["hint_key"] |
|
setattr(condition, hint_key, latent_hint) |
|
if "use_none_hint" in data_batch and data_batch["use_none_hint"]: |
|
setattr(uncondition, hint_key, None) |
|
else: |
|
setattr(uncondition, hint_key, latent_hint) |
|
|
|
|
|
|
|
if "regional_contexts" in data_batch and "region_masks" in data_batch: |
|
setattr(condition, "regional_contexts", data_batch["regional_contexts"]) |
|
setattr(condition, "region_masks", data_batch["region_masks"]) |
|
|
|
setattr(uncondition, "region_masks", data_batch["region_masks"]) |
|
setattr(uncondition, "regional_contexts", None) |
|
|
|
to_cp = self.net.is_context_parallel_enabled |
|
|
|
if parallel_state.is_initialized(): |
|
condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) |
|
uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) |
|
|
|
cp_group = parallel_state.get_context_parallel_group() |
|
latent_hint = getattr(condition, hint_key) |
|
seq_dim = 3 if latent_hint.ndim == 6 else 2 |
|
latent_hint = split_inputs_cp(latent_hint, seq_dim=seq_dim, cp_group=cp_group) |
|
setattr(condition, hint_key, latent_hint) |
|
if getattr(uncondition, hint_key) is not None: |
|
setattr(uncondition, hint_key, latent_hint) |
|
if hasattr(condition, "regional_contexts") and getattr(condition, "regional_contexts") is not None: |
|
regional_contexts = getattr(condition, "regional_contexts") |
|
regional_contexts = split_inputs_cp(regional_contexts, seq_dim=2, cp_group=cp_group) |
|
setattr(condition, "regional_contexts", regional_contexts) |
|
|
|
if hasattr(condition, "region_masks") and getattr(condition, "region_masks") is not None: |
|
region_masks = getattr(condition, "region_masks") |
|
region_masks = split_inputs_cp(region_masks, seq_dim=2, cp_group=cp_group) |
|
setattr(condition, "region_masks", region_masks) |
|
setattr(uncondition, "region_masks", region_masks) |
|
|
|
setattr(condition, "base_model", self.model.base_model) |
|
setattr(uncondition, "base_model", self.model.base_model) |
|
if hasattr(self, "hint_encoders"): |
|
self.model.net.hint_encoders = self.hint_encoders |
|
|
|
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor): |
|
B = noise_x.shape[0] |
|
w, h = target_w, target_h |
|
n_img_w = (w - 1) // patch_w + 1 |
|
n_img_h = (h - 1) // patch_h + 1 |
|
|
|
overlap_size_w = overlap_size_h = 0 |
|
if n_img_w > 1: |
|
overlap_size_w = (n_img_w * patch_w - w) // (n_img_w - 1) |
|
assert n_img_w * patch_w - overlap_size_w * (n_img_w - 1) == w |
|
if n_img_h > 1: |
|
overlap_size_h = (n_img_h * patch_h - h) // (n_img_h - 1) |
|
assert n_img_h * patch_h - overlap_size_h * (n_img_h - 1) == h |
|
|
|
if use_batch_processing: |
|
condition.gt_latent = condition_latent |
|
uncondition.gt_latent = condition_latent |
|
setattr(condition, hint_key, latent_hint) |
|
if getattr(uncondition, hint_key) is not None: |
|
setattr(uncondition, hint_key, latent_hint) |
|
|
|
cond_x0 = self.denoise( |
|
noise_x, |
|
sigma, |
|
condition, |
|
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, |
|
seed=seed, |
|
).x0_pred_replaced |
|
|
|
uncond_x0 = self.denoise( |
|
noise_x, |
|
sigma, |
|
uncondition, |
|
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, |
|
seed=seed, |
|
).x0_pred_replaced |
|
x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) |
|
|
|
merged = merge_patches_into_video(x0, overlap_size_h, overlap_size_w, n_img_h, n_img_w) |
|
return split_video_into_patches(merged, patch_h, patch_w) |
|
|
|
batch_images = noise_x |
|
batch_sigma = sigma |
|
output = [] |
|
for idx, cur_images in enumerate(batch_images): |
|
noise_x = cur_images.unsqueeze(0) |
|
sigma = batch_sigma[idx : idx + 1] |
|
condition.gt_latent = condition_latent[idx : idx + 1] |
|
uncondition.gt_latent = condition_latent[idx : idx + 1] |
|
setattr(condition, hint_key, latent_hint[idx : idx + 1]) |
|
if getattr(uncondition, hint_key) is not None: |
|
setattr(uncondition, hint_key, latent_hint[idx : idx + 1]) |
|
|
|
cond_x0 = self.denoise( |
|
noise_x, |
|
sigma, |
|
condition, |
|
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, |
|
seed=seed, |
|
).x0_pred_replaced |
|
uncond_x0 = self.denoise( |
|
noise_x, |
|
sigma, |
|
uncondition, |
|
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, |
|
seed=seed, |
|
).x0_pred_replaced |
|
x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) |
|
output.append(x0) |
|
output = rearrange(torch.stack(output), "(n t) b ... -> (b n t) ...", n=n_img_h, t=n_img_w) |
|
final_output = merge_patches_into_video(output, overlap_size_h, overlap_size_w, n_img_h, n_img_w) |
|
final_output = split_video_into_patches(final_output, patch_h, patch_w) |
|
return final_output |
|
|
|
return x0_fn |
|
|
|
def generate_samples_from_batch( |
|
self, |
|
data_batch: Dict, |
|
guidance: float = 1.5, |
|
seed: int = 1, |
|
state_shape: Tuple | None = None, |
|
n_sample: int | None = None, |
|
is_negative_prompt: bool = False, |
|
num_steps: int = 35, |
|
condition_latent: Union[torch.Tensor, None] = None, |
|
num_condition_t: Union[int, None] = None, |
|
condition_video_augment_sigma_in_inference: float = None, |
|
x_sigma_max: Optional[torch.Tensor] = None, |
|
sigma_max: float | None = None, |
|
target_h: int = 88, |
|
target_w: int = 160, |
|
patch_h: int = 88, |
|
patch_w: int = 160, |
|
use_batch_processing: bool = True, |
|
) -> Tensor: |
|
""" |
|
Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. |
|
Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. |
|
If this feature is stablized, we could consider to move this function to the base model. |
|
|
|
Args: |
|
condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. |
|
num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half |
|
""" |
|
assert patch_h <= target_h and patch_w <= target_w |
|
if n_sample is None: |
|
input_key = self.input_data_key |
|
n_sample = data_batch[input_key].shape[0] |
|
if state_shape is None: |
|
log.debug(f"Default Video state shape is used. {self.state_shape}") |
|
state_shape = self.state_shape |
|
x0_fn = self.get_x0_fn_from_batch( |
|
data_batch, |
|
guidance, |
|
is_negative_prompt=is_negative_prompt, |
|
condition_latent=condition_latent, |
|
num_condition_t=num_condition_t, |
|
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, |
|
seed=seed, |
|
target_h=target_h, |
|
target_w=target_w, |
|
patch_h=patch_h, |
|
patch_w=patch_w, |
|
use_batch_processing=use_batch_processing, |
|
) |
|
|
|
if sigma_max is None: |
|
sigma_max = self.sde.sigma_max |
|
|
|
if x_sigma_max is None: |
|
x_sigma_max = ( |
|
misc.arch_invariant_rand( |
|
(n_sample,) + tuple(state_shape), |
|
torch.float32, |
|
self.tensor_kwargs["device"], |
|
seed, |
|
) |
|
* sigma_max |
|
) |
|
|
|
if self.net.is_context_parallel_enabled: |
|
x_sigma_max = broadcast(x_sigma_max, to_tp=False, to_cp=True) |
|
x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) |
|
|
|
samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max) |
|
|
|
if self.net.is_context_parallel_enabled: |
|
samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) |
|
|
|
return samples |
|
|
|
|
|
class VideoDiffusionT2VModelWithCtrl(DiffusionT2WModel): |
|
def build_model(self) -> torch.nn.ModuleDict: |
|
log.info("Start creating base model") |
|
base_model = super().build_model() |
|
|
|
config = self.config |
|
self.load_base_model(base_model) |
|
log.info("Done creating base model") |
|
|
|
log.info("Start creating ctrlnet model") |
|
net = lazy_instantiate(self.config.net_ctrl) |
|
conditioner = base_model.conditioner |
|
logvar = base_model.logvar |
|
|
|
model = torch.nn.ModuleDict({"net": net, "conditioner": conditioner, "logvar": logvar}) |
|
model.load_state_dict(base_model.state_dict(), strict=False) |
|
|
|
model.base_model = base_model |
|
log.info("Done creating ctrlnet model") |
|
|
|
self.hint_key = self.config.hint_key["hint_key"] |
|
return model |
|
|
|
@property |
|
def base_net(self): |
|
return self.model.base_model.net |
|
|
|
@property |
|
def conditioner(self): |
|
return self.model.conditioner |
|
|
|
def load_base_model(self, base_model) -> None: |
|
config = self.config |
|
if config.base_load_from is not None: |
|
checkpoint_path = config.base_load_from["load_path"] |
|
else: |
|
checkpoint_path = "" |
|
if checkpoint_path: |
|
log.info(f"Loading base model checkpoint (local): {checkpoint_path}") |
|
state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) |
|
log.success(f"Complete loading base model checkpoint (local): {checkpoint_path}") |
|
|
|
if "ema" in state_dict: |
|
|
|
log.info("Copying ema to base model") |
|
base_state_dict = {k.replace("-", "."): v for k, v in state_dict["ema"].items()} |
|
elif "model" in state_dict: |
|
|
|
log.warning("Using non-EMA base model") |
|
base_state_dict = state_dict["model"] |
|
else: |
|
log.info("Loading from an EMA only model") |
|
base_state_dict = state_dict |
|
base_model.load_state_dict(base_state_dict, strict=False) |
|
log.info("Done loading the base model checkpoint.") |
|
|
|
def get_data_and_condition( |
|
self, data_batch: dict[str, Tensor], **kwargs |
|
) -> Tuple[Tensor, VideoConditionerWithCtrl]: |
|
|
|
hint_key = self.config.hint_key["hint_key"] |
|
_data = {hint_key: data_batch[hint_key]} |
|
if IS_PREPROCESSED_KEY in data_batch: |
|
_data[IS_PREPROCESSED_KEY] = data_batch[IS_PREPROCESSED_KEY] |
|
data_batch[hint_key] = _data[hint_key] |
|
data_batch["hint_key"] = hint_key |
|
raw_state, latent_state, condition = super().get_data_and_condition(data_batch, **kwargs) |
|
use_multicontrol = ( |
|
("control_weight" in data_batch) |
|
and not isinstance(data_batch["control_weight"], float) |
|
and data_batch["control_weight"].shape[0] > 1 |
|
) |
|
if use_multicontrol: |
|
latent_hint = [] |
|
num_conditions = data_batch[data_batch["hint_key"]].size(1) // 3 |
|
for i in range(num_conditions): |
|
cond_mask = [False] * num_conditions |
|
cond_mask[i] = True |
|
latent_hint += [self.encode_latent(data_batch, cond_mask=cond_mask)] |
|
latent_hint = torch.cat(latent_hint) |
|
else: |
|
latent_hint = self.encode_latent(data_batch) |
|
|
|
|
|
data_batch["latent_hint"] = latent_hint |
|
setattr(condition, hint_key, latent_hint) |
|
setattr(condition, "base_model", self.model.base_model) |
|
return raw_state, latent_state, condition |
|
|
|
def get_x_from_clean( |
|
self, |
|
in_clean_img: torch.Tensor, |
|
sigma_max: float | None, |
|
seed: int = 1, |
|
) -> Tensor: |
|
""" |
|
in_clean_img (torch.Tensor): input clean image for image-to-image/video-to-video by adding noise then denoising |
|
sigma_max (float): maximum sigma applied to in_clean_image for image-to-image/video-to-video |
|
""" |
|
if in_clean_img is None: |
|
return None |
|
generator = torch.Generator(device=self.tensor_kwargs["device"]) |
|
generator.manual_seed(seed) |
|
noise = torch.randn(*in_clean_img.shape, **self.tensor_kwargs, generator=generator) |
|
if sigma_max is None: |
|
sigma_max = self.sde.sigma_max |
|
x_sigma_max = in_clean_img + noise * sigma_max |
|
return x_sigma_max |
|
|
|
def encode_latent(self, data_batch: dict, cond_mask: list = []) -> torch.Tensor: |
|
x = data_batch[data_batch["hint_key"]] |
|
latent = [] |
|
|
|
num_conditions = x.size(1) // 3 |
|
if num_conditions > 1 and self.config.hint_dropout_rate > 0: |
|
if torch.is_grad_enabled(): |
|
cond_mask = torch.rand(num_conditions) > self.config.hint_dropout_rate |
|
if not cond_mask.any(): |
|
cond_mask = [True] * num_conditions |
|
elif not cond_mask: |
|
cond_mask = self.config.hint_mask |
|
else: |
|
cond_mask = [True] * num_conditions |
|
for idx in range(0, x.size(1), 3): |
|
x_rgb = x[:, idx : idx + 3] |
|
if not cond_mask[idx // 3]: |
|
x_rgb = torch.zeros_like(x_rgb) |
|
latent.append(self.encode(x_rgb)) |
|
latent = torch.cat(latent, dim=1) |
|
return latent |
|
|
|
def get_x0_fn_from_batch( |
|
self, |
|
data_batch: Dict, |
|
guidance: float = 1.5, |
|
is_negative_prompt: bool = False, |
|
condition_latent: torch.Tensor = None, |
|
num_condition_t: Union[int, None] = None, |
|
condition_video_augment_sigma_in_inference: float = None, |
|
) -> Callable: |
|
""" |
|
Generates a callable function `x0_fn` based on the provided data batch and guidance factor. |
|
|
|
This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. |
|
|
|
Args: |
|
- data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` |
|
- guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. |
|
- is_negative_prompt (bool): use negative prompt t5 in uncondition if true |
|
condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. |
|
- num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" |
|
- condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference |
|
|
|
Returns: |
|
- Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin |
|
|
|
The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. |
|
""" |
|
|
|
if is_negative_prompt: |
|
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) |
|
else: |
|
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) |
|
|
|
latent_hint = data_batch["latent_hint"] |
|
hint_key = data_batch["hint_key"] |
|
setattr(condition, hint_key, latent_hint) |
|
if "use_none_hint" in data_batch and data_batch["use_none_hint"]: |
|
setattr(uncondition, hint_key, None) |
|
else: |
|
setattr(uncondition, hint_key, latent_hint) |
|
|
|
|
|
if "regional_contexts" in data_batch and "region_masks" in data_batch: |
|
setattr(condition, "regional_contexts", data_batch["regional_contexts"]) |
|
setattr(condition, "region_masks", data_batch["region_masks"]) |
|
|
|
setattr(uncondition, "region_masks", data_batch["region_masks"]) |
|
setattr(uncondition, "regional_contexts", None) |
|
|
|
to_cp = self.net.is_context_parallel_enabled |
|
|
|
if parallel_state.is_initialized(): |
|
condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) |
|
uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) |
|
|
|
cp_group = parallel_state.get_context_parallel_group() |
|
latent_hint = getattr(condition, hint_key) |
|
seq_dim = 3 if latent_hint.ndim == 6 else 2 |
|
latent_hint = split_inputs_cp(latent_hint, seq_dim=seq_dim, cp_group=cp_group) |
|
setattr(condition, hint_key, latent_hint) |
|
if getattr(uncondition, hint_key) is not None: |
|
setattr(uncondition, hint_key, latent_hint) |
|
|
|
if hasattr(condition, "regional_contexts") and getattr(condition, "regional_contexts") is not None: |
|
regional_contexts = getattr(condition, "regional_contexts") |
|
regional_contexts = split_inputs_cp(regional_contexts, seq_dim=2, cp_group=cp_group) |
|
setattr(condition, "regional_contexts", regional_contexts) |
|
|
|
if hasattr(condition, "region_masks") and getattr(condition, "region_masks") is not None: |
|
region_masks = getattr(condition, "region_masks") |
|
region_masks = split_inputs_cp(region_masks, seq_dim=2, cp_group=cp_group) |
|
setattr(condition, "region_masks", region_masks) |
|
setattr(uncondition, "region_masks", region_masks) |
|
setattr(condition, "base_model", self.model.base_model) |
|
setattr(uncondition, "base_model", self.model.base_model) |
|
if hasattr(self, "hint_encoders"): |
|
self.model.net.hint_encoders = self.hint_encoders |
|
|
|
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: |
|
cond_x0 = self.denoise( |
|
noise_x, |
|
sigma, |
|
condition, |
|
).x0 |
|
uncond_x0 = self.denoise( |
|
noise_x, |
|
sigma, |
|
uncondition, |
|
).x0 |
|
return cond_x0 + guidance * (cond_x0 - uncond_x0) |
|
|
|
return x0_fn |
|
|
|
def generate_samples_from_batch( |
|
self, |
|
data_batch: Dict, |
|
guidance: float = 1.5, |
|
seed: int = 1, |
|
state_shape: Tuple | None = None, |
|
n_sample: int | None = None, |
|
is_negative_prompt: bool = False, |
|
num_steps: int = 35, |
|
condition_latent: Union[torch.Tensor, None] = None, |
|
num_condition_t: Union[int, None] = None, |
|
condition_video_augment_sigma_in_inference: float = None, |
|
x_sigma_max: Optional[torch.Tensor] = None, |
|
sigma_max: float | None = None, |
|
**kwargs, |
|
) -> Tensor: |
|
""" |
|
Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. |
|
Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. |
|
If this feature is stablized, we could consider to move this function to the base model. |
|
|
|
Args: |
|
condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. |
|
num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half |
|
""" |
|
if n_sample is None: |
|
input_key = self.input_data_key |
|
n_sample = data_batch[input_key].shape[0] |
|
if state_shape is None: |
|
log.debug(f"Default Video state shape is used. {self.state_shape}") |
|
state_shape = self.state_shape |
|
|
|
x0_fn = self.get_x0_fn_from_batch( |
|
data_batch, |
|
guidance, |
|
is_negative_prompt=is_negative_prompt, |
|
condition_latent=condition_latent, |
|
num_condition_t=num_condition_t, |
|
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, |
|
) |
|
|
|
if sigma_max is None: |
|
sigma_max = self.sde.sigma_max |
|
|
|
if x_sigma_max is None: |
|
x_sigma_max = ( |
|
misc.arch_invariant_rand( |
|
(n_sample,) + tuple(state_shape), |
|
torch.float32, |
|
self.tensor_kwargs["device"], |
|
seed, |
|
) |
|
* sigma_max |
|
) |
|
|
|
if self.net.is_context_parallel_enabled: |
|
x_sigma_max = broadcast(x_sigma_max, to_tp=False, to_cp=True) |
|
x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) |
|
|
|
samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max) |
|
|
|
if self.net.is_context_parallel_enabled: |
|
samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) |
|
return samples |
|
|
|
|
|
class VideoDistillModelWithCtrl(DistillV2WModel): |
|
def build_model(self) -> torch.nn.ModuleDict: |
|
log.info("Start creating base model") |
|
base_model = super().build_model() |
|
|
|
log.info("Done creating base model") |
|
|
|
log.info("Start creating ctrlnet model") |
|
net = DistillControlNet(self.config) |
|
net.base_model.net.load_state_dict(base_model["net"].state_dict()) |
|
conditioner = base_model.conditioner |
|
logvar = base_model.logvar |
|
|
|
model = torch.nn.ModuleDict({"net": net, "conditioner": conditioner, "logvar": logvar}) |
|
model.base_model = net.base_model.net |
|
|
|
self.hint_key = self.config.hint_key["hint_key"] |
|
return model |
|
|
|
@property |
|
def base_net(self): |
|
return self.model.base_model.net |
|
|
|
@property |
|
def conditioner(self): |
|
return self.model.conditioner |
|
|
|
def get_data_and_condition( |
|
self, data_batch: dict[str, Tensor], **kwargs |
|
) -> Tuple[Tensor, VideoConditionerWithCtrl]: |
|
|
|
hint_key = self.config.hint_key["hint_key"] |
|
_data = {hint_key: data_batch[hint_key]} |
|
if IS_PREPROCESSED_KEY in data_batch: |
|
_data[IS_PREPROCESSED_KEY] = data_batch[IS_PREPROCESSED_KEY] |
|
data_batch[hint_key] = _data[hint_key] |
|
data_batch["hint_key"] = hint_key |
|
raw_state, latent_state, condition = super().get_data_and_condition(data_batch, **kwargs) |
|
use_multicontrol = ( |
|
("control_weight" in data_batch) |
|
and not isinstance(data_batch["control_weight"], float) |
|
and data_batch["control_weight"].shape[0] > 1 |
|
) |
|
if use_multicontrol: |
|
latent_hint = [] |
|
num_conditions = data_batch[data_batch["hint_key"]].size(1) // 3 |
|
for i in range(num_conditions): |
|
cond_mask = [False] * num_conditions |
|
cond_mask[i] = True |
|
latent_hint += [self.encode_latent(data_batch, cond_mask=cond_mask)] |
|
latent_hint = torch.cat(latent_hint) |
|
else: |
|
latent_hint = self.encode_latent(data_batch) |
|
|
|
|
|
data_batch["latent_hint"] = latent_hint |
|
setattr(condition, hint_key, latent_hint) |
|
setattr(condition, "base_model", self.model.base_model) |
|
return raw_state, latent_state, condition |
|
|
|
def get_x_from_clean( |
|
self, |
|
in_clean_img: torch.Tensor, |
|
sigma_max: float | None, |
|
seed: int = 1, |
|
) -> Tensor: |
|
""" |
|
in_clean_img (torch.Tensor): input clean image for image-to-image/video-to-video by adding noise then denoising |
|
sigma_max (float): maximum sigma applied to in_clean_image for image-to-image/video-to-video |
|
""" |
|
if in_clean_img is None: |
|
return None |
|
generator = torch.Generator(device=self.tensor_kwargs["device"]) |
|
generator.manual_seed(seed) |
|
noise = torch.randn(*in_clean_img.shape, **self.tensor_kwargs, generator=generator) |
|
if sigma_max is None: |
|
sigma_max = self.sde.sigma_max |
|
x_sigma_max = in_clean_img + noise * sigma_max |
|
return x_sigma_max |
|
|
|
def encode_latent(self, data_batch: dict, cond_mask: list = []) -> torch.Tensor: |
|
x = data_batch[data_batch["hint_key"]] |
|
latent = [] |
|
|
|
num_conditions = x.size(1) // 3 |
|
if num_conditions > 1 and self.config.hint_dropout_rate > 0: |
|
if torch.is_grad_enabled(): |
|
cond_mask = torch.rand(num_conditions) > self.config.hint_dropout_rate |
|
if not cond_mask.any(): |
|
cond_mask = [True] * num_conditions |
|
elif not cond_mask: |
|
cond_mask = self.config.hint_mask |
|
else: |
|
cond_mask = [True] * num_conditions |
|
for idx in range(0, x.size(1), 3): |
|
x_rgb = x[:, idx : idx + 3] |
|
if not cond_mask[idx // 3]: |
|
x_rgb = torch.zeros_like(x_rgb) |
|
latent.append(self.encode(x_rgb)) |
|
latent = torch.cat(latent, dim=1) |
|
return latent |
|
|
|
def generate_samples_from_batch( |
|
self, |
|
data_batch: Dict, |
|
guidance: float = 1.5, |
|
seed: int = 1, |
|
state_shape: Tuple | None = None, |
|
n_sample: int | None = None, |
|
is_negative_prompt: bool = False, |
|
num_steps: int = 1, |
|
condition_latent: Union[torch.Tensor, None] = None, |
|
num_condition_t: Union[int, None] = None, |
|
condition_video_augment_sigma_in_inference: float = None, |
|
x_sigma_max: Optional[torch.Tensor] = None, |
|
sigma_max: float | None = None, |
|
target_h: int = 88, |
|
target_w: int = 160, |
|
patch_h: int = 88, |
|
patch_w: int = 160, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
"""Single-step generation matching internal distilled model""" |
|
|
|
|
|
self._normalize_video_databatch_inplace(data_batch) |
|
self._augment_image_dim_inplace(data_batch) |
|
|
|
if n_sample is None: |
|
|
|
input_key = self.input_data_key |
|
n_sample = data_batch[input_key].shape[0] |
|
if state_shape is None: |
|
log.debug(f"Default Video state shape is used. {self.state_shape}") |
|
state_shape = self.state_shape |
|
if sigma_max is None: |
|
sigma_max = self.sde.sigma_max |
|
if x_sigma_max is None: |
|
x_sigma_max = ( |
|
misc.arch_invariant_rand( |
|
(n_sample,) + tuple(state_shape), |
|
torch.float32, |
|
self.tensor_kwargs["device"], |
|
seed, |
|
) |
|
* sigma_max |
|
) |
|
|
|
|
|
batch_shape = (n_sample, *state_shape) |
|
generator = torch.Generator(device=self.tensor_kwargs["device"]) |
|
generator.manual_seed(seed) |
|
random_noise = torch.randn(*batch_shape, generator=generator, **self.tensor_kwargs) |
|
|
|
if is_negative_prompt: |
|
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) |
|
else: |
|
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) |
|
|
|
if condition_latent is None: |
|
condition_latent = torch.zeros(data_batch["latent_hint"].shape, **self.tensor_kwargs) |
|
num_condition_t = 0 |
|
condition_video_augment_sigma_in_inference = 1000 |
|
|
|
condition.video_cond_bool = True |
|
condition = self.add_condition_video_indicator_and_video_input_mask( |
|
condition_latent, condition, num_condition_t |
|
) |
|
uncondition.video_cond_bool = True |
|
uncondition = self.add_condition_video_indicator_and_video_input_mask( |
|
condition_latent, uncondition, num_condition_t |
|
) |
|
uncondition.condition_video_indicator = condition.condition_video_indicator.clone() |
|
uncondition.condition_video_input_mask = condition.condition_video_input_mask.clone() |
|
|
|
latent_hint = data_batch["latent_hint"] |
|
hint_key = data_batch["hint_key"] |
|
setattr(condition, hint_key, latent_hint) |
|
if "use_none_hint" in data_batch and data_batch["use_none_hint"]: |
|
setattr(uncondition, hint_key, None) |
|
else: |
|
setattr(uncondition, hint_key, latent_hint) |
|
|
|
to_cp = self.net.is_context_parallel_enabled |
|
|
|
if parallel_state.is_initialized(): |
|
condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) |
|
uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) |
|
|
|
cp_group = parallel_state.get_context_parallel_group() |
|
latent_hint = getattr(condition, hint_key) |
|
seq_dim = 3 if latent_hint.ndim == 6 else 2 |
|
latent_hint = split_inputs_cp(latent_hint, seq_dim=seq_dim, cp_group=cp_group) |
|
setattr(condition, hint_key, latent_hint) |
|
if getattr(uncondition, hint_key) is not None: |
|
setattr(uncondition, hint_key, latent_hint) |
|
|
|
|
|
setattr(condition, "base_model", self.model.base_model) |
|
setattr(uncondition, "base_model", self.model.base_model) |
|
if hasattr(self, "hint_encoders"): |
|
self.model.net.hint_encoders = self.hint_encoders |
|
|
|
cp_enabled = self.net.is_context_parallel_enabled |
|
if cp_enabled: |
|
random_noise = split_inputs_cp(x=random_noise, seq_dim=2, cp_group=self.net.cp_group) |
|
|
|
condition.gt_latent = condition_latent |
|
uncondition.gt_latent = condition_latent |
|
|
|
if self.net.is_context_parallel_enabled: |
|
x_sigma_max = broadcast(x_sigma_max, to_tp=False, to_cp=True) |
|
x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) |
|
|
|
samples = self._forward_distilled( |
|
epsilon=random_noise, |
|
condition=condition, |
|
uncondition=uncondition, |
|
guidance=guidance, |
|
hint_key=hint_key, |
|
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, |
|
target_h=target_h, |
|
target_w=target_w, |
|
patch_h=patch_h, |
|
patch_w=patch_w, |
|
seed=seed, |
|
inference_mode=True, |
|
**kwargs, |
|
) |
|
cp_enabled = self.net.is_context_parallel_enabled |
|
if cp_enabled: |
|
samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) |
|
|
|
return samples |
|
|
|
def _forward_distilled( |
|
self, |
|
epsilon: torch.Tensor, |
|
condition: Any, |
|
uncondition: Any, |
|
guidance: float, |
|
hint_key: str, |
|
condition_video_augment_sigma_in_inference: float = 0.001, |
|
target_h: int = 88, |
|
target_w: int = 160, |
|
patch_h: int = 88, |
|
patch_w: int = 160, |
|
seed: int = 1, |
|
inference_mode: bool = True, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
"""Single forward pass for distilled models""" |
|
B = epsilon.shape[0] |
|
|
|
w, h = target_w, target_h |
|
n_img_w = (w - 1) // patch_w + 1 |
|
n_img_h = (h - 1) // patch_h + 1 |
|
|
|
overlap_size_w = overlap_size_h = 0 |
|
if n_img_w > 1: |
|
overlap_size_w = (n_img_w * patch_w - w) // (n_img_w - 1) |
|
assert n_img_w * patch_w - overlap_size_w * (n_img_w - 1) == w |
|
if n_img_h > 1: |
|
overlap_size_h = (n_img_h * patch_h - h) // (n_img_h - 1) |
|
assert n_img_h * patch_h - overlap_size_h * (n_img_h - 1) == h |
|
|
|
|
|
sigma_max = torch.tensor(self.sde.sigma_max).repeat(epsilon.size(0)).to(epsilon.device) |
|
|
|
|
|
with torch.no_grad(): |
|
cond_x0 = self.denoise( |
|
noise_x=epsilon * self.sde.sigma_max, |
|
sigma=sigma_max, |
|
condition=condition, |
|
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, |
|
seed=seed, |
|
).x0_pred_replaced |
|
return cond_x0 |
|
|