Spaces:
Configuration error
Configuration error
#credit to ExponentialML for this module | |
#from https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter | |
import os | |
import torch | |
import comfy | |
from einops import rearrange | |
from comfy import model_base, model_management | |
from .lvdm.modules.networks.openaimodel3d import UNetModel as DynamiCrafterUNetModel | |
from .utils.model_utils import DynamiCrafterBase, DYNAMICRAFTER_CONFIG, load_image_proj_dict, load_dynamicrafter_dict, get_image_proj_model | |
class DynamiCrafter: | |
def __init__(self): | |
self.model_patcher = None | |
# There is probably a better way to do this, but with the apply_model callback, this seems necessary. | |
# The model gets wrapped around a CFG Denoiser class, and handles the conditioning parts there. | |
# We cannot access it, so we must find the conditioning according to how ComfyUI handles it. | |
def get_conditioning_pair(self, c_crossattn, use_cfg: bool): | |
if not use_cfg: | |
return c_crossattn | |
conditioning_group = [] | |
for i in range(c_crossattn.shape[0]): | |
# Get the positive and negative conditioning. | |
positive_idx = i + 1 | |
negative_idx = i | |
if positive_idx >= c_crossattn.shape[0]: | |
break | |
if not torch.equal(c_crossattn[[positive_idx]], c_crossattn[[negative_idx]]): | |
conditioning_group = [ | |
c_crossattn[[positive_idx]], | |
c_crossattn[[negative_idx]] | |
] | |
break | |
if len(conditioning_group) == 0: | |
raise ValueError("Could not get the appropriate conditioning group.") | |
return torch.cat(conditioning_group) | |
# apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond} | |
def _forward(self, *args): | |
transformer_options = self.model_patcher.model_options['transformer_options'] | |
conditioning = transformer_options['conditioning'] | |
apply_model = args[0] | |
# forward_dict | |
fd = args[1] | |
x, t, model_in_kwargs, _ = fd['input'], fd['timestep'], fd['c'], fd['cond_or_uncond'] | |
c_crossattn = model_in_kwargs.pop("c_crossattn") | |
c_concat = conditioning['c_concat'] | |
num_video_frames = conditioning['num_video_frames'] | |
fs = conditioning['fs'] | |
original_num_frames = num_video_frames | |
# Better way to determine if we're using CFG | |
# The cond batch will always be num_frames >= 2 since we're doing video, | |
# so we need get this condition differently here. | |
if x.shape[0] > num_video_frames: | |
num_video_frames *= 2 | |
batch_size = 2 | |
use_cfg = True | |
else: | |
use_cfg = False | |
batch_size = 1 | |
if use_cfg: | |
c_concat = torch.cat([c_concat] * 2) | |
self.validate_forwardable_latent(x, c_concat, num_video_frames, use_cfg) | |
x_in, c_concat = map(lambda xc: rearrange(xc, '(b t) c h w -> b c t h w', b=batch_size), (x, c_concat)) | |
# We always assume video, so there will always be batched conditionings. | |
c_crossattn = self.get_conditioning_pair(c_crossattn, use_cfg) | |
c_crossattn = c_crossattn[:2] if use_cfg else c_crossattn[:1] | |
context_in = c_crossattn | |
img_embs = conditioning['image_emb'] | |
if use_cfg: | |
img_emb_uncond = conditioning['image_emb_uncond'] | |
img_embs = torch.cat([img_embs, img_emb_uncond]) | |
fs = torch.cat([fs] * x_in.shape[0]) | |
outs = [] | |
for i in range(batch_size): | |
model_in_kwargs['transformer_options']['cond_idx'] = i | |
x_out = apply_model( | |
x_in[[i]], | |
t=torch.cat([t[:1]]), | |
context_in=context_in[[i]], | |
c_crossattn=c_crossattn, | |
cc_concat=c_concat[[i]], # "cc" is to handle naming conflict with apply_model wrapper. | |
# We want to handle this in the UNet forward. | |
num_video_frames=num_video_frames // 2 if batch_size > 1 else num_video_frames, | |
img_emb=img_embs[[i]], | |
fs=fs[[i]], | |
**model_in_kwargs | |
) | |
outs.append(x_out) | |
x_out = torch.cat(list(reversed(outs))) | |
x_out = rearrange(x_out, 'b c t h w -> (b t) c h w') | |
return x_out | |
def assign_forward_args( | |
self, | |
model, | |
c_concat, | |
image_emb, | |
image_emb_uncond, | |
fs, | |
frames, | |
): | |
model.model_options['transformer_options']['conditioning'] = { | |
"c_concat": c_concat, | |
"image_emb": image_emb, | |
'image_emb_uncond': image_emb_uncond, | |
"fs": fs, | |
"num_video_frames": frames, | |
} | |
def validate_forwardable_latent(self, latent, c_concat, num_video_frames, use_cfg): | |
check_no_cfg = latent.shape[0] != num_video_frames | |
check_with_cfg = latent.shape[0] != (num_video_frames * 2) | |
latent_batch_size = latent.shape[0] if not use_cfg else latent.shape[0] // 2 | |
num_frames = num_video_frames if not use_cfg else num_video_frames // 2 | |
if all([check_no_cfg, check_with_cfg]): | |
raise ValueError( | |
"Please make sure your latent inputs match the number of frames in the DynamiCrafter Processor." | |
f"Got a latent batch size of ({latent_batch_size}) with number of frames being ({num_frames})." | |
) | |
latent_h, latent_w = latent.shape[-2:] | |
c_concat_h, c_concat_w = c_concat.shape[-2:] | |
if not all([latent_h == c_concat_h, latent_w == c_concat_w]): | |
raise ValueError( | |
"Please make sure that your input latent and image frames are the same height and width.", | |
f"Image Size: {c_concat_w * 8}, {c_concat_h * 8}, Latent Size: {latent_h * 8}, {latent_w * 8}" | |
) | |
def process_image_conditioning( | |
self, | |
model, | |
clip_vision, | |
vae, | |
image_proj_model, | |
images, | |
use_interpolate, | |
fps: int, | |
frames: int, | |
scale_latents: bool | |
): | |
self.model_patcher = model | |
encoded_latent = vae.encode(images[:, :, :, :3]) | |
encoded_image = clip_vision.encode_image(images[:1])['last_hidden_state'] | |
image_emb = image_proj_model(encoded_image) | |
encoded_image_uncond = clip_vision.encode_image(torch.zeros_like(images)[:1])['last_hidden_state'] | |
image_emb_uncond = image_proj_model(encoded_image_uncond) | |
c_concat = encoded_latent | |
if scale_latents: | |
vae_process_input = vae.process_input | |
vae.process_input = lambda image: (image - .5) * 2 | |
c_concat = vae.encode(images[:, :, :, :3]) | |
vae.process_input = vae_process_input | |
c_concat = model.model.process_latent_in(c_concat) * 1.3 | |
else: | |
c_concat = model.model.process_latent_in(c_concat) | |
fs = torch.tensor([fps], dtype=torch.long, device=model_management.intermediate_device()) | |
model.set_model_unet_function_wrapper(self._forward) | |
used_interpolate_processing = False | |
if use_interpolate and frames > 16: | |
raise ValueError( | |
"When using interpolation mode, the maximum amount of frames are 16." | |
"If you're doing long video generation, consider using the last frame\ | |
from the first generation for the next one (autoregressive)." | |
) | |
if encoded_latent.shape[0] == 1: | |
c_concat = torch.cat([c_concat] * frames, dim=0)[:frames] | |
if use_interpolate: | |
mask = torch.zeros_like(c_concat) | |
mask[:1] = c_concat[:1] | |
c_concat = mask | |
used_interpolate_processing = True | |
else: | |
if use_interpolate and c_concat.shape[0] in [2, 3]: | |
input_frame_count = c_concat.shape[0] | |
# We're just padding to the same type an size of the concat | |
masked_frames = torch.zeros_like(torch.cat([c_concat[:1]] * frames))[:frames] | |
# Start frame | |
masked_frames[:1] = c_concat[:1] | |
end_frame_idx = -1 | |
# TODO | |
speed = 1.0 | |
if speed < 1.0: | |
possible_speeds = list(torch.linspace(0, 1.0, c_concat.shape[0])) | |
speed_from_frames = enumerate(possible_speeds) | |
speed_idx = min(speed_from_frames, key=lambda n: n[1] - speed)[0] | |
end_frame_idx = speed_idx | |
# End frame | |
masked_frames[-1:] = c_concat[[end_frame_idx]] | |
# Possible middle frame, but not working at the moment. | |
if input_frame_count == 3: | |
middle_idx = masked_frames.shape[0] // 2 | |
middle_idx_frame = c_concat.shape[0] // 2 | |
masked_frames[[middle_idx]] = c_concat[[middle_idx_frame]] | |
c_concat = masked_frames | |
used_interpolate_processing = True | |
print(f"Using interpolation mode with {input_frame_count} frames.") | |
if c_concat.shape[0] < frames and not used_interpolate_processing: | |
print( | |
"Multiple images found, but interpolation mode is unset. Using the first frame as condition.", | |
) | |
c_concat = torch.cat([c_concat[:1]] * frames) | |
c_concat = c_concat[:frames] | |
if encoded_latent.shape[0] == 1: | |
encoded_latent = torch.cat([encoded_latent] * frames)[:frames] | |
if encoded_latent.shape[0] < frames and encoded_latent.shape[0] != 1: | |
encoded_latent = torch.cat( | |
[encoded_latent] + [encoded_latent[-1:]] * abs(encoded_latent.shape[0] - frames) | |
)[:frames] | |
# We could store this as a state in this Node Class Instance, but to prevent any weird edge cases, | |
# this should always be passed through the 'stateless' way, and let ComfyUI handle the transformer_options state. | |
self.assign_forward_args(model, c_concat, image_emb, image_emb_uncond, fs, frames) | |
return (model, {"samples": torch.zeros_like(c_concat)}, {"samples": encoded_latent},) | |
# Loader for the DynamiCrafter model. | |
def load_model_sicts(self, model_path: str): | |
model_state_dict = comfy.utils.load_torch_file(model_path) | |
dynamicrafter_dict = load_dynamicrafter_dict(model_state_dict) | |
image_proj_dict = load_image_proj_dict(model_state_dict) | |
return dynamicrafter_dict, image_proj_dict | |
def get_prediction_type(self, is_eps: bool, model_config): | |
if not is_eps and "image_cross_attention_scale_learnable" in model_config.unet_config.keys(): | |
model_config.unet_config["image_cross_attention_scale_learnable"] = False | |
return model_base.ModelType.EPS if is_eps else model_base.ModelType.V_PREDICTION | |
def handle_model_management(self, dynamicrafter_dict: dict, model_config): | |
parameters = comfy.utils.calculate_parameters(dynamicrafter_dict, "model.diffusion_model.") | |
load_device = model_management.get_torch_device() | |
unet_dtype = model_management.unet_dtype( | |
model_params=parameters, | |
supported_dtypes=model_config.supported_inference_dtypes | |
) | |
manual_cast_dtype = model_management.unet_manual_cast( | |
unet_dtype, | |
load_device, | |
model_config.supported_inference_dtypes | |
) | |
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) | |
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) | |
offload_device = model_management.unet_offload_device() | |
return load_device, inital_load_device | |
def check_leftover_keys(self, state_dict: dict): | |
left_over = state_dict.keys() | |
if len(left_over) > 0: | |
print("left over keys:", left_over) | |
def load_dynamicrafter(self, model_path): | |
if os.path.exists(model_path): | |
dynamicrafter_dict, image_proj_dict = self.load_model_sicts(model_path) | |
model_config = DynamiCrafterBase(DYNAMICRAFTER_CONFIG) | |
dynamicrafter_dict, is_eps = model_config.process_dict_version(state_dict=dynamicrafter_dict) | |
MODEL_TYPE = self.get_prediction_type(is_eps, model_config) | |
load_device, inital_load_device = self.handle_model_management(dynamicrafter_dict, model_config) | |
model = model_base.BaseModel( | |
model_config, | |
model_type=MODEL_TYPE, | |
device=inital_load_device, | |
unet_model=DynamiCrafterUNetModel | |
) | |
image_proj_model = get_image_proj_model(image_proj_dict) | |
model.load_model_weights(dynamicrafter_dict, "model.diffusion_model.") | |
self.check_leftover_keys(dynamicrafter_dict) | |
model_patcher = comfy.model_patcher.ModelPatcher( | |
model, | |
load_device=load_device, | |
offload_device=model_management.unet_offload_device(), | |
current_device=inital_load_device | |
) | |
return (model_patcher, image_proj_model,) |