|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
from typing import Callable, Dict, Optional, Tuple, Union |
|
|
|
import torch |
|
from einops import rearrange |
|
from megatron.core import parallel_state |
|
from torch import Tensor |
|
|
|
from cosmos_transfer1.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS |
|
from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp |
|
from cosmos_transfer1.diffusion.training.models.model import DiffusionModel, broadcast_condition |
|
from cosmos_transfer1.diffusion.training.models.model_image import CosmosCondition, diffusion_fsdp_class_decorator |
|
from cosmos_transfer1.utils import log, misc |
|
|
|
|
|
def deepcopy_no_copy_model(obj): |
|
""" |
|
We need to create a copy of the condition construct such that condition masks can be adjusted dynamically, but |
|
the controlnet encoder plug-in also uses the condition construct to pass along the base_model object which cannot be |
|
deep-copied, hence this funciton |
|
""" |
|
if hasattr(obj, "base_model") and obj.base_model is not None: |
|
my_base_model = obj.base_model |
|
obj.base_model = None |
|
copied_obj = copy.deepcopy(obj) |
|
copied_obj.base_model = my_base_model |
|
obj.base_model = my_base_model |
|
else: |
|
copied_obj = copy.deepcopy(obj) |
|
return copied_obj |
|
|
|
|
|
class MultiviewDiffusionModel(DiffusionModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.n_views = config.n_views |
|
|
|
@torch.no_grad() |
|
def encode(self, state: torch.Tensor) -> torch.Tensor: |
|
state = rearrange(state, "B C (V T) H W -> (B V) C T H W", V=self.n_views) |
|
encoded_state = self.vae.encode(state) |
|
encoded_state = rearrange(encoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) * self.sigma_data |
|
return encoded_state |
|
|
|
@torch.no_grad() |
|
def decode(self, latent: torch.Tensor) -> torch.Tensor: |
|
latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) |
|
decoded_state = self.vae.decode(latent / self.sigma_data) |
|
decoded_state = rearrange(decoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) |
|
return decoded_state |
|
|
|
def compute_loss_with_epsilon_and_sigma( |
|
self, |
|
data_batch: dict[str, torch.Tensor], |
|
x0_from_data_batch: torch.Tensor, |
|
x0: torch.Tensor, |
|
condition: CosmosCondition, |
|
epsilon: torch.Tensor, |
|
sigma: torch.Tensor, |
|
): |
|
if self.is_image_batch(data_batch): |
|
|
|
self.net.disable_context_parallel() |
|
else: |
|
if parallel_state.is_initialized(): |
|
if parallel_state.get_context_parallel_world_size() > 1: |
|
|
|
cp_group = parallel_state.get_context_parallel_group() |
|
self.net.enable_context_parallel(cp_group) |
|
log.debug("[CP] Split x0 and epsilon") |
|
|
|
x0 = rearrange(x0, "B C (V T) H W -> (B V) C T H W", V=self.n_views) |
|
epsilon = rearrange(epsilon, "B C (V T) H W -> (B V) C T H W", V=self.n_views) |
|
|
|
x0 = split_inputs_cp(x=x0, seq_dim=2, cp_group=self.net.cp_group) |
|
epsilon = split_inputs_cp(x=epsilon, seq_dim=2, cp_group=self.net.cp_group) |
|
|
|
x0 = rearrange(x0, "(B V) C T H W -> B C (V T) H W", V=self.n_views) |
|
epsilon = rearrange(epsilon, "(B V) C T H W -> B C (V T) H W", V=self.n_views) |
|
|
|
output_batch, kendall_loss, pred_mse, edm_loss = super( |
|
DiffusionModel, self |
|
).compute_loss_with_epsilon_and_sigma(data_batch, x0_from_data_batch, x0, condition, epsilon, sigma) |
|
if not self.is_image_batch(data_batch): |
|
if self.loss_reduce == "sum" and parallel_state.get_context_parallel_world_size() > 1: |
|
kendall_loss *= parallel_state.get_context_parallel_world_size() |
|
|
|
return output_batch, kendall_loss, pred_mse, edm_loss |
|
|
|
def generate_samples_from_batch( |
|
self, |
|
data_batch: Dict, |
|
guidance: float = 1.5, |
|
seed: int = 1, |
|
state_shape: Tuple | None = None, |
|
n_sample: int | None = None, |
|
is_negative_prompt: bool = False, |
|
num_steps: int = 35, |
|
solver_option: COMMON_SOLVER_OPTIONS = "2ab", |
|
x_sigma_max: Optional[torch.Tensor] = None, |
|
sigma_max: float | None = None, |
|
guidance_other: Union[float, None] = None, |
|
) -> Tensor: |
|
""" |
|
Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. |
|
Args: |
|
data_batch (dict): raw data batch draw from the training data loader. |
|
iteration (int): Current iteration number. |
|
guidance (float): guidance weights |
|
seed (int): random seed |
|
state_shape (tuple): shape of the state, default to self.state_shape if not provided |
|
n_sample (int): number of samples to generate |
|
is_negative_prompt (bool): use negative prompt t5 in uncondition if true |
|
num_steps (int): number of steps for the diffusion process |
|
solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) |
|
""" |
|
self._normalize_video_databatch_inplace(data_batch) |
|
self._augment_image_dim_inplace(data_batch) |
|
is_image_batch = self.is_image_batch(data_batch) |
|
if n_sample is None: |
|
input_key = self.input_image_key if is_image_batch else self.input_data_key |
|
n_sample = data_batch[input_key].shape[0] |
|
if state_shape is None: |
|
if is_image_batch: |
|
state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) |
|
x0_fn = self.get_x0_fn_from_batch( |
|
data_batch, guidance, is_negative_prompt=is_negative_prompt, guidance_other=guidance_other |
|
) |
|
x_sigma_max = ( |
|
misc.arch_invariant_rand( |
|
(n_sample,) + tuple(state_shape), |
|
torch.float32, |
|
self.tensor_kwargs["device"], |
|
seed, |
|
) |
|
* self.sde.sigma_max |
|
) |
|
if self.net.is_context_parallel_enabled: |
|
x_sigma_max = rearrange(x_sigma_max, "B C (V T) H W -> (B V) C T H W", V=self.n_views) |
|
|
|
x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) |
|
|
|
x_sigma_max = rearrange(x_sigma_max, "(B V) C T H W -> B C (V T) H W", V=self.n_views) |
|
|
|
samples = self.sampler( |
|
x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max, solver_option=solver_option |
|
) |
|
if self.net.is_context_parallel_enabled: |
|
samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views) |
|
samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) |
|
samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) |
|
|
|
return samples |
|
|
|
def get_x0_fn_from_batch( |
|
self, |
|
data_batch: Dict, |
|
guidance: float = 1.5, |
|
is_negative_prompt: bool = False, |
|
guidance_other: Union[float, None] = None, |
|
) -> Callable: |
|
""" |
|
Generates a callable function `x0_fn` based on the provided data batch and guidance factor. |
|
|
|
This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. |
|
|
|
Args: |
|
- data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` |
|
- guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. |
|
- is_negative_prompt (bool): use negative prompt t5 in uncondition if true |
|
|
|
Returns: |
|
- Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin |
|
|
|
The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. |
|
""" |
|
if is_negative_prompt: |
|
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) |
|
else: |
|
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) |
|
|
|
to_cp = self.net.is_context_parallel_enabled |
|
|
|
if parallel_state.is_initialized(): |
|
condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) |
|
uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) |
|
else: |
|
assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." |
|
|
|
if guidance_other is not None: |
|
|
|
assert not parallel_state.is_initialized(), "Parallel state not supported with two guidances." |
|
condition_other = copy.deepcopy(uncondition) |
|
condition_other.trajectory = condition.trajectory |
|
|
|
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: |
|
cond_x0 = self.denoise(noise_x, sigma, condition).x0 |
|
uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 |
|
cond_other_x0 = self.denoise(noise_x, sigma, condition_other).x0 |
|
|
|
raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + guidance_other * (cond_other_x0 - uncond_x0) |
|
|
|
if "guided_image" in data_batch: |
|
assert False, "not supported" |
|
return raw_x0 |
|
|
|
else: |
|
|
|
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: |
|
cond_x0 = self.denoise(noise_x, sigma, condition).x0 |
|
uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 |
|
raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) |
|
if "guided_image" in data_batch: |
|
|
|
assert "guided_mask" in data_batch, "guided_mask should be in data_batch if guided_image is present" |
|
guide_image = data_batch["guided_image"] |
|
guide_mask = data_batch["guided_mask"] |
|
raw_x0 = guide_mask * guide_image + (1 - guide_mask) * raw_x0 |
|
return raw_x0 |
|
|
|
return x0_fn |
|
|
|
|
|
@diffusion_fsdp_class_decorator |
|
class FSDPDiffusionModel(MultiviewDiffusionModel): |
|
pass |
|
|