import importlib import os import socket import sys import ipdb # noqa: F401 import torch import torch.nn as nn from omegaconf import OmegaConf HOSTNAME = socket.gethostname() if "trinity" in HOSTNAME: # Might be outdated config_path = "/home/amylin2/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml" weights_path = "/home/amylin2/latent-diffusion/model.ckpt" elif "grogu" in HOSTNAME: # Might be outdated config_path = "/home/jasonzh2/code/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml" weights_path = "/home/jasonzh2/code/latent-diffusion/model.ckpt" elif "ender" in HOSTNAME: config_path = "/home/jason/ray_diffusion/external/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml" weights_path = "/home/jason/ray_diffusion/external/latent-diffusion/model.ckpt" else: config_path = None weights_path = None if weights_path is not None: LDM_PATH = os.path.dirname(weights_path) if LDM_PATH not in sys.path: sys.path.append(LDM_PATH) def resize(image, size=None, scale_factor=None): return nn.functional.interpolate( image, size=size, scale_factor=scale_factor, mode="bilinear", align_corners=False, ) def instantiate_from_config(config): if "target" not in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) class PretrainedVAE(nn.Module): def __init__(self, freeze_weights=True, num_patches_x=16, num_patches_y=16): super().__init__() config = OmegaConf.load(config_path) self.model = instantiate_from_config(config.model) self.model.init_from_ckpt(weights_path) self.model.eval() self.feature_dim = 16 self.num_patches_x = num_patches_x self.num_patches_y = num_patches_y if freeze_weights: for param in self.model.parameters(): param.requires_grad = False def forward(self, x, autoresize=False): """ Spatial dimensions of output will be H // 16, W // 16. If autoresize is True, then the input will be resized such that the output feature map is the correct dimensions. Args: x (torch.Tensor): Images (B, C, H, W). Should be normalized to be [-1, 1]. autoresize (bool): Whether to resize the input to match the num_patch dimensions. Returns: torch.Tensor: Latent sample (B, 16, h, w) """ *B, c, h, w = x.shape x = x.reshape(-1, c, h, w) if autoresize: new_w = self.num_patches_x * 16 new_h = self.num_patches_y * 16 x = resize(x, size=(new_h, new_w)) decoded, latent = self.model(x) # A little ambiguous bc it's all 16, but it is (c, h, w) latent_sample = latent.sample().reshape( *B, self.feature_dim, self.num_patches_y, self.num_patches_x ) return latent_sample activations = {} def get_activation(name): def hook(model, input, output): activations[name] = output return hook class SpatialDino(nn.Module): def __init__( self, freeze_weights=True, model_type="dinov2_vits14", num_patches_x=16, num_patches_y=16, activation_hooks=False, ): super().__init__() self.model = torch.hub.load("facebookresearch/dinov2", model_type) self.feature_dim = self.model.embed_dim self.num_patches_x = num_patches_x self.num_patches_y = num_patches_y if freeze_weights: for param in self.model.parameters(): param.requires_grad = False self.activation_hooks = activation_hooks if self.activation_hooks: self.model.blocks[5].register_forward_hook(get_activation("encoder1")) self.model.blocks[11].register_forward_hook(get_activation("encoder2")) self.activations = activations def forward(self, x, autoresize=False): """ Spatial dimensions of output will be H // 14, W // 14. If autoresize is True, then the output will be resized to the correct dimensions. Args: x (torch.Tensor): Images (B, C, H, W). Should be ImageNet normalized. autoresize (bool): Whether to resize the input to match the num_patch dimensions. Returns: feature_map (torch.tensor): (B, C, h, w) """ *B, c, h, w = x.shape x = x.reshape(-1, c, h, w) # if autoresize: # new_w = self.num_patches_x * 14 # new_h = self.num_patches_y * 14 # x = resize(x, size=(new_h, new_w)) # Output will be (B, H * W, C) features = self.model.forward_features(x)["x_norm_patchtokens"] features = features.permute(0, 2, 1) features = features.reshape( # (B, C, H, W) -1, self.feature_dim, h // 14, w // 14 ) if autoresize: features = resize(features, size=(self.num_patches_y, self.num_patches_x)) features = features.reshape( *B, self.feature_dim, self.num_patches_y, self.num_patches_x ) return features