File size: 47,086 Bytes
226c7c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union
import torch
from einops import rearrange
from megatron.core import parallel_state
from torch import Tensor
from cosmos_transfer1.diffusion.conditioner import BaseVideoCondition, VideoConditionerWithCtrl, VideoExtendCondition
from cosmos_transfer1.diffusion.inference.inference_utils import merge_patches_into_video, split_video_into_patches
from cosmos_transfer1.diffusion.model.model_t2w import DiffusionT2WModel, broadcast_condition
from cosmos_transfer1.diffusion.model.model_v2w import DiffusionV2WModel, DistillV2WModel
from cosmos_transfer1.diffusion.module.parallel import broadcast, cat_outputs_cp, split_inputs_cp
from cosmos_transfer1.diffusion.networks.distill_controlnet_wrapper import DistillControlNet
from cosmos_transfer1.utils import log, misc
from cosmos_transfer1.utils.lazy_config import instantiate as lazy_instantiate
T = TypeVar("T")
IS_PREPROCESSED_KEY = "is_preprocessed"
class VideoDiffusionModelWithCtrl(DiffusionV2WModel):
def build_model(self) -> torch.nn.ModuleDict:
log.info("Start creating base model")
base_model = super().build_model()
# initialize base model
self.load_base_model(base_model)
log.info("Done creating base model")
log.info("Start creating ctrlnet model")
net = lazy_instantiate(self.config.net_ctrl)
conditioner = base_model.conditioner
logvar = base_model.logvar
# initialize controlnet encoder
model = torch.nn.ModuleDict({"net": net, "conditioner": conditioner, "logvar": logvar})
model.load_state_dict(base_model.state_dict(), strict=False)
model.base_model = base_model
log.info("Done creating ctrlnet model")
self.hint_key = self.config.hint_key["hint_key"]
return model
@property
def base_net(self):
return self.model.base_model.net
@property
def conditioner(self):
return self.model.conditioner
def load_base_model(self, base_model) -> None:
config = self.config
if config.base_load_from is not None:
checkpoint_path = config.base_load_from["load_path"]
else:
checkpoint_path = ""
if checkpoint_path:
log.info(f"Loading base model checkpoint (local): {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False)
log.success(f"Complete loading base model checkpoint (local): {checkpoint_path}")
if "ema" in state_dict and state_dict["ema"] is not None:
# Copy the base model weights from ema model.
log.info("Copying ema to base model")
base_state_dict = {k.replace("-", "."): v for k, v in state_dict["ema"].items()}
elif "model" in state_dict:
# Copy the base model weights from reg model.
log.warning("Using non-EMA base model")
base_state_dict = state_dict["model"]
else:
log.info("Loading from an EMA only model")
base_state_dict = state_dict
missing, unexpected = base_model.load_state_dict(base_state_dict, strict=False)
log.info(f"Missing: {missing}")
log.info(f"Unexpected: {unexpected}")
log.info("Done loading the base model checkpoint.")
def get_data_and_condition(
self, data_batch: dict[str, Tensor], **kwargs
) -> Tuple[Tensor, VideoConditionerWithCtrl]:
# process the control input
hint_key = self.config.hint_key["hint_key"]
_data = {hint_key: data_batch[hint_key]}
if IS_PREPROCESSED_KEY in data_batch:
_data[IS_PREPROCESSED_KEY] = data_batch[IS_PREPROCESSED_KEY]
data_batch[hint_key] = _data[hint_key]
data_batch["hint_key"] = hint_key
raw_state, latent_state, condition = super().get_data_and_condition(data_batch, **kwargs)
use_multicontrol = (
("control_weight" in data_batch)
and not isinstance(data_batch["control_weight"], float)
and data_batch["control_weight"].shape[0] > 1
)
if use_multicontrol: # encode individual conditions separately
latent_hint = []
num_conditions = data_batch[data_batch["hint_key"]].size(1) // 3
for i in range(num_conditions):
cond_mask = [False] * num_conditions
cond_mask[i] = True
latent_hint += [self.encode_latent(data_batch, cond_mask=cond_mask)]
latent_hint = torch.cat(latent_hint)
else:
latent_hint = self.encode_latent(data_batch)
# add extra conditions
data_batch["latent_hint"] = latent_hint
setattr(condition, hint_key, latent_hint)
setattr(condition, "base_model", self.model.base_model)
return raw_state, latent_state, condition
def get_x_from_clean(
self,
in_clean_img: torch.Tensor,
sigma_max: float | None,
seed: int = 1,
) -> Tensor:
"""
in_clean_img (torch.Tensor): input clean image for image-to-image/video-to-video by adding noise then denoising
sigma_max (float): maximum sigma applied to in_clean_image for image-to-image/video-to-video
"""
if in_clean_img is None:
return None
generator = torch.Generator(device=self.tensor_kwargs["device"])
generator.manual_seed(seed)
noise = torch.randn(*in_clean_img.shape, **self.tensor_kwargs, generator=generator)
if sigma_max is None:
sigma_max = self.sde.sigma_max
x_sigma_max = in_clean_img + noise * sigma_max
return x_sigma_max
def encode_latent(self, data_batch: dict, cond_mask: list = []) -> torch.Tensor:
x = data_batch[data_batch["hint_key"]]
latent = []
# control input goes through tokenizer, which always takes 3-input channels
num_conditions = x.size(1) // 3 # input conditions were concatenated along channel dimension
if num_conditions > 1 and self.config.hint_dropout_rate > 0:
if torch.is_grad_enabled(): # during training, randomly dropout some conditions
cond_mask = torch.rand(num_conditions) > self.config.hint_dropout_rate
if not cond_mask.any(): # make sure at least one condition is present
cond_mask = [True] * num_conditions
elif not cond_mask: # during inference, use hint_mask to indicate which conditions are used
cond_mask = self.config.hint_mask
else:
cond_mask = [True] * num_conditions
for idx in range(0, x.size(1), 3):
x_rgb = x[:, idx : idx + 3]
if not cond_mask[idx // 3]: # if the condition is not selected, replace with a black image
x_rgb = torch.zeros_like(x_rgb)
latent.append(self.encode(x_rgb))
latent = torch.cat(latent, dim=1)
return latent
def get_x0_fn_from_batch(
self,
data_batch: Dict,
guidance: float = 1.5,
is_negative_prompt: bool = False,
condition_latent: torch.Tensor = None,
num_condition_t: Union[int, None] = None,
condition_video_augment_sigma_in_inference: float = None,
seed: int = 1,
target_h: int = 88,
target_w: int = 160,
patch_h: int = 88,
patch_w: int = 160,
use_batch_processing: bool = True,
) -> Callable:
"""
Generates a callable function `x0_fn` based on the provided data batch and guidance factor.
This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states.
Args:
- data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner`
- guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5.
- is_negative_prompt (bool): use negative prompt t5 in uncondition if true
condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video.
- num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n"
- condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference
- target_h (int): final stitched latent height
- target_w (int): final stitched latent width
- patch_h (int): latent patch height for each network inference
- patch_w (int): latent patch width for each network inference
Returns:
- Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin
The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence.
"""
# data_batch should be the one processed by self.get_data_and_condition
if is_negative_prompt:
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
else:
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
# Add conditions for long video generation.
if condition_latent is None:
condition_latent = torch.zeros(data_batch["latent_hint"].shape, **self.tensor_kwargs)
num_condition_t = 0
condition_video_augment_sigma_in_inference = 1000
if use_batch_processing:
condition = self.add_condition_video_indicator_and_video_input_mask(
condition_latent, condition, num_condition_t
)
uncondition = self.add_condition_video_indicator_and_video_input_mask(
condition_latent, uncondition, num_condition_t
)
else:
condition = self.add_condition_video_indicator_and_video_input_mask(
condition_latent[:1], condition, num_condition_t
)
uncondition = self.add_condition_video_indicator_and_video_input_mask(
condition_latent[:1], uncondition, num_condition_t
)
condition.video_cond_bool = True
uncondition.video_cond_bool = False # Not do cfg on condition frames
# Add extra conditions for ctrlnet.
latent_hint = data_batch["latent_hint"]
hint_key = data_batch["hint_key"]
setattr(condition, hint_key, latent_hint)
if "use_none_hint" in data_batch and data_batch["use_none_hint"]:
setattr(uncondition, hint_key, None)
else:
setattr(uncondition, hint_key, latent_hint)
# Add extra conditions for ctrlnet.
# Handle regional prompting information
if "regional_contexts" in data_batch and "region_masks" in data_batch:
setattr(condition, "regional_contexts", data_batch["regional_contexts"])
setattr(condition, "region_masks", data_batch["region_masks"])
# For unconditioned generation, we still need the region masks but not the regional contexts
setattr(uncondition, "region_masks", data_batch["region_masks"])
setattr(uncondition, "regional_contexts", None)
to_cp = self.net.is_context_parallel_enabled
# For inference, check if parallel_state is initialized
if parallel_state.is_initialized():
condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp)
uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp)
cp_group = parallel_state.get_context_parallel_group()
latent_hint = getattr(condition, hint_key)
seq_dim = 3 if latent_hint.ndim == 6 else 2
latent_hint = split_inputs_cp(latent_hint, seq_dim=seq_dim, cp_group=cp_group)
setattr(condition, hint_key, latent_hint)
if getattr(uncondition, hint_key) is not None:
setattr(uncondition, hint_key, latent_hint)
if hasattr(condition, "regional_contexts") and getattr(condition, "regional_contexts") is not None:
regional_contexts = getattr(condition, "regional_contexts")
regional_contexts = split_inputs_cp(regional_contexts, seq_dim=2, cp_group=cp_group)
setattr(condition, "regional_contexts", regional_contexts)
if hasattr(condition, "region_masks") and getattr(condition, "region_masks") is not None:
region_masks = getattr(condition, "region_masks")
region_masks = split_inputs_cp(region_masks, seq_dim=2, cp_group=cp_group)
setattr(condition, "region_masks", region_masks)
setattr(uncondition, "region_masks", region_masks)
setattr(condition, "base_model", self.model.base_model)
setattr(uncondition, "base_model", self.model.base_model)
if hasattr(self, "hint_encoders"):
self.model.net.hint_encoders = self.hint_encoders
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor):
B = noise_x.shape[0] # Batch dimension
w, h = target_w, target_h
n_img_w = (w - 1) // patch_w + 1
n_img_h = (h - 1) // patch_h + 1
overlap_size_w = overlap_size_h = 0
if n_img_w > 1:
overlap_size_w = (n_img_w * patch_w - w) // (n_img_w - 1)
assert n_img_w * patch_w - overlap_size_w * (n_img_w - 1) == w
if n_img_h > 1:
overlap_size_h = (n_img_h * patch_h - h) // (n_img_h - 1)
assert n_img_h * patch_h - overlap_size_h * (n_img_h - 1) == h
if use_batch_processing:
condition.gt_latent = condition_latent
uncondition.gt_latent = condition_latent
setattr(condition, hint_key, latent_hint)
if getattr(uncondition, hint_key) is not None:
setattr(uncondition, hint_key, latent_hint)
# Batch denoising
cond_x0 = self.denoise(
noise_x,
sigma,
condition,
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
seed=seed,
).x0_pred_replaced
uncond_x0 = self.denoise(
noise_x,
sigma,
uncondition,
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
seed=seed,
).x0_pred_replaced
x0 = cond_x0 + guidance * (cond_x0 - uncond_x0)
merged = merge_patches_into_video(x0, overlap_size_h, overlap_size_w, n_img_h, n_img_w)
return split_video_into_patches(merged, patch_h, patch_w)
batch_images = noise_x
batch_sigma = sigma
output = []
for idx, cur_images in enumerate(batch_images):
noise_x = cur_images.unsqueeze(0)
sigma = batch_sigma[idx : idx + 1]
condition.gt_latent = condition_latent[idx : idx + 1]
uncondition.gt_latent = condition_latent[idx : idx + 1]
setattr(condition, hint_key, latent_hint[idx : idx + 1])
if getattr(uncondition, hint_key) is not None:
setattr(uncondition, hint_key, latent_hint[idx : idx + 1])
cond_x0 = self.denoise(
noise_x,
sigma,
condition,
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
seed=seed,
).x0_pred_replaced
uncond_x0 = self.denoise(
noise_x,
sigma,
uncondition,
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
seed=seed,
).x0_pred_replaced
x0 = cond_x0 + guidance * (cond_x0 - uncond_x0)
output.append(x0)
output = rearrange(torch.stack(output), "(n t) b ... -> (b n t) ...", n=n_img_h, t=n_img_w)
final_output = merge_patches_into_video(output, overlap_size_h, overlap_size_w, n_img_h, n_img_w)
final_output = split_video_into_patches(final_output, patch_h, patch_w)
return final_output
return x0_fn
def generate_samples_from_batch(
self,
data_batch: Dict,
guidance: float = 1.5,
seed: int = 1,
state_shape: Tuple | None = None,
n_sample: int | None = None,
is_negative_prompt: bool = False,
num_steps: int = 35,
condition_latent: Union[torch.Tensor, None] = None,
num_condition_t: Union[int, None] = None,
condition_video_augment_sigma_in_inference: float = None,
x_sigma_max: Optional[torch.Tensor] = None,
sigma_max: float | None = None,
target_h: int = 88,
target_w: int = 160,
patch_h: int = 88,
patch_w: int = 160,
use_batch_processing: bool = True,
) -> Tensor:
"""
Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples.
Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given.
If this feature is stablized, we could consider to move this function to the base model.
Args:
condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video.
num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half
"""
assert patch_h <= target_h and patch_w <= target_w
if n_sample is None:
input_key = self.input_data_key
n_sample = data_batch[input_key].shape[0]
if state_shape is None:
log.debug(f"Default Video state shape is used. {self.state_shape}")
state_shape = self.state_shape
x0_fn = self.get_x0_fn_from_batch(
data_batch,
guidance,
is_negative_prompt=is_negative_prompt,
condition_latent=condition_latent,
num_condition_t=num_condition_t,
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
seed=seed,
target_h=target_h,
target_w=target_w,
patch_h=patch_h,
patch_w=patch_w,
use_batch_processing=use_batch_processing,
)
if sigma_max is None:
sigma_max = self.sde.sigma_max
if x_sigma_max is None:
x_sigma_max = (
misc.arch_invariant_rand(
(n_sample,) + tuple(state_shape),
torch.float32,
self.tensor_kwargs["device"],
seed,
)
* sigma_max
)
if self.net.is_context_parallel_enabled:
x_sigma_max = broadcast(x_sigma_max, to_tp=False, to_cp=True)
x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group)
samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max)
if self.net.is_context_parallel_enabled:
samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group)
return samples
class VideoDiffusionT2VModelWithCtrl(DiffusionT2WModel):
def build_model(self) -> torch.nn.ModuleDict:
log.info("Start creating base model")
base_model = super().build_model()
# initialize base model
config = self.config
self.load_base_model(base_model)
log.info("Done creating base model")
log.info("Start creating ctrlnet model")
net = lazy_instantiate(self.config.net_ctrl)
conditioner = base_model.conditioner
logvar = base_model.logvar
# initialize controlnet encoder
model = torch.nn.ModuleDict({"net": net, "conditioner": conditioner, "logvar": logvar})
model.load_state_dict(base_model.state_dict(), strict=False)
model.base_model = base_model
log.info("Done creating ctrlnet model")
self.hint_key = self.config.hint_key["hint_key"]
return model
@property
def base_net(self):
return self.model.base_model.net
@property
def conditioner(self):
return self.model.conditioner
def load_base_model(self, base_model) -> None:
config = self.config
if config.base_load_from is not None:
checkpoint_path = config.base_load_from["load_path"]
else:
checkpoint_path = ""
if checkpoint_path:
log.info(f"Loading base model checkpoint (local): {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False)
log.success(f"Complete loading base model checkpoint (local): {checkpoint_path}")
if "ema" in state_dict:
# Copy the base model weights from ema model.
log.info("Copying ema to base model")
base_state_dict = {k.replace("-", "."): v for k, v in state_dict["ema"].items()}
elif "model" in state_dict:
# Copy the base model weights from reg model.
log.warning("Using non-EMA base model")
base_state_dict = state_dict["model"]
else:
log.info("Loading from an EMA only model")
base_state_dict = state_dict
base_model.load_state_dict(base_state_dict, strict=False)
log.info("Done loading the base model checkpoint.")
def get_data_and_condition(
self, data_batch: dict[str, Tensor], **kwargs
) -> Tuple[Tensor, VideoConditionerWithCtrl]:
# process the control input
hint_key = self.config.hint_key["hint_key"]
_data = {hint_key: data_batch[hint_key]}
if IS_PREPROCESSED_KEY in data_batch:
_data[IS_PREPROCESSED_KEY] = data_batch[IS_PREPROCESSED_KEY]
data_batch[hint_key] = _data[hint_key]
data_batch["hint_key"] = hint_key
raw_state, latent_state, condition = super().get_data_and_condition(data_batch, **kwargs)
use_multicontrol = (
("control_weight" in data_batch)
and not isinstance(data_batch["control_weight"], float)
and data_batch["control_weight"].shape[0] > 1
)
if use_multicontrol: # encode individual conditions separately
latent_hint = []
num_conditions = data_batch[data_batch["hint_key"]].size(1) // 3
for i in range(num_conditions):
cond_mask = [False] * num_conditions
cond_mask[i] = True
latent_hint += [self.encode_latent(data_batch, cond_mask=cond_mask)]
latent_hint = torch.cat(latent_hint)
else:
latent_hint = self.encode_latent(data_batch)
# add extra conditions
data_batch["latent_hint"] = latent_hint
setattr(condition, hint_key, latent_hint)
setattr(condition, "base_model", self.model.base_model)
return raw_state, latent_state, condition
def get_x_from_clean(
self,
in_clean_img: torch.Tensor,
sigma_max: float | None,
seed: int = 1,
) -> Tensor:
"""
in_clean_img (torch.Tensor): input clean image for image-to-image/video-to-video by adding noise then denoising
sigma_max (float): maximum sigma applied to in_clean_image for image-to-image/video-to-video
"""
if in_clean_img is None:
return None
generator = torch.Generator(device=self.tensor_kwargs["device"])
generator.manual_seed(seed)
noise = torch.randn(*in_clean_img.shape, **self.tensor_kwargs, generator=generator)
if sigma_max is None:
sigma_max = self.sde.sigma_max
x_sigma_max = in_clean_img + noise * sigma_max
return x_sigma_max
def encode_latent(self, data_batch: dict, cond_mask: list = []) -> torch.Tensor:
x = data_batch[data_batch["hint_key"]]
latent = []
# control input goes through tokenizer, which always takes 3-input channels
num_conditions = x.size(1) // 3 # input conditions were concatenated along channel dimension
if num_conditions > 1 and self.config.hint_dropout_rate > 0:
if torch.is_grad_enabled(): # during training, randomly dropout some conditions
cond_mask = torch.rand(num_conditions) > self.config.hint_dropout_rate
if not cond_mask.any(): # make sure at least one condition is present
cond_mask = [True] * num_conditions
elif not cond_mask: # during inference, use hint_mask to indicate which conditions are used
cond_mask = self.config.hint_mask
else:
cond_mask = [True] * num_conditions
for idx in range(0, x.size(1), 3):
x_rgb = x[:, idx : idx + 3]
if not cond_mask[idx // 3]: # if the condition is not selected, replace with a black image
x_rgb = torch.zeros_like(x_rgb)
latent.append(self.encode(x_rgb))
latent = torch.cat(latent, dim=1)
return latent
def get_x0_fn_from_batch(
self,
data_batch: Dict,
guidance: float = 1.5,
is_negative_prompt: bool = False,
condition_latent: torch.Tensor = None,
num_condition_t: Union[int, None] = None,
condition_video_augment_sigma_in_inference: float = None,
) -> Callable:
"""
Generates a callable function `x0_fn` based on the provided data batch and guidance factor.
This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states.
Args:
- data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner`
- guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5.
- is_negative_prompt (bool): use negative prompt t5 in uncondition if true
condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video.
- num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n"
- condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference
Returns:
- Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin
The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence.
"""
# data_batch should be the one processed by self.get_data_and_condition
if is_negative_prompt:
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
else:
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
# Add extra conditions for ctrlnet.
latent_hint = data_batch["latent_hint"]
hint_key = data_batch["hint_key"]
setattr(condition, hint_key, latent_hint)
if "use_none_hint" in data_batch and data_batch["use_none_hint"]:
setattr(uncondition, hint_key, None)
else:
setattr(uncondition, hint_key, latent_hint)
# Handle regional prompting information
if "regional_contexts" in data_batch and "region_masks" in data_batch:
setattr(condition, "regional_contexts", data_batch["regional_contexts"])
setattr(condition, "region_masks", data_batch["region_masks"])
# For unconditioned generation, we still need the region masks but not the regional contexts
setattr(uncondition, "region_masks", data_batch["region_masks"])
setattr(uncondition, "regional_contexts", None)
to_cp = self.net.is_context_parallel_enabled
# For inference, check if parallel_state is initialized
if parallel_state.is_initialized():
condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp)
uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp)
cp_group = parallel_state.get_context_parallel_group()
latent_hint = getattr(condition, hint_key)
seq_dim = 3 if latent_hint.ndim == 6 else 2
latent_hint = split_inputs_cp(latent_hint, seq_dim=seq_dim, cp_group=cp_group)
setattr(condition, hint_key, latent_hint)
if getattr(uncondition, hint_key) is not None:
setattr(uncondition, hint_key, latent_hint)
if hasattr(condition, "regional_contexts") and getattr(condition, "regional_contexts") is not None:
regional_contexts = getattr(condition, "regional_contexts")
regional_contexts = split_inputs_cp(regional_contexts, seq_dim=2, cp_group=cp_group)
setattr(condition, "regional_contexts", regional_contexts)
if hasattr(condition, "region_masks") and getattr(condition, "region_masks") is not None:
region_masks = getattr(condition, "region_masks")
region_masks = split_inputs_cp(region_masks, seq_dim=2, cp_group=cp_group)
setattr(condition, "region_masks", region_masks)
setattr(uncondition, "region_masks", region_masks)
setattr(condition, "base_model", self.model.base_model)
setattr(uncondition, "base_model", self.model.base_model)
if hasattr(self, "hint_encoders"):
self.model.net.hint_encoders = self.hint_encoders
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
cond_x0 = self.denoise(
noise_x,
sigma,
condition,
).x0
uncond_x0 = self.denoise(
noise_x,
sigma,
uncondition,
).x0
return cond_x0 + guidance * (cond_x0 - uncond_x0)
return x0_fn
def generate_samples_from_batch(
self,
data_batch: Dict,
guidance: float = 1.5,
seed: int = 1,
state_shape: Tuple | None = None,
n_sample: int | None = None,
is_negative_prompt: bool = False,
num_steps: int = 35,
condition_latent: Union[torch.Tensor, None] = None,
num_condition_t: Union[int, None] = None,
condition_video_augment_sigma_in_inference: float = None,
x_sigma_max: Optional[torch.Tensor] = None,
sigma_max: float | None = None,
**kwargs,
) -> Tensor:
"""
Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples.
Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given.
If this feature is stablized, we could consider to move this function to the base model.
Args:
condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video.
num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half
"""
if n_sample is None:
input_key = self.input_data_key
n_sample = data_batch[input_key].shape[0]
if state_shape is None:
log.debug(f"Default Video state shape is used. {self.state_shape}")
state_shape = self.state_shape
x0_fn = self.get_x0_fn_from_batch(
data_batch,
guidance,
is_negative_prompt=is_negative_prompt,
condition_latent=condition_latent,
num_condition_t=num_condition_t,
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
)
if sigma_max is None:
sigma_max = self.sde.sigma_max
if x_sigma_max is None:
x_sigma_max = (
misc.arch_invariant_rand(
(n_sample,) + tuple(state_shape),
torch.float32,
self.tensor_kwargs["device"],
seed,
)
* sigma_max
)
if self.net.is_context_parallel_enabled:
x_sigma_max = broadcast(x_sigma_max, to_tp=False, to_cp=True)
x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group)
samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max)
if self.net.is_context_parallel_enabled:
samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group)
return samples
class VideoDistillModelWithCtrl(DistillV2WModel):
def build_model(self) -> torch.nn.ModuleDict:
log.info("Start creating base model")
base_model = super().build_model()
# initialize base model
log.info("Done creating base model")
log.info("Start creating ctrlnet model")
net = DistillControlNet(self.config)
net.base_model.net.load_state_dict(base_model["net"].state_dict())
conditioner = base_model.conditioner
logvar = base_model.logvar
# initialize controlnet encoder
model = torch.nn.ModuleDict({"net": net, "conditioner": conditioner, "logvar": logvar})
model.base_model = net.base_model.net
self.hint_key = self.config.hint_key["hint_key"]
return model
@property
def base_net(self):
return self.model.base_model.net
@property
def conditioner(self):
return self.model.conditioner
def get_data_and_condition(
self, data_batch: dict[str, Tensor], **kwargs
) -> Tuple[Tensor, VideoConditionerWithCtrl]:
# process the control input
hint_key = self.config.hint_key["hint_key"]
_data = {hint_key: data_batch[hint_key]}
if IS_PREPROCESSED_KEY in data_batch:
_data[IS_PREPROCESSED_KEY] = data_batch[IS_PREPROCESSED_KEY]
data_batch[hint_key] = _data[hint_key]
data_batch["hint_key"] = hint_key
raw_state, latent_state, condition = super().get_data_and_condition(data_batch, **kwargs)
use_multicontrol = (
("control_weight" in data_batch)
and not isinstance(data_batch["control_weight"], float)
and data_batch["control_weight"].shape[0] > 1
)
if use_multicontrol: # encode individual conditions separately
latent_hint = []
num_conditions = data_batch[data_batch["hint_key"]].size(1) // 3
for i in range(num_conditions):
cond_mask = [False] * num_conditions
cond_mask[i] = True
latent_hint += [self.encode_latent(data_batch, cond_mask=cond_mask)]
latent_hint = torch.cat(latent_hint)
else:
latent_hint = self.encode_latent(data_batch)
# add extra conditions
data_batch["latent_hint"] = latent_hint
setattr(condition, hint_key, latent_hint)
setattr(condition, "base_model", self.model.base_model)
return raw_state, latent_state, condition
def get_x_from_clean(
self,
in_clean_img: torch.Tensor,
sigma_max: float | None,
seed: int = 1,
) -> Tensor:
"""
in_clean_img (torch.Tensor): input clean image for image-to-image/video-to-video by adding noise then denoising
sigma_max (float): maximum sigma applied to in_clean_image for image-to-image/video-to-video
"""
if in_clean_img is None:
return None
generator = torch.Generator(device=self.tensor_kwargs["device"])
generator.manual_seed(seed)
noise = torch.randn(*in_clean_img.shape, **self.tensor_kwargs, generator=generator)
if sigma_max is None:
sigma_max = self.sde.sigma_max
x_sigma_max = in_clean_img + noise * sigma_max
return x_sigma_max
def encode_latent(self, data_batch: dict, cond_mask: list = []) -> torch.Tensor:
x = data_batch[data_batch["hint_key"]]
latent = []
# control input goes through tokenizer, which always takes 3-input channels
num_conditions = x.size(1) // 3 # input conditions were concatenated along channel dimension
if num_conditions > 1 and self.config.hint_dropout_rate > 0:
if torch.is_grad_enabled(): # during training, randomly dropout some conditions
cond_mask = torch.rand(num_conditions) > self.config.hint_dropout_rate
if not cond_mask.any(): # make sure at least one condition is present
cond_mask = [True] * num_conditions
elif not cond_mask: # during inference, use hint_mask to indicate which conditions are used
cond_mask = self.config.hint_mask
else:
cond_mask = [True] * num_conditions
for idx in range(0, x.size(1), 3):
x_rgb = x[:, idx : idx + 3]
if not cond_mask[idx // 3]: # if the condition is not selected, replace with a black image
x_rgb = torch.zeros_like(x_rgb)
latent.append(self.encode(x_rgb))
latent = torch.cat(latent, dim=1)
return latent
def generate_samples_from_batch(
self,
data_batch: Dict,
guidance: float = 1.5,
seed: int = 1,
state_shape: Tuple | None = None,
n_sample: int | None = None,
is_negative_prompt: bool = False,
num_steps: int = 1, # Ignored for distilled models
condition_latent: Union[torch.Tensor, None] = None,
num_condition_t: Union[int, None] = None,
condition_video_augment_sigma_in_inference: float = None,
x_sigma_max: Optional[torch.Tensor] = None,
sigma_max: float | None = None,
target_h: int = 88,
target_w: int = 160,
patch_h: int = 88,
patch_w: int = 160,
**kwargs,
) -> torch.Tensor:
"""Single-step generation matching internal distilled model"""
# Same preprocessing as base class
self._normalize_video_databatch_inplace(data_batch)
self._augment_image_dim_inplace(data_batch)
if n_sample is None:
# input_key = self.input_image_key if is_image_batch else self.input_data_key
input_key = self.input_data_key
n_sample = data_batch[input_key].shape[0]
if state_shape is None:
log.debug(f"Default Video state shape is used. {self.state_shape}")
state_shape = self.state_shape
if sigma_max is None:
sigma_max = self.sde.sigma_max
if x_sigma_max is None:
x_sigma_max = (
misc.arch_invariant_rand(
(n_sample,) + tuple(state_shape),
torch.float32,
self.tensor_kwargs["device"],
seed,
)
* sigma_max
)
# Generate initial noise
batch_shape = (n_sample, *state_shape)
generator = torch.Generator(device=self.tensor_kwargs["device"])
generator.manual_seed(seed)
random_noise = torch.randn(*batch_shape, generator=generator, **self.tensor_kwargs)
if is_negative_prompt:
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
else:
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
# Handle conditioning
if condition_latent is None:
condition_latent = torch.zeros(data_batch["latent_hint"].shape, **self.tensor_kwargs)
num_condition_t = 0
condition_video_augment_sigma_in_inference = 1000
condition.video_cond_bool = True
condition = self.add_condition_video_indicator_and_video_input_mask(
condition_latent, condition, num_condition_t
)
uncondition.video_cond_bool = True # Not do cfg on condition frames
uncondition = self.add_condition_video_indicator_and_video_input_mask(
condition_latent, uncondition, num_condition_t
)
uncondition.condition_video_indicator = condition.condition_video_indicator.clone()
uncondition.condition_video_input_mask = condition.condition_video_input_mask.clone()
latent_hint = data_batch["latent_hint"]
hint_key = data_batch["hint_key"]
setattr(condition, hint_key, latent_hint)
if "use_none_hint" in data_batch and data_batch["use_none_hint"]:
setattr(uncondition, hint_key, None)
else:
setattr(uncondition, hint_key, latent_hint)
to_cp = self.net.is_context_parallel_enabled
# For inference, check if parallel_state is initialized
if parallel_state.is_initialized():
condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp)
uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp)
cp_group = parallel_state.get_context_parallel_group()
latent_hint = getattr(condition, hint_key)
seq_dim = 3 if latent_hint.ndim == 6 else 2
latent_hint = split_inputs_cp(latent_hint, seq_dim=seq_dim, cp_group=cp_group)
setattr(condition, hint_key, latent_hint)
if getattr(uncondition, hint_key) is not None:
setattr(uncondition, hint_key, latent_hint)
# not sure if this is consistent w the new distilled model?
setattr(condition, "base_model", self.model.base_model)
setattr(uncondition, "base_model", self.model.base_model)
if hasattr(self, "hint_encoders"):
self.model.net.hint_encoders = self.hint_encoders
cp_enabled = self.net.is_context_parallel_enabled
if cp_enabled:
random_noise = split_inputs_cp(x=random_noise, seq_dim=2, cp_group=self.net.cp_group)
condition.gt_latent = condition_latent
uncondition.gt_latent = condition_latent
if self.net.is_context_parallel_enabled:
x_sigma_max = broadcast(x_sigma_max, to_tp=False, to_cp=True)
x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group)
samples = self._forward_distilled(
epsilon=random_noise,
condition=condition,
uncondition=uncondition,
guidance=guidance,
hint_key=hint_key,
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
target_h=target_h,
target_w=target_w,
patch_h=patch_h,
patch_w=patch_w,
seed=seed,
inference_mode=True,
**kwargs,
)
cp_enabled = self.net.is_context_parallel_enabled
if cp_enabled:
samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group)
return samples
def _forward_distilled(
self,
epsilon: torch.Tensor,
condition: Any,
uncondition: Any,
guidance: float,
hint_key: str,
condition_video_augment_sigma_in_inference: float = 0.001,
target_h: int = 88,
target_w: int = 160,
patch_h: int = 88,
patch_w: int = 160,
seed: int = 1,
inference_mode: bool = True,
**kwargs,
) -> torch.Tensor:
"""Single forward pass for distilled models"""
B = epsilon.shape[0] # Batch dimension
w, h = target_w, target_h
n_img_w = (w - 1) // patch_w + 1
n_img_h = (h - 1) // patch_h + 1
overlap_size_w = overlap_size_h = 0
if n_img_w > 1:
overlap_size_w = (n_img_w * patch_w - w) // (n_img_w - 1)
assert n_img_w * patch_w - overlap_size_w * (n_img_w - 1) == w
if n_img_h > 1:
overlap_size_h = (n_img_h * patch_h - h) // (n_img_h - 1)
assert n_img_h * patch_h - overlap_size_h * (n_img_h - 1) == h
# Single denoising step at sigma_max
sigma_max = torch.tensor(self.sde.sigma_max).repeat(epsilon.size(0)).to(epsilon.device)
# Direct network forward pass - no iterative sampling
with torch.no_grad():
cond_x0 = self.denoise(
noise_x=epsilon * self.sde.sigma_max, # Scale noise to sigma_max
sigma=sigma_max,
condition=condition,
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
seed=seed,
).x0_pred_replaced
return cond_x0
|