Spaces:
Configuration error
Configuration error
import torch | |
from collections import OrderedDict | |
from comfy import model_base | |
from comfy import utils | |
from comfy import diffusers_convert | |
try: | |
import comfy.text_encoders.sd2_clip | |
except ImportError: | |
from comfy import sd2_clip | |
from comfy import supported_models_base | |
from comfy import latent_formats | |
from ..lvdm.modules.encoders.resampler import Resampler | |
DYNAMICRAFTER_CONFIG = { | |
'in_channels': 8, | |
'out_channels': 4, | |
'model_channels': 320, | |
'attention_resolutions': [4, 2, 1], | |
'num_res_blocks': 2, | |
'channel_mult': [1, 2, 4, 4], | |
'num_head_channels': 64, | |
'transformer_depth': 1, | |
'context_dim': 1024, | |
'use_linear': True, | |
'use_checkpoint': False, | |
'temporal_conv': True, | |
'temporal_attention': True, | |
'temporal_selfatt_only': True, | |
'use_relative_position': False, | |
'use_causal_attention': False, | |
'temporal_length': 16, | |
'addition_attention': True, | |
'image_cross_attention': True, | |
'image_cross_attention_scale_learnable': True, | |
'default_fs': 3, | |
'fs_condition': True | |
} | |
IMAGE_PROJ_CONFIG = { | |
"dim": 1024, | |
"depth": 4, | |
"dim_head": 64, | |
"heads": 12, | |
"num_queries": 16, | |
"embedding_dim": 1280, | |
"output_dim": 1024, | |
"ff_mult": 4, | |
"video_length": 16 | |
} | |
def process_list_or_str(target_key_or_keys, k): | |
if isinstance(target_key_or_keys, list): | |
return any([list_k in k for list_k in target_key_or_keys]) | |
else: | |
return target_key_or_keys in k | |
def simple_state_dict_loader(state_dict: dict, target_key: str, target_dict: dict = None): | |
out_dict = {} | |
if target_dict is None: | |
for k, v in state_dict.items(): | |
if process_list_or_str(target_key, k): | |
out_dict[k] = v | |
else: | |
for k, v in target_dict.items(): | |
out_dict[k] = state_dict[k] | |
return out_dict | |
def load_image_proj_dict(state_dict: dict): | |
return simple_state_dict_loader(state_dict, 'image_proj') | |
def load_dynamicrafter_dict(state_dict: dict): | |
return simple_state_dict_loader(state_dict, 'model.diffusion_model') | |
def load_vae_dict(state_dict: dict): | |
return simple_state_dict_loader(state_dict, 'first_stage_model') | |
def get_base_model(state_dict: dict, version_checker=False): | |
is_256_model = False | |
for k in state_dict.keys(): | |
if "framestride_embed" in k: | |
is_256_model = True | |
break | |
def get_image_proj_model(state_dict: dict): | |
state_dict = {k.replace('image_proj_model.', ''): v for k, v in state_dict.items()} | |
#target_dict = Resampler().state_dict() | |
ImageProjModel = Resampler(**IMAGE_PROJ_CONFIG) | |
ImageProjModel.load_state_dict(state_dict) | |
print("Image Projection Model loaded successfully") | |
#del target_dict | |
return ImageProjModel | |
class DynamiCrafterBase(supported_models_base.BASE): | |
unet_config = {} | |
unet_extra_config = {} | |
latent_format = latent_formats.SD15 | |
def process_clip_state_dict(self, state_dict): | |
replace_prefix = {} | |
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format | |
replace_prefix["cond_stage_model.model."] = "clip_h." | |
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) | |
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.") | |
return state_dict | |
def process_clip_state_dict_for_saving(self, state_dict): | |
replace_prefix = {} | |
replace_prefix["clip_h"] = "cond_stage_model.model" | |
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) | |
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) | |
return state_dict | |
def clip_target(self): | |
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel) | |
def process_dict_version(self, state_dict: dict): | |
processed_dict = OrderedDict() | |
is_eps = False | |
for k in list(state_dict.keys()): | |
if "framestride_embed" in k: | |
new_key = k.replace("framestride_embed", "fps_embedding") | |
processed_dict[new_key] = state_dict[k] | |
is_eps = True | |
continue | |
processed_dict[k] = state_dict[k] | |
return processed_dict, is_eps | |