Spaces:
Configuration error
Configuration error
import torch | |
import comfy.sd | |
import comfy.utils | |
from comfy import model_management | |
from comfy import diffusers_convert | |
class EXVAE(comfy.sd.VAE): | |
def __init__(self, model_path, model_conf, dtype=torch.float32): | |
self.latent_dim = model_conf["embed_dim"] | |
self.latent_scale = model_conf["embed_scale"] | |
self.device = model_management.vae_device() | |
self.offload_device = model_management.vae_offload_device() | |
self.vae_dtype = dtype | |
sd = comfy.utils.load_torch_file(model_path) | |
model = None | |
if model_conf["type"] == "AutoencoderKL": | |
from .models.kl import AutoencoderKL | |
model = AutoencoderKL(config=model_conf) | |
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): | |
sd = diffusers_convert.convert_vae_state_dict(sd) | |
elif model_conf["type"] == "AutoencoderKL-VideoDecoder": | |
from .models.temporal_ae import AutoencoderKL | |
model = AutoencoderKL(config=model_conf) | |
elif model_conf["type"] == "VQModel": | |
from .models.vq import VQModel | |
model = VQModel(config=model_conf) | |
elif model_conf["type"] == "ConsistencyDecoder": | |
from .models.consistencydecoder import ConsistencyDecoder | |
model = ConsistencyDecoder() | |
sd = {f"model.{k}":v for k,v in sd.items()} | |
elif model_conf["type"] == "MoVQ3": | |
from .models.movq3 import MoVQ | |
model = MoVQ(model_conf) | |
else: | |
raise NotImplementedError(f"Unknown VAE type '{model_conf['type']}'") | |
self.first_stage_model = model.eval() | |
m, u = self.first_stage_model.load_state_dict(sd, strict=False) | |
if len(m) > 0: print("Missing VAE keys", m) | |
if len(u) > 0: print("Leftover VAE keys", u) | |
self.first_stage_model.to(self.vae_dtype).to(self.offload_device) | |
### Encode/Decode functions below needed due to source repo having 4 VAE channels and a scale factor of 8 hardcoded | |
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): | |
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) | |
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) | |
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) | |
pbar = comfy.utils.ProgressBar(steps) | |
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() | |
output = torch.clamp(( | |
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.latent_scale, pbar = pbar) + | |
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.latent_scale, pbar = pbar) + | |
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.latent_scale, pbar = pbar)) | |
/ 3.0) / 2.0, min=0.0, max=1.0) | |
return output | |
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): | |
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) | |
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) | |
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) | |
pbar = comfy.utils.ProgressBar(steps) | |
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() | |
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar) | |
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar) | |
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.latent_scale), out_channels=self.latent_dim, pbar=pbar) | |
samples /= 3.0 | |
return samples | |
def decode(self, samples_in): | |
self.first_stage_model = self.first_stage_model.to(self.device) | |
try: | |
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7 | |
model_management.free_memory(memory_used, self.device) | |
free_memory = model_management.get_free_memory(self.device) | |
batch_number = int(free_memory / memory_used) | |
batch_number = max(1, batch_number) | |
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.latent_scale), round(samples_in.shape[3] * self.latent_scale)), device="cpu") | |
for x in range(0, samples_in.shape[0], batch_number): | |
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) | |
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) | |
except model_management.OOM_EXCEPTION as e: | |
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") | |
pixel_samples = self.decode_tiled_(samples_in) | |
self.first_stage_model = self.first_stage_model.to(self.offload_device) | |
pixel_samples = pixel_samples.cpu().movedim(1,-1) | |
return pixel_samples | |
def encode(self, pixel_samples): | |
self.first_stage_model = self.first_stage_model.to(self.device) | |
pixel_samples = pixel_samples.movedim(-1,1) | |
try: | |
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. | |
model_management.free_memory(memory_used, self.device) | |
free_memory = model_management.get_free_memory(self.device) | |
batch_number = int(free_memory / memory_used) | |
batch_number = max(1, batch_number) | |
samples = torch.empty((pixel_samples.shape[0], self.latent_dim, round(pixel_samples.shape[2] // self.latent_scale), round(pixel_samples.shape[3] // self.latent_scale)), device="cpu") | |
for x in range(0, pixel_samples.shape[0], batch_number): | |
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) | |
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() | |
except model_management.OOM_EXCEPTION as e: | |
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") | |
samples = self.encode_tiled_(pixel_samples) | |
self.first_stage_model = self.first_stage_model.to(self.offload_device) | |
return samples | |