|
import comfy.samplers |
|
from comfy.k_diffusion.sampling import default_noise_sampler |
|
from tqdm.auto import trange, tqdm |
|
from itertools import product |
|
import torch |
|
|
|
@torch.no_grad() |
|
def sample_lcm_alt(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, euler_steps=0, ancestral=0.0, noise_mult = 1.0): |
|
extra_args = {} if extra_args is None else extra_args |
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler |
|
s_in = x.new_ones([x.shape[0]]) |
|
steps = len(sigmas)-1 |
|
euler_limit = euler_steps%steps |
|
loop_control = [True] * euler_limit + [False] * (steps - euler_limit) |
|
return sample_lcm_backbone(model, x, sigmas, extra_args, callback, disable, noise_sampler, loop_control, ancestral, noise_mult) |
|
|
|
@torch.no_grad() |
|
def sample_lcm_cycle(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, euler_steps = 1, lcm_steps = 1, tweak_sigmas = False, ancestral=0.0): |
|
extra_args = {} if extra_args is None else extra_args |
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler |
|
s_in = x.new_ones([x.shape[0]]) |
|
steps = len(sigmas) - 2 |
|
cycle_length = euler_steps + lcm_steps |
|
repeats = steps // (cycle_length) |
|
leftover = steps % (cycle_length) |
|
cycle = [True] * euler_steps + [False] * lcm_steps |
|
loop_control = cycle * repeats + cycle[-leftover:] |
|
if tweak_sigmas: |
|
index_map = torch.tensor([i + j * repeats for i,j in product(range(repeats),range(cycle_length))] + |
|
list(range(cycle_length*repeats,len(sigmas)))).to(sigmas.device) |
|
sigmas = torch.index_select(sigmas, 0, index_map) |
|
return sample_lcm_backbone(model, x, sigmas, extra_args, callback, disable, noise_sampler, loop_control, ancestral) |
|
|
|
@torch.no_grad() |
|
def sample_lcm_backbone(model, x, sigmas, extra_args, callback, disable, noise_sampler, loop_control, ancestral, noise_mult = 1.0): |
|
s_in = x.new_ones([x.shape[0]]) |
|
for i in trange(len(sigmas) - 1, disable=disable): |
|
denoised = model(x, sigmas[i] * s_in, **extra_args) |
|
if callback is not None: |
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) |
|
|
|
if sigmas[i + 1] > 0: |
|
if loop_control[i]: |
|
if ancestral < 1.0: |
|
removed_noise = (x - denoised) / sigmas[i] |
|
if ancestral > 0.0: |
|
noise = noise_sampler(sigmas[i], sigmas[i + 1]) |
|
if ancestral < 1.0: |
|
noise = (ancestral**0.5) * noise + ((1.0 - ancestral)**0.5) * removed_noise |
|
elif ancestral == 0.0: |
|
noise = removed_noise*noise_mult |
|
else: |
|
noise = noise_sampler(sigmas[i], sigmas[i + 1]) |
|
else: |
|
noise = None |
|
x = denoised |
|
if noise is not None: |
|
x += sigmas[i + 1] * noise |
|
return x |
|
|
|
class LCMScheduler: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": |
|
{"model": ("MODEL",), |
|
"steps": ("INT", {"default": 8, "min": 1, "max": 10000}), |
|
} |
|
} |
|
RETURN_TYPES = ("SIGMAS",) |
|
CATEGORY = "sampling/custom_sampling/schedulers" |
|
|
|
FUNCTION = "get_sigmas" |
|
|
|
def get_sigmas(self, model, steps): |
|
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, "sgm_uniform", steps).cpu() |
|
return (sigmas, ) |
|
|
|
class SamplerLCMAlternative: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": |
|
{"euler_steps": ("INT", {"default": 0, "min": -10000, "max": 10000}), |
|
"ancestral": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01, "round": True}), |
|
"noise_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.001, "round": True}), |
|
} |
|
} |
|
RETURN_TYPES = ("SAMPLER",) |
|
CATEGORY = "sampling/custom_sampling/samplers" |
|
|
|
FUNCTION = "get_sampler" |
|
|
|
def get_sampler(self, euler_steps, ancestral, noise_mult): |
|
sampler = comfy.samplers.KSAMPLER(sample_lcm_alt, extra_options={"euler_steps": euler_steps, "noise_mult": noise_mult, "ancestral": ancestral}) |
|
return (sampler, ) |
|
|
|
class SamplerLCMCycle: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": |
|
{"euler_steps": ("INT", {"default": 1, "min": 1, "max": 50}), |
|
"lcm_steps": ("INT", {"default": 2, "min": 1, "max": 50}), |
|
"tweak_sigmas": ("BOOLEAN", {"default": False}), |
|
"ancestral": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01, "round": False}), |
|
} |
|
} |
|
RETURN_TYPES = ("SAMPLER",) |
|
CATEGORY = "sampling/custom_sampling/samplers" |
|
|
|
FUNCTION = "get_sampler" |
|
|
|
def get_sampler(self, euler_steps, lcm_steps, tweak_sigmas, ancestral): |
|
sampler = comfy.samplers.KSAMPLER(sample_lcm_cycle, extra_options={"euler_steps": euler_steps, "lcm_steps": lcm_steps, "tweak_sigmas": tweak_sigmas, "ancestral": ancestral}) |
|
return (sampler, ) |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"LCMScheduler": LCMScheduler, |
|
"SamplerLCMAlternative": SamplerLCMAlternative, |
|
"SamplerLCMCycle": SamplerLCMCycle, |
|
} |
|
|