|
"""This file contains the model definition of TiTok. |
|
|
|
Copyright (2024) Bytedance Ltd. and/or its affiliates |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
you may not use this file except in compliance with the License. |
|
You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
See the License for the specific language governing permissions and |
|
limitations under the License. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
|
|
from .modules.base_model import BaseModel |
|
from .modules.blocks import TiTokEncoder, TiTokDecoder |
|
from .quantizer.quantizer import VectorQuantizer, DiagonalGaussianDistribution |
|
from .modules.maskgit_vqgan import Encoder as Pixel_Eecoder |
|
from .modules.maskgit_vqgan import Decoder as Pixel_Decoder |
|
from .modules.maskgit_vqgan import VectorQuantizer as Pixel_Quantizer |
|
import json |
|
from omegaconf import OmegaConf |
|
from pathlib import Path |
|
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
class PretrainedTokenizer(nn.Module): |
|
def __init__(self, pretrained_weight): |
|
super().__init__() |
|
conf = OmegaConf.create( |
|
{ |
|
"channel_mult": [1, 1, 2, 2, 4], |
|
"num_resolutions": 5, |
|
"dropout": 0.0, |
|
"hidden_channels": 128, |
|
"num_channels": 3, |
|
"num_res_blocks": 2, |
|
"resolution": 256, |
|
"z_channels": 256, |
|
} |
|
) |
|
self.encoder = Pixel_Eecoder(conf) |
|
self.decoder = Pixel_Decoder(conf) |
|
self.quantize = Pixel_Quantizer( |
|
num_embeddings=1024, embedding_dim=256, commitment_cost=0.25 |
|
) |
|
|
|
self.load_state_dict( |
|
torch.load(pretrained_weight, map_location=torch.device("cpu")), strict=True |
|
) |
|
|
|
self.eval() |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
@torch.no_grad() |
|
def encode(self, x): |
|
hidden_states = self.encoder(x) |
|
quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states) |
|
return codebook_indices.detach() |
|
|
|
@torch.no_grad() |
|
def decode(self, codes): |
|
quantized_states = self.quantize.get_codebook_entry(codes) |
|
rec_images = self.decoder(quantized_states) |
|
rec_images = torch.clamp(rec_images, 0.0, 1.0) |
|
return rec_images.detach() |
|
|
|
@torch.no_grad() |
|
def decode_tokens(self, codes): |
|
return self.decode(codes) |
|
|
|
|
|
class TiTok( |
|
BaseModel, |
|
PyTorchModelHubMixin, |
|
tags=["arxiv:2406.07550", "image-tokenization"], |
|
repo_url="https://github.com/bytedance/1d-tokenizer", |
|
license="apache-2.0", |
|
): |
|
def __init__(self, config): |
|
|
|
if isinstance(config, dict): |
|
config = OmegaConf.create(config) |
|
|
|
super().__init__() |
|
self.config = config |
|
|
|
self.finetune_decoder = config.model.vq_model.get("finetune_decoder", True) |
|
|
|
self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq") |
|
if self.quantize_mode not in ["vq", "vae"]: |
|
raise ValueError(f"Unsupported quantize mode {self.quantize_mode}.") |
|
|
|
if self.finetune_decoder and self.quantize_mode not in ["vq"]: |
|
raise ValueError( |
|
"Only supprot finetune_decoder with vq quantization for now." |
|
) |
|
|
|
self.encoder = TiTokEncoder(config) |
|
self.decoder = TiTokDecoder(config) |
|
|
|
self.num_latent_tokens = config.model.vq_model.num_latent_tokens |
|
scale = self.encoder.width**-0.5 |
|
self.latent_tokens = nn.Parameter( |
|
scale * torch.randn(self.num_latent_tokens, self.encoder.width) |
|
) |
|
|
|
self.apply(self._init_weights) |
|
|
|
if self.quantize_mode == "vq": |
|
self.quantize = VectorQuantizer( |
|
codebook_size=config.model.vq_model.codebook_size, |
|
token_size=config.model.vq_model.token_size, |
|
commitment_cost=config.model.vq_model.commitment_cost, |
|
use_l2_norm=config.model.vq_model.use_l2_norm, |
|
) |
|
elif self.quantize_mode == "vae": |
|
self.quantize = DiagonalGaussianDistribution |
|
else: |
|
raise NotImplementedError |
|
|
|
if self.finetune_decoder: |
|
|
|
self.latent_tokens.requires_grad_(False) |
|
self.encoder.eval() |
|
self.encoder.requires_grad_(False) |
|
self.quantize.eval() |
|
self.quantize.requires_grad_(False) |
|
|
|
|
|
self.pixel_quantize = Pixel_Quantizer( |
|
num_embeddings=1024, embedding_dim=256, commitment_cost=0.25 |
|
) |
|
self.pixel_decoder = Pixel_Decoder( |
|
OmegaConf.create( |
|
{ |
|
"channel_mult": [1, 1, 2, 2, 4], |
|
"num_resolutions": 5, |
|
"dropout": 0.0, |
|
"hidden_channels": 128, |
|
"num_channels": 3, |
|
"num_res_blocks": 2, |
|
"resolution": 256, |
|
"z_channels": 256, |
|
} |
|
) |
|
) |
|
|
|
def _save_pretrained(self, save_directory: Path) -> None: |
|
"""Save weights and config to a local directory.""" |
|
|
|
|
|
dict_config = OmegaConf.to_container(self.config) |
|
|
|
file_path = Path(save_directory) / "config.json" |
|
with open(file_path, "w") as json_file: |
|
json.dump(dict_config, json_file, indent=4) |
|
super()._save_pretrained(save_directory) |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights. |
|
:param: |
|
module -> torch.nn.Module: module to initialize |
|
""" |
|
if ( |
|
isinstance(module, nn.Linear) |
|
or isinstance(module, nn.Conv1d) |
|
or isinstance(module, nn.Conv2d) |
|
): |
|
module.weight.data = nn.init.trunc_normal_( |
|
module.weight.data, mean=0.0, std=0.02 |
|
) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data = nn.init.trunc_normal_( |
|
module.weight.data, mean=0.0, std=0.02 |
|
) |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
def encode(self, x): |
|
if self.finetune_decoder: |
|
with torch.no_grad(): |
|
self.encoder.eval() |
|
self.quantize.eval() |
|
z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens) |
|
z_quantized, result_dict = self.quantize(z) |
|
result_dict["quantizer_loss"] *= 0 |
|
result_dict["commitment_loss"] *= 0 |
|
result_dict["codebook_loss"] *= 0 |
|
else: |
|
z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens) |
|
if self.quantize_mode == "vq": |
|
z_quantized, result_dict = self.quantize(z) |
|
elif self.quantize_mode == "vae": |
|
posteriors = self.quantize(z) |
|
z_quantized = posteriors.sample() |
|
result_dict = posteriors |
|
|
|
return z_quantized, result_dict |
|
|
|
def decode(self, z_quantized): |
|
decoded = self.decoder(z_quantized) |
|
if self.finetune_decoder: |
|
quantized_states = torch.einsum( |
|
"nchw,cd->ndhw", |
|
decoded.softmax(1), |
|
self.pixel_quantize.embedding.weight, |
|
) |
|
decoded = self.pixel_decoder(quantized_states) |
|
return decoded |
|
|
|
def decode_tokens(self, tokens): |
|
if self.quantize_mode == "vq": |
|
tokens = tokens.squeeze(1) |
|
batch, seq_len = tokens.shape |
|
z_quantized = self.quantize.get_codebook_entry(tokens.reshape(-1)).reshape( |
|
batch, 1, seq_len, -1 |
|
) |
|
z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous() |
|
elif self.quantize_mode == "vae": |
|
z_quantized = tokens |
|
decoded = self.decode(z_quantized) |
|
return decoded |
|
|
|
def forward(self, x): |
|
z_quantized, result_dict = self.encode(x) |
|
decoded = self.decode(z_quantized) |
|
return decoded, result_dict["min_encoding_indices"], z_quantized |
|
|