Spaces:
Running
on
T4
Running
on
T4
import importlib | |
import os | |
import socket | |
import sys | |
import ipdb # noqa: F401 | |
import torch | |
import torch.nn as nn | |
from omegaconf import OmegaConf | |
HOSTNAME = socket.gethostname() | |
if "trinity" in HOSTNAME: | |
# Might be outdated | |
config_path = "/home/amylin2/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml" | |
weights_path = "/home/amylin2/latent-diffusion/model.ckpt" | |
elif "grogu" in HOSTNAME: | |
# Might be outdated | |
config_path = "/home/jasonzh2/code/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml" | |
weights_path = "/home/jasonzh2/code/latent-diffusion/model.ckpt" | |
elif "ender" in HOSTNAME: | |
config_path = "/home/jason/ray_diffusion/external/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml" | |
weights_path = "/home/jason/ray_diffusion/external/latent-diffusion/model.ckpt" | |
else: | |
config_path = None | |
weights_path = None | |
if weights_path is not None: | |
LDM_PATH = os.path.dirname(weights_path) | |
if LDM_PATH not in sys.path: | |
sys.path.append(LDM_PATH) | |
def resize(image, size=None, scale_factor=None): | |
return nn.functional.interpolate( | |
image, | |
size=size, | |
scale_factor=scale_factor, | |
mode="bilinear", | |
align_corners=False, | |
) | |
def instantiate_from_config(config): | |
if "target" not in config: | |
if config == "__is_first_stage__": | |
return None | |
elif config == "__is_unconditional__": | |
return None | |
raise KeyError("Expected key `target` to instantiate.") | |
return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
def get_obj_from_str(string, reload=False): | |
module, cls = string.rsplit(".", 1) | |
if reload: | |
module_imp = importlib.import_module(module) | |
importlib.reload(module_imp) | |
return getattr(importlib.import_module(module, package=None), cls) | |
class PretrainedVAE(nn.Module): | |
def __init__(self, freeze_weights=True, num_patches_x=16, num_patches_y=16): | |
super().__init__() | |
config = OmegaConf.load(config_path) | |
self.model = instantiate_from_config(config.model) | |
self.model.init_from_ckpt(weights_path) | |
self.model.eval() | |
self.feature_dim = 16 | |
self.num_patches_x = num_patches_x | |
self.num_patches_y = num_patches_y | |
if freeze_weights: | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
def forward(self, x, autoresize=False): | |
""" | |
Spatial dimensions of output will be H // 16, W // 16. If autoresize is True, | |
then the input will be resized such that the output feature map is the correct | |
dimensions. | |
Args: | |
x (torch.Tensor): Images (B, C, H, W). Should be normalized to be [-1, 1]. | |
autoresize (bool): Whether to resize the input to match the num_patch | |
dimensions. | |
Returns: | |
torch.Tensor: Latent sample (B, 16, h, w) | |
""" | |
*B, c, h, w = x.shape | |
x = x.reshape(-1, c, h, w) | |
if autoresize: | |
new_w = self.num_patches_x * 16 | |
new_h = self.num_patches_y * 16 | |
x = resize(x, size=(new_h, new_w)) | |
decoded, latent = self.model(x) | |
# A little ambiguous bc it's all 16, but it is (c, h, w) | |
latent_sample = latent.sample().reshape( | |
*B, self.feature_dim, self.num_patches_y, self.num_patches_x | |
) | |
return latent_sample | |
activations = {} | |
def get_activation(name): | |
def hook(model, input, output): | |
activations[name] = output | |
return hook | |
class SpatialDino(nn.Module): | |
def __init__( | |
self, | |
freeze_weights=True, | |
model_type="dinov2_vits14", | |
num_patches_x=16, | |
num_patches_y=16, | |
activation_hooks=False, | |
): | |
super().__init__() | |
self.model = torch.hub.load("facebookresearch/dinov2", model_type) | |
self.feature_dim = self.model.embed_dim | |
self.num_patches_x = num_patches_x | |
self.num_patches_y = num_patches_y | |
if freeze_weights: | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
self.activation_hooks = activation_hooks | |
if self.activation_hooks: | |
self.model.blocks[5].register_forward_hook(get_activation("encoder1")) | |
self.model.blocks[11].register_forward_hook(get_activation("encoder2")) | |
self.activations = activations | |
def forward(self, x, autoresize=False): | |
""" | |
Spatial dimensions of output will be H // 14, W // 14. If autoresize is True, | |
then the output will be resized to the correct dimensions. | |
Args: | |
x (torch.Tensor): Images (B, C, H, W). Should be ImageNet normalized. | |
autoresize (bool): Whether to resize the input to match the num_patch | |
dimensions. | |
Returns: | |
feature_map (torch.tensor): (B, C, h, w) | |
""" | |
*B, c, h, w = x.shape | |
x = x.reshape(-1, c, h, w) | |
# if autoresize: | |
# new_w = self.num_patches_x * 14 | |
# new_h = self.num_patches_y * 14 | |
# x = resize(x, size=(new_h, new_w)) | |
# Output will be (B, H * W, C) | |
features = self.model.forward_features(x)["x_norm_patchtokens"] | |
features = features.permute(0, 2, 1) | |
features = features.reshape( # (B, C, H, W) | |
-1, self.feature_dim, h // 14, w // 14 | |
) | |
if autoresize: | |
features = resize(features, size=(self.num_patches_y, self.num_patches_x)) | |
features = features.reshape( | |
*B, self.feature_dim, self.num_patches_y, self.num_patches_x | |
) | |
return features | |