Spaces:
Configuration error
Configuration error
import torch | |
from torch.nn import Linear | |
from types import MethodType | |
import comfy.model_management | |
import comfy.samplers | |
from comfy.cldm.cldm import ControlNet | |
from comfy.controlnet import ControlLora | |
def patch_controlnet(model, control_net): | |
import comfy.controlnet | |
if isinstance(control_net, ControlLora): | |
del_keys = [] | |
for k in control_net.control_weights: | |
if k.startswith("label_emb.0.0."): | |
del_keys.append(k) | |
for k in del_keys: | |
control_net.control_weights.pop(k) | |
super_pre_run = ControlLora.pre_run | |
super_copy = ControlLora.copy | |
super_forward = ControlNet.forward | |
def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs): | |
with torch.cuda.amp.autocast(enabled=True): | |
context = model.model.diffusion_model.encoder_hid_proj(context) | |
return super_forward(self, x, hint, timesteps, context, **kwargs) | |
def KolorsControlLora_pre_run(self, *args, **kwargs): | |
result = super_pre_run(self, *args, **kwargs) | |
if hasattr(self, "control_model"): | |
self.control_model.forward = MethodType( | |
KolorsControlNet_forward, self.control_model) | |
return result | |
control_net.pre_run = MethodType( | |
KolorsControlLora_pre_run, control_net) | |
def KolorsControlLora_copy(self, *args, **kwargs): | |
c = super_copy(self, *args, **kwargs) | |
c.pre_run = MethodType( | |
KolorsControlLora_pre_run, c) | |
return c | |
control_net.copy = MethodType(KolorsControlLora_copy, control_net) | |
elif isinstance(control_net, comfy.controlnet.ControlNet): | |
model_label_emb = model.model.diffusion_model.label_emb | |
control_net.control_model.label_emb = model_label_emb | |
control_net.control_model_wrapped.model.label_emb = model_label_emb | |
super_forward = ControlNet.forward | |
def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs): | |
with torch.cuda.amp.autocast(enabled=True): | |
context = model.model.diffusion_model.encoder_hid_proj(context) | |
return super_forward(self, x, hint, timesteps, context, **kwargs) | |
control_net.control_model.forward = MethodType( | |
KolorsControlNet_forward, control_net.control_model) | |
else: | |
raise NotImplementedError(f"Type {control_net} not supported for KolorsControlNetPatch") | |
return control_net | |