StupidGame's picture
Upload 1941 files
baa8e90
import torch
import os
from .resampler import Resampler
import contextlib
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention
from comfy.clip_vision import clip_preprocess
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
# attention_channels of input, output, middle
SD_V12_CHANNELS = [320] * 4 + [640] * 4 + [1280] * 4 + [1280] * 6 + [640] * 6 + [320] * 6 + [1280] * 2
SD_XL_CHANNELS = [640] * 8 + [1280] * 40 + [1280] * 60 + [640] * 12 + [1280] * 20
def get_file_list(path):
return [f for f in os.listdir(path) if f.endswith('.bin') or f.endswith('.safetensors')]
def set_model_patch_replace(model, patch_kwargs, key):
to = model.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if "attn2" not in to["patches_replace"]:
to["patches_replace"]["attn2"] = {}
if key not in to["patches_replace"]["attn2"]:
patch = CrossAttentionPatch(**patch_kwargs)
to["patches_replace"]["attn2"][key] = patch
else:
to["patches_replace"]["attn2"][key].set_new_condition(**patch_kwargs)
def load_ipadapter(ckpt_path):
model = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if ckpt_path.lower().endswith(".safetensors"):
st_model = {"image_proj": {}, "ip_adapter": {}}
for key in model.keys():
if key.startswith("image_proj."):
st_model["image_proj"][key.replace("image_proj.", "")] = model[key]
elif key.startswith("ip_adapter."):
st_model["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
# sort keys
model = {"image_proj": st_model["image_proj"], "ip_adapter": {}}
sorted_keys = sorted(st_model["ip_adapter"].keys(), key=lambda x: int(x.split(".")[0]))
for key in sorted_keys:
model["ip_adapter"][key] = st_model["ip_adapter"][key]
st_model = None
if not "ip_adapter" in model.keys() or not model["ip_adapter"]:
raise Exception("invalid IPAdapter model {}".format(ckpt_path))
return model
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
# Cross Attention to_k, to_v for IPAdapter
class To_KV(torch.nn.Module):
def __init__(self, cross_attention_dim):
super().__init__()
channels = SD_XL_CHANNELS if cross_attention_dim == 2048 else SD_V12_CHANNELS
self.to_kvs = torch.nn.ModuleList([torch.nn.Linear(cross_attention_dim, channel, bias=False) for channel in channels])
def load_state_dict(self, state_dict):
# input -> output -> middle
for i, key in enumerate(state_dict.keys()):
self.to_kvs[i].weight.data = state_dict[key]
class IPAdapterModel(torch.nn.Module):
def __init__(self, state_dict, plus, cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4, sdxl_plus=False):
super().__init__()
self.plus = plus
if self.plus:
self.image_proj_model = Resampler(
dim=1280 if sdxl_plus else cross_attention_dim,
depth=4,
dim_head=64,
heads=20 if sdxl_plus else 12,
num_queries=clip_extra_context_tokens,
embedding_dim=clip_embeddings_dim,
output_dim=cross_attention_dim,
ff_mult=4
)
else:
self.image_proj_model = ImageProjModel(
cross_attention_dim=cross_attention_dim,
clip_embeddings_dim=clip_embeddings_dim,
clip_extra_context_tokens=clip_extra_context_tokens
)
self.image_proj_model.load_state_dict(state_dict["image_proj"])
self.ip_layers = To_KV(cross_attention_dim)
self.ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode()
def get_image_embeds(self, cond, uncond):
image_prompt_embeds = self.image_proj_model(cond)
uncond_image_prompt_embeds = self.image_proj_model(uncond)
return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapter:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"image": ("IMAGE", ),
"clip_vision": ("CLIP_VISION", ),
"weight": ("FLOAT", {
"default": 1,
"min": -1, #Minimum value
"max": 3, #Maximum value
"step": 0.05 #Slider's step
}),
"model_name": (get_file_list(os.path.join(CURRENT_DIR,"models")), ),
"dtype": (["fp16", "fp32"], ),
},
"optional": {
"mask": ("MASK",),
}
}
RETURN_TYPES = ("MODEL", "CLIP_VISION_OUTPUT")
FUNCTION = "adapter"
CATEGORY = "loaders"
def adapter(self, model, image, clip_vision, weight, model_name, dtype, mask=None):
device = comfy.model_management.get_torch_device()
self.dtype = torch.float32 if dtype == "fp32" or device.type == "mps" else torch.float16
self.weight = weight # ip_adapter scale
ip_state_dict = load_ipadapter(os.path.join(CURRENT_DIR, os.path.join(CURRENT_DIR, "models", model_name)))
self.plus = "latents" in ip_state_dict["image_proj"]
# cross_attention_dim is equal to text_encoder output
self.cross_attention_dim = ip_state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1]
self.sdxl = self.cross_attention_dim == 2048
self.sdxl_plus = self.sdxl and self.plus
# number of tokens of ip_adapter embedding
if self.plus:
self.clip_extra_context_tokens = ip_state_dict["image_proj"]["latents"].shape[1]
else:
self.clip_extra_context_tokens = ip_state_dict["image_proj"]["proj.weight"].shape[0] // self.cross_attention_dim
cond, uncond, outputs = self.clip_vision_encode(clip_vision, image, self.plus)
self.clip_embeddings_dim = cond.shape[-1]
self.ipadapter = IPAdapterModel(
ip_state_dict,
plus = self.plus,
cross_attention_dim = self.cross_attention_dim,
clip_embeddings_dim = self.clip_embeddings_dim,
clip_extra_context_tokens = self.clip_extra_context_tokens,
sdxl_plus = self.sdxl_plus
)
self.ipadapter.to(device, dtype=self.dtype)
self.image_emb, self.uncond_image_emb = self.ipadapter.get_image_embeds(cond.to(device, dtype=self.dtype), uncond.to(device, dtype=self.dtype))
self.image_emb = self.image_emb.to(device, dtype=self.dtype)
self.uncond_image_emb = self.uncond_image_emb.to(device, dtype=self.dtype)
# Not sure of batch size at this point.
self.cond_uncond_image_emb = None
new_model = model.clone()
if mask is not None:
mask = mask.squeeze().to(device)
'''
patch_name of sdv1-2: ("input" or "output" or "middle", block_id)
patch_name of sdxl: ("input" or "output" or "middle", block_id, transformer_index)
'''
patch_kwargs = {
"number": 0,
"weight": self.weight,
"ipadapter": self.ipadapter,
"dtype": self.dtype,
"cond": self.image_emb,
"uncond": self.uncond_image_emb,
"mask": mask
}
if not self.sdxl:
for id in [1,2,4,5,7,8]: # id of input_blocks that have cross attention
set_model_patch_replace(new_model, patch_kwargs, ("input", id))
patch_kwargs["number"] += 1
for id in [3,4,5,6,7,8,9,10,11]: # id of output_blocks that have cross attention
set_model_patch_replace(new_model, patch_kwargs, ("output", id))
patch_kwargs["number"] += 1
set_model_patch_replace(new_model, patch_kwargs, ("middle", 0))
else:
for id in [4,5,7,8]: # id of input_blocks that have cross attention
block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth
for index in block_indices:
set_model_patch_replace(new_model, patch_kwargs, ("input", id, index))
patch_kwargs["number"] += 1
for id in range(6): # id of output_blocks that have cross attention
block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth
for index in block_indices:
set_model_patch_replace(new_model, patch_kwargs, ("output", id, index))
patch_kwargs["number"] += 1
for index in range(10):
set_model_patch_replace(new_model, patch_kwargs, ("middle", 0, index))
patch_kwargs["number"] += 1
return (new_model, outputs)
def clip_vision_encode(self, clip_vision, image, plus=False):
inputs = clip_preprocess(image)
comfy.model_management.load_model_gpu(clip_vision.patcher)
pixel_values = inputs.to(clip_vision.load_device)
if clip_vision.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(comfy.model_management.get_autocast_device(clip_vision.load_device), torch.float32):
outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True)
if plus:
cond = outputs.hidden_states[-2]
with precision_scope(comfy.model_management.get_autocast_device(clip_vision.load_device), torch.float32):
uncond = clip_vision.model(torch.zeros_like(pixel_values), output_hidden_states=True).hidden_states[-2]
else:
cond = outputs.image_embeds
uncond = torch.zeros_like(cond)
for k in outputs:
t = outputs[k]
if k == "hidden_states":
outputs[k] = None
elif t is not None:
outputs[k] = t.cpu()
return cond, uncond, outputs
class CrossAttentionPatch:
# forward for patching
def __init__(self, weight, ipadapter, dtype, number, cond, uncond, mask=None):
self.weights = [weight]
self.ipadapters = [ipadapter]
self.conds = [cond]
self.unconds = [uncond]
self.dtype = dtype
self.number = number
self.masks = [mask]
def set_new_condition(self, weight, ipadapter, cond, uncond, dtype, number, mask=None):
self.weights.append(weight)
self.ipadapters.append(ipadapter)
self.conds.append(cond)
self.unconds.append(uncond)
self.masks.append(mask)
self.dtype = dtype
def __call__(self, n, context_attn2, value_attn2, extra_options):
org_dtype = n.dtype
cond_or_uncond = extra_options["cond_or_uncond"]
original_shape = (extra_options["original_shape"][2], extra_options["original_shape"][3])
with torch.autocast("cuda", dtype=self.dtype):
q = n
k = context_attn2
v = value_attn2
b, _, _ = q.shape
batch_prompt = b // len(cond_or_uncond)
out = optimized_attention(q, k, v, extra_options["n_heads"])
for weight, cond, uncond, ipadapter, mask in zip(self.weights, self.conds, self.unconds, self.ipadapters, self.masks):
k_cond = ipadapter.ip_layers.to_kvs[self.number*2](cond).repeat(batch_prompt, 1, 1)
k_uncond = ipadapter.ip_layers.to_kvs[self.number*2](uncond).repeat(batch_prompt, 1, 1)
v_cond = ipadapter.ip_layers.to_kvs[self.number*2+1](cond).repeat(batch_prompt, 1, 1)
v_uncond = ipadapter.ip_layers.to_kvs[self.number*2+1](uncond).repeat(batch_prompt, 1, 1)
ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0)
ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0)
# Convert ip_k and ip_v to the same dtype as q
ip_k = ip_k.to(dtype=q.dtype)
ip_v = ip_v.to(dtype=q.dtype)
ip_out = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"])
if mask is not None:
# 良い方法募集
if original_shape[0] * original_shape[1] == q.shape[1]:
down_sample_rate = 1
elif (original_shape[0] // 2) * (original_shape[1] // 2) == q.shape[1]:
down_sample_rate = 2
elif (original_shape[0] // 4) * (original_shape[1] // 4) == q.shape[1]:
down_sample_rate = 4
else:
down_sample_rate = 8
mask_downsample = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(original_shape[0] // down_sample_rate, original_shape[1] // down_sample_rate), mode="nearest").squeeze(0)
mask_downsample = mask_downsample.view(1, -1, 1).repeat(out.shape[0], 1, out.shape[2])
ip_out = ip_out * mask_downsample
out = out + ip_out * weight
return out.to(dtype=org_dtype)