JasonSmithSO's picture
Upload 578 files
8866644 verified
import torch
import comfy.sd
import comfy.utils
from comfy import model_management
from comfy import diffusers_convert
class EXVAE(comfy.sd.VAE):
def __init__(self, model_path, model_conf, dtype=torch.float32):
self.latent_dim = model_conf["embed_dim"]
self.latent_scale = model_conf["embed_scale"]
self.device = model_management.vae_device()
self.offload_device = model_management.vae_offload_device()
self.vae_dtype = dtype
sd = comfy.utils.load_torch_file(model_path)
model = None
if model_conf["type"] == "AutoencoderKL":
from .models.kl import AutoencoderKL
model = AutoencoderKL(config=model_conf)
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys():
sd = diffusers_convert.convert_vae_state_dict(sd)
elif model_conf["type"] == "AutoencoderKL-VideoDecoder":
from .models.temporal_ae import AutoencoderKL
model = AutoencoderKL(config=model_conf)
elif model_conf["type"] == "VQModel":
from .models.vq import VQModel
model = VQModel(config=model_conf)
elif model_conf["type"] == "ConsistencyDecoder":
from .models.consistencydecoder import ConsistencyDecoder
model = ConsistencyDecoder()
sd = {f"model.{k}":v for k,v in sd.items()}
elif model_conf["type"] == "MoVQ3":
from .models.movq3 import MoVQ
model = MoVQ(model_conf)
else:
raise NotImplementedError(f"Unknown VAE type '{model_conf['type']}'")
self.first_stage_model = model.eval()
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0: print("Missing VAE keys", m)
if len(u) > 0: print("Leftover VAE keys", u)
self.first_stage_model.to(self.vae_dtype).to(self.offload_device)
### Encode/Decode functions below needed due to source repo having 4 VAE channels and a scale factor of 8 hardcoded
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
output = torch.clamp((
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.latent_scale, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.latent_scale, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.latent_scale, pbar = pbar))
/ 3.0) / 2.0, min=0.0, max=1.0)
return output
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar)
samples /= 3.0
return samples
def decode(self, samples_in):
self.first_stage_model = self.first_stage_model.to(self.device)
try:
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.latent_scale), round(samples_in.shape[3] * self.latent_scale)), device="cpu")
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples
def encode(self, pixel_samples):
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1)
try:
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_dim, round(pixel_samples.shape[2] // self.latent_scale), round(pixel_samples.shape[3] // self.latent_scale)), device="cpu")
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples