diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..5ffb87283bb83c401f264a29459b38982f46110c 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/app_examples/0.png filter=lfs diff=lfs merge=lfs -text +assets/app_examples/1.png filter=lfs diff=lfs merge=lfs -text +assets/overview.png filter=lfs diff=lfs merge=lfs -text +assets/app_examples/2.png filter=lfs diff=lfs merge=lfs -text +assets/app_examples/4.png filter=lfs diff=lfs merge=lfs -text +assets/comparison_of_generation.png filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..bc9f14e52ca7bade02bd33bc8df6d665a7314d9d --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +__pycache__/ +*.ckpt +checkpoints/ +results/ +VTBench_models/ +README.md diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..76c4a42f30a5bdace651c883c365a845c71b6cde --- /dev/null +++ b/app.py @@ -0,0 +1,166 @@ +import os +import spaces +import subprocess +import sys + +# REQUIREMENTS_FILE = "requirements.txt" +# if os.path.exists(REQUIREMENTS_FILE): +# try: +# print("Installing dependencies from requirements.txt...") +# subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_FILE]) +# print("Dependencies installed successfully.") +# except subprocess.CalledProcessError as e: +# print(f"Failed to install dependencies: {e}") +# else: +# print("requirements.txt not found.") + +import gradio as gr +from src.data_processing import pil_to_tensor, tensor_to_pil +from PIL import Image +from src.model_processing import get_model +from huggingface_hub import snapshot_download +import torch + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Running on: {device}") + +MODEL_DIR = "./VTBench_models" +if not os.path.exists(MODEL_DIR): + print("Downloading VTBench_models from Hugging Face...") + snapshot_download( + repo_id="huaweilin/VTBench_models", + local_dir=MODEL_DIR, + local_dir_use_symlinks=False + ) + print("Download complete.") + +example_image_paths = [f"assets/app_examples/{i}.png" for i in range(0, 5)] + +model_name_mapping = { + "SD3.5L": "SD3.5L", + "chameleon": "Chameleon", + # "flowmo_lo": "FlowMo Lo", + # "flowmo_hi": "FlowMo Hi", + # "gpt4o": "GPT-4o", + "janus_pro_1b": "Janus Pro 1B/7B", + # "llamagen-ds8": "LlamaGen ds8", + # "llamagen-ds16": "LlamaGen ds16", + # "llamagen-ds16-t2i": "LlamaGen ds16 T2I", + # "maskbit_16bit": "MaskBiT 16bit", + # "maskbit_18bit": "MaskBiT 18bit", + # "open_magvit2": "OpenMagViT", + # "titok_b64": "Titok-b64", + # "titok_bl64": "Titok-bl64", + # "titok_s128": "Titok-s128", + # "titok_bl128": "Titok-bl128", + # "titok_l32": "Titok-l32", + # "titok_sl256": "Titok-sl256", + # "var_256": "VAR-256", + # "var_512": "VAR-512", + # "FLUX.1-dev": "FLUX.1-dev", + # "infinity_d32": "Infinity-d32", + # "infinity_d64": "Infinity-d64", + # "bsqvit": "BSQ-VIT", +} + +def load_model(model_name): + model, data_params = get_model(MODEL_DIR, model_name) + model = model.to(device) + model.eval() + return model, data_params + +model_dict = { + model_name: load_model(model_name) + for model_name in model_name_mapping +} + +placeholder_image = Image.new("RGBA", (512, 512), (0, 0, 0, 0)) + +@spaces.GPU +def process_selected_models(uploaded_image, selected_models): + results = [] + for model_name in model_name_mapping: + if uploaded_image is None: + results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (No input)")) + elif model_name in selected_models: + try: + model, data_params = model_dict[model_name] + pixel_values = pil_to_tensor(uploaded_image, **data_params).unsqueeze(0).to(device) + output = model(pixel_values)[0] + reconstructed_image = tensor_to_pil(output[0].cpu(), **data_params) + results.append(gr.update(value=reconstructed_image, label=model_name_mapping[model_name])) + except Exception as e: + print(f"Error in model {model_name}: {e}") + results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (Error)")) + else: + results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (Not selected)")) + return results + +with gr.Blocks() as demo: + gr.Markdown("## VTBench") + + gr.Markdown("---") + + image_input = gr.Image( + type="pil", + label="Upload an image", + width=512, + height=512, + ) + + gr.Markdown("### Click on an example image to use it as input:") + example_rows = [example_image_paths[i:i+5] for i in range(0, len(example_image_paths), 5)] + for row in example_rows: + with gr.Row(): + for path in row: + ex_img = gr.Image( + value=path, + show_label=False, + interactive=True, + width=256, + height=256, + ) + + def make_loader(p=path): + def load_img(): + return Image.open(p) + return load_img + + ex_img.select(fn=make_loader(), outputs=image_input) + + gr.Markdown("---") + + gr.Markdown("⚠️ **The more models you select, the longer the processing time will be.**") + model_selector = gr.CheckboxGroup( + choices=list(model_name_mapping.keys()), + label="Select models to run", + value=["SD3.5L", "chameleon", "janus_pro_1b"], + interactive=True, + ) + run_button = gr.Button("Start Processing") + + image_outputs = [] + model_items = list(model_name_mapping.items()) + + n_columns = 5 + output_rows = [model_items[i:i+n_columns] for i in range(0, len(model_items), n_columns)] + + with gr.Column(): + for row in output_rows: + with gr.Row(): + for model_name, display_name in row: + out_img = gr.Image( + label=f"{display_name} (Not run)", + value=placeholder_image, + width=512, + height=512, + ) + image_outputs.append(out_img) + + run_button.click( + fn=process_selected_models, + inputs=[image_input, model_selector], + outputs=image_outputs + ) + +demo.launch() diff --git a/assets/app_examples/0.png b/assets/app_examples/0.png new file mode 100644 index 0000000000000000000000000000000000000000..ecb1083a1936433d09c6c7e3156ae4dabdb26fe4 --- /dev/null +++ b/assets/app_examples/0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4ea7967c763311298a587d3c8a7486913d63df0640e79e9e5cfd8cfdc9a4a558 +size 2854383 diff --git a/assets/app_examples/1.png b/assets/app_examples/1.png new file mode 100644 index 0000000000000000000000000000000000000000..557a7531eafc8b69e847822c7c06bee57cf51c0a --- /dev/null +++ b/assets/app_examples/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d54dacbad49976f3105b8e69fede50f7a1cf7abe96ec5244c46eaaadfd688a6 +size 1286181 diff --git a/assets/app_examples/2.png b/assets/app_examples/2.png new file mode 100644 index 0000000000000000000000000000000000000000..51bd8f73ef0e28778285974b07346abb643881b2 --- /dev/null +++ b/assets/app_examples/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:656a438b675122e0384097c91554fe4d810e8a0025b770443998e473574e5056 +size 2094389 diff --git a/assets/app_examples/3.png b/assets/app_examples/3.png new file mode 100644 index 0000000000000000000000000000000000000000..88b64188e5b7768699b9c46fe29ae8758aa97151 --- /dev/null +++ b/assets/app_examples/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fdee23b36fae13bb5806738e96607f8690941b647f7d6db0865f5efd745d8360 +size 89379 diff --git a/assets/app_examples/4.png b/assets/app_examples/4.png new file mode 100644 index 0000000000000000000000000000000000000000..ab6eccfc37c4860bbb2807fa16f93f033791950c --- /dev/null +++ b/assets/app_examples/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9040083842809702de7efc1944aed09b099919f1ed2267a7280fc2bab82df0b1 +size 1663988 diff --git a/assets/comparison_of_generation.png b/assets/comparison_of_generation.png new file mode 100644 index 0000000000000000000000000000000000000000..9f1527192e2d75127b20edcef8e49deda06a802e --- /dev/null +++ b/assets/comparison_of_generation.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8ce5f1b645dceb72c01cce066b2d91bb19935877ac03a4ea69c74ed612e8212 +size 3476633 diff --git a/assets/overview.png b/assets/overview.png new file mode 100644 index 0000000000000000000000000000000000000000..ba2e84fe9458df7268cdeb3b68c7da2bf88be0b7 --- /dev/null +++ b/assets/overview.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d80cb837fa4594b76be6f8c912e1dcb13c5dc6a4a57b09e7861f2d20ce5c92e2 +size 2149201 diff --git a/evaluations/character_error_rate.py b/evaluations/character_error_rate.py new file mode 100644 index 0000000000000000000000000000000000000000..ea153566eeabb16e1c38e989bf04122a9993cf60 --- /dev/null +++ b/evaluations/character_error_rate.py @@ -0,0 +1,27 @@ +import torch +from torchmetrics import Metric +from ocr import OCR +import Levenshtein + + +class CharacterErrorRate(Metric): + def __init__(self, ocr, dist_sync_on_step=False): + # super().__init__(dist_sync_on_step=dist_sync_on_step) + super().__init__() + self.add_state("total_errors", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total_chars", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.ocr = ocr + + def update(self, pred_images, target_images): + for pred_img, target_img in zip(pred_images, target_images): + pred_text = self.ocr.predict(pred_img) + target_text = self.ocr.predict(target_img) + + dist = Levenshtein.distance(pred_text, target_text) + self.total_errors += dist + self.total_chars += len(target_text) + + def compute(self): + if self.total_chars == 0: + return torch.tensor(0.0) + return self.total_errors / self.total_chars diff --git a/evaluations/evaluate_images.py b/evaluations/evaluate_images.py new file mode 100644 index 0000000000000000000000000000000000000000..c353c53383e5d9d5ef0e50e0adeb2539506df5a9 --- /dev/null +++ b/evaluations/evaluate_images.py @@ -0,0 +1,130 @@ +import os +import argparse +from PIL import Image +from tqdm import tqdm +from torchvision import transforms +from torch.utils.data import Dataset, DataLoader +import torch +import torch.nn.functional as F +from ocr import OCR +from character_error_rate import CharacterErrorRate +from word_error_rate import WordErrorRate +from torchmetrics.image import ( + PeakSignalNoiseRatio, + StructuralSimilarityIndexMeasure, + LearnedPerceptualImagePatchSimilarity, + FrechetInceptionDistance, +) + + +class ImageFolderPairDataset(Dataset): + def __init__(self, dir1, dir2, transform=None): + self.dir1 = dir1 + self.dir2 = dir2 + self.filenames = sorted(os.listdir(dir1)) + self.transform = transform + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + name = self.filenames[idx] + img1 = Image.open(os.path.join(self.dir1, name)).convert("RGB") + img2 = Image.open(os.path.join(self.dir2, name)).convert("RGB") + if self.transform: + img1 = self.transform(img1) + img2 = self.transform(img2) + return img1, img2 + + +def evaluate(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + transform = transforms.Compose( + [transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor()] + ) + + dataset = ImageFolderPairDataset( + args.original_dir, args.reconstructed_dir, transform + ) + loader = DataLoader( + dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers + ) + + if "cer" in args.metrics or "wer" in args.metrics: + ocr = OCR(device) + + # Metrics init + metrics = {} + + if "psnr" in args.metrics: + metrics["psnr"] = PeakSignalNoiseRatio().to(device) + if "ssim" in args.metrics: + metrics["ssim"] = StructuralSimilarityIndexMeasure().to(device) + if "lpips" in args.metrics: + metrics["lpips"] = LearnedPerceptualImagePatchSimilarity().to(device) + if "fid" in args.metrics: + metrics["fid"] = FrechetInceptionDistance().to(device) + if "cer" in args.metrics: + metrics["cer"] = CharacterErrorRate(ocr) + if "wer" in args.metrics: + metrics["wer"] = WordErrorRate(ocr) + + for batch in tqdm(loader, desc="Evaluating"): + # img1, img1_path, img2, img2_path = [b.to(device) for b in batch] + img1, img2 = [b.to(device) for b in batch] + + if "psnr" in metrics: + metrics["psnr"].update(img2, img1) + if "ssim" in metrics: + metrics["ssim"].update(img2, img1) + if "lpips" in metrics: + metrics["lpips"].update(img2, img1) + if "cer" in metrics: + metrics["cer"].update(img2, img1) + if "wer" in metrics: + metrics["wer"].update(img2, img1) + if "fid" in metrics: + img1_uint8 = (img1 * 255).clamp(0, 255).to(torch.uint8) + img2_uint8 = (img2 * 255).clamp(0, 255).to(torch.uint8) + metrics["fid"].update(img1_uint8, real=True) + metrics["fid"].update(img2_uint8, real=False) + + print("\nResults:") + for name, metric in metrics.items(): + print(f"{name.upper()}", end="\t") + print() + for name, metric in metrics.items(): + result = metric.compute().item() + print(f"{result:.4f}", end="\t") + print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--original_dir", type=str, required=True, help="Path to original images" + ) + parser.add_argument( + "--reconstructed_dir", + type=str, + required=True, + help="Path to reconstructed images", + ) + parser.add_argument( + "--metrics", + nargs="+", + default=["psnr", "ssim", "lpips", "fid"], + help="Metrics to compute: psnr, ssim, lpips, fid", + ) + parser.add_argument( + "--batch_size", type=int, default=8, help="Batch size for processing" + ) + parser.add_argument("--image_size", type=int, default=256, help="Image resize size") + parser.add_argument( + "--num_workers", type=int, default=4, help="Number of workers for DataLoader" + ) + args = parser.parse_args() + + evaluate(args) diff --git a/evaluations/ocr.py b/evaluations/ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..b87992eba14c333af9a36630dcad2df137677b5c --- /dev/null +++ b/evaluations/ocr.py @@ -0,0 +1,44 @@ +from PIL import Image +from transformers import AutoProcessor, AutoModelForImageTextToText +import torch + + +class OCR: + def __init__(self, device="cpu"): + self.device = torch.device(device) + self.model = AutoModelForImageTextToText.from_pretrained( + "google/gemma-3-12b-it", + torch_dtype=torch.bfloat16, + ).to(self.device) + self.processor = AutoProcessor.from_pretrained("google/gemma-3-12b-it") + + self.messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + { + "type": "text", + "text": "Extract and output only the text from the image in its original language. If there is no text, return nothing.", + }, + ], + }, + ] + + def predict(self, image): + image = ( + (image * 255).clamp(0, 255).to(torch.uint8).permute((1, 2, 0)).cpu().numpy() + ) + image = Image.fromarray(image).convert("RGB").resize((1024, 1024)) + prompt = self.processor.apply_chat_template( + self.messages, add_generation_prompt=True + ) + inputs = self.processor(text=prompt, images=[image], return_tensors="pt").to( + self.device + ) + with torch.no_grad(): + generated_ids = self.model.generate(**inputs, max_new_tokens=1024) + generated_text = self.processor.batch_decode( + generated_ids[:, inputs.input_ids.shape[-1] :], skip_special_tokens=True + )[0] + return generated_text diff --git a/evaluations/word_error_rate.py b/evaluations/word_error_rate.py new file mode 100644 index 0000000000000000000000000000000000000000..75be349fafe60520d5e5f6b9aa0bd65536f92a53 --- /dev/null +++ b/evaluations/word_error_rate.py @@ -0,0 +1,30 @@ +import torch +from torchmetrics import Metric +import Levenshtein + + +class WordErrorRate(Metric): + def __init__(self, ocr, dist_sync_on_step=False): + # super().__init__(dist_sync_on_step=dist_sync_on_step) + super().__init__() + self.ocr = ocr + self.add_state("total_errors", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total_words", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, pred_images, target_images): + for pred_img, target_img in zip(pred_images, target_images): + pred_text = self.ocr.predict(pred_img) + target_text = self.ocr.predict(target_img) + + pred_words = pred_text.strip().split() + target_words = target_text.strip().split() + + dist = Levenshtein.distance(" ".join(pred_words), " ".join(target_words)) + + self.total_errors += dist + self.total_words += len(target_words) + + def compute(self): + if self.total_words == 0: + return torch.tensor(0.0) + return self.total_errors / self.total_words diff --git a/examples/get_result.py b/examples/get_result.py new file mode 100644 index 0000000000000000000000000000000000000000..ed9549004355e03970e9c3049c921fddef967a5a --- /dev/null +++ b/examples/get_result.py @@ -0,0 +1,94 @@ +import os +import pandas as pd + +root_dir = "./" + +model_name_mapping = { + "flowmo_lo": "FlowMo Lo", + "flowmo_hi": "FlowMo Hi", + "gpt4o": "GPT-4o", + "janus_pro_1b": "Janus Pro 1B/7B", + "llamagen-ds8": "LlamaGen ds8", + "llamagen-ds16": "LlamaGen ds16", + "llamagen-ds16-t2i": "LlamaGen ds16 T2I", + "maskbit_16bit": "MaskBiT 16bit", + "maskbit_18bit": "MaskBiT 18bit", + "open_magvit2": "OpenMagViT", + "titok_b64": "Titok-b64", + "titok_bl64": "Titok-bl64", + "titok_s128": "Titok-s128", + "titok_bl128": "Titok-bl128", + "titok_l32": "Titok-l32", + "titok_sl256": "Titok-sl256", + "var_256": "VAR-256", + "var_512": "VAR-512", + "SD3.5L": "SD3.5L", + "FLUX.1-dev": "FLUX.1-dev", + "infinity_d32": "Infinity-d32", + "infinity_d64": "Infinity-d64", + "chameleon": "Chameleon", + "bsqvit": "BSQ-VIT", +} + +output_order = [ + "FlowMo Lo", + "FlowMo Hi", + "MaskBiT 16bit", + "MaskBiT 18bit", + "Titok-l32", + "Titok-b64", + "Titok-s128", + "Titok-bl64", + "Titok-bl128", + "Titok-sl256", + "OpenMagViT", + "LlamaGen ds8", + "BSQ-VIT", + "VAR-256", + "Janus Pro 1B/7B", + "Chameleon", + "LlamaGen ds16", + "LlamaGen ds16 T2I", + "VAR-512", + "Infinity-d32", + "Infinity-d64", + "SD3.5L", + "FLUX.1-dev", + "GPT-4o", +] + +for dataset_name in os.listdir(root_dir): + dataset_path = os.path.join(root_dir, dataset_name) + if not os.path.isdir(dataset_path): + continue + + results = {} + + for model_dir in os.listdir(dataset_path): + model_path = os.path.join(dataset_path, model_dir) + result_file = os.path.join(model_path, "result.txt") + + if os.path.isfile(result_file): + with open(result_file, "r", encoding="utf-8") as f: + lines = f.readlines() + + if len(lines) >= 2: + metrics_line = lines[-2].strip() + values_line = lines[-1].strip() + + metrics = metrics_line.split() + values = values_line.split() + + mapped_name = model_name_mapping.get(model_dir, model_dir) + results[mapped_name] = values + + if results: + header = "\t".join(metrics) + print(f"{dataset_name}\t{header}") + for model_name in output_order: + if model_name in results: + values = results[model_name] + print(f"{model_name}\t" + "\t".join(values)) + else: + print(f"{model_name}\t" + "no result") + print() diff --git a/examples/run.sh b/examples/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..f06fd0930cbd9ec667b8722403963dbf52665dd3 --- /dev/null +++ b/examples/run.sh @@ -0,0 +1,85 @@ +#!/bin/bash + +dataset_name_list=("task1-imagenet" "task1-high-resolution" "task1-varying-resolution" "task2-detail-preservation" "task3-movie-posters" "task3-arxiv-abstracts" "task3-multilingual_Chinese" "task3-multilingual_Hindi" "task3-multilingual_Japanese" "task3-multilingual_Korean") +model_name_list=("chameleon" "llamagen-ds16" "llamagen-ds8" "flowmo_lo" "flowmo_hi" "open_magvit2" "titok_l32" "titok_b64" "titok_s128" "titok_bl64" "titok_bl128" "titok_sl256" "janus_pro_1b" "maskbit_18bit" "maskbit_16bit" "var_256" "var_512" "SD3.5L" "gpt4o" "llamagen-ds16-t2i" "infinity_d32" "infinity_d64" "bsqvit" "FLUX.1-dev") + +batch_size=1 + +if command -v sbatch >/dev/null 2>&1; then + has_slurm=true +else + has_slurm=false +fi + +shell_dir=$(cd "$(dirname "$0")";pwd) +echo "shell_dir: ${shell_dir}" +base_path="${shell_dir}/../" + +for dataset_name in "${dataset_name_list[@]}" +do + cd ${shell_dir} + folder_dir="${dataset_name}" + mkdir ${folder_dir} + + metrics="fid ssim psnr lpips" + split_name="test" + n_take=-1 + + if [[ $dataset_name == task3-multilingual_* ]]; then + split_name="${dataset_name##*_}" + dataset_name="${dataset_name%_*}" + fi + if [ "$dataset_name" = "task1-imagenet" ]; then + split_name="val" + fi + + if [ "$dataset_name" = "task1-varying-resolution" ]; then + batch_size=1 + fi + if [ "$dataset_name" = "task3-movie-posters" ]; then + metrics="fid ssim psnr lpips cer wer" + fi + if [ "$dataset_name" = "task3-arxiv-abstracts" ]; then + metrics="fid ssim psnr lpips cer wer" + fi + if [ "$dataset_name" = "task3-multilingual" ]; then + metrics="fid ssim psnr lpips cer" + fi + + for model_name in "${model_name_list[@]}" + do + if [ "$dataset_name" = "task1-imagenet" ] && [ "$model_name" = "gpt4o" ]; then + n_take=100 + fi + cd ${shell_dir} + + work_dir="${folder_dir}/${model_name}" + echo "model_name: ${model_name}, work_dir: ${work_dir}" + mkdir ${work_dir} + + cp submit.sh ${work_dir} + + cd ${work_dir} + sed -i "s|{model_name}|${model_name}|g" submit.sh + sed -i "s|{split_name}|${split_name}|g" submit.sh + sed -i "s|{dataset_name}|${dataset_name}|g" submit.sh + sed -i "s|{batch_size}|${batch_size}|g" submit.sh + sed -i "s|{base_path}|${base_path}|g" submit.sh + sed -i "s|{metrics}|${metrics}|g" submit.sh + sed -i "s|{n_take}|${n_take}|g" submit.sh + +# if [ "$has_slurm" = true ]; then +# res=$(sbatch ./submit.sh) +# res=($res) +# task_id=${res[-1]} +# echo "task_id: ${task_id}" +# touch "task_id_${task_id}" +# else +# echo "Slurm not detected, running with bash..." +# bash ./submit.sh +# fi + + bash ./submit.sh + + done +done diff --git a/examples/submit.sh b/examples/submit.sh new file mode 100644 index 0000000000000000000000000000000000000000..485b5967d84b36cae94492be2ed10074ec83c2d8 --- /dev/null +++ b/examples/submit.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Put your slurm commands here + +accelerate launch --num_processes=1 {base_path}/main.py --batch_size {batch_size} --model_name {model_name} --split_name {split_name} --dataset_name {dataset_name} --output_dir {model_name}_results --n_take {n_take} +python {base_path}/evaluations/evaluate_images.py \ + --original_dir {model_name}_results/original_images \ + --reconstructed_dir {model_name}_results/reconstructed_images/ \ + --metrics {metrics} \ + --batch_size 16 \ + --num_workers 8 | tee result.txt diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..599727aa0959dc16fce999c4f73831d263b3fda0 --- /dev/null +++ b/main.py @@ -0,0 +1,99 @@ +import numpy as np +import os +import PIL +import pickle +import torch +import argparse +import json +from PIL import Image +import torch.nn as nn +import torch +from transformers import AutoProcessor, AutoModelForImageTextToText +from src.data_loader import DataCollatorForSupervisedDataset, get_dataset +from src.data_processing import tensor_to_pil +from src.model_processing import get_model +from PIL import Image +from accelerate import Accelerator +from torch.utils.data import DataLoader +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor + +parser = argparse.ArgumentParser() +parser.add_argument("--model_name", type=str, default="chameleon") +parser.add_argument("--model_path", type=str, default=None) +parser.add_argument("--dataset_name", type=str, default="task3-movie-posters") +parser.add_argument("--split_name", type=str, default="test") +parser.add_argument("--batch_size", default=8, type=int) +parser.add_argument("--output_dir", type=str, default=None) +parser.add_argument("--begin_id", default=0, type=int) +parser.add_argument("--n_take", default=-1, type=int) +args = parser.parse_args() + +batch_size = args.batch_size +output_dir = args.output_dir + +accelerator = Accelerator() + +if accelerator.is_main_process and output_dir is not None: + os.makedirs(output_dir, exist_ok=True) + os.makedirs(f"{output_dir}/original_images", exist_ok=True) + os.makedirs(f"{output_dir}/reconstructed_images", exist_ok=True) + os.makedirs(f"{output_dir}/results", exist_ok=True) + +model, data_params = get_model(args.model_path, args.model_name) +dataset = get_dataset(args.dataset_name, args.split_name, None if args.n_take <= 0 else args.n_take) +data_collator = DataCollatorForSupervisedDataset(args.dataset_name, **data_params) +dataloader = DataLoader( + dataset, batch_size=batch_size, num_workers=0, collate_fn=data_collator +) + +model, dataloader = accelerator.prepare(model, dataloader) +print("Model prepared...") + + +def save_results( + pixel_values, reconstructed_image, idx, output_dir, data_params +): + if reconstructed_image is None: + return + + ori_img = tensor_to_pil(pixel_values, **data_params) + rec_img = tensor_to_pil(reconstructed_image, **data_params) + + ori_img.save(f"{output_dir}/original_images/{idx:08d}.png") + rec_img.save(f"{output_dir}/reconstructed_images/{idx:08d}.png") + + result = { + "ori_img": ori_img, + "rec_img": rec_img, + } + + with open(f"{output_dir}/results/{idx:08d}.pickle", "wb") as fw: + pickle.dump(result, fw) + + +executor = ThreadPoolExecutor(max_workers=16) +with torch.no_grad(): + print("Begin data loading...") + for batch in tqdm(dataloader): + pixel_values = batch["image"] + reconstructed_images = model(pixel_values) + if isinstance(reconstructed_images, tuple): + reconstructed_images = reconstructed_images[0] + + if output_dir is not None: + idx_list = batch["idx"] + original_images = pixel_values.detach().cpu() + if not isinstance(reconstructed_images, list): + reconstructed_images = reconstructed_images.detach().cpu() + for i in range(pixel_values.shape[0]): + executor.submit( + save_results, + original_images[i], + reconstructed_images[i], + idx_list[i], + output_dir, + data_params, + ) + +executor.shutdown(wait=True) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5b1940354060a25772e14f199a1cb1d7b65ab402 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +numpy +mup==1.0.0 +einops +omegaconf +lightning==2.3.3 +piq +python-Levenshtein +verovio +pytorch_fid +transformers +torch-fidelity +accelerate +datasets +git+https://github.com/deepseek-ai/Janus.git +diffusers +openai +imageio +huggingface_hub +gradio +torch +torchvision diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/data_loader.py b/src/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd975ad4178e497db9d5f255883fb47797963ed --- /dev/null +++ b/src/data_loader.py @@ -0,0 +1,61 @@ +import PIL +from PIL import Image +from dataclasses import dataclass, field +from datasets import load_dataset +import torch +from .data_processing import pil_to_tensor + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + def __init__(self, dataset_name, **kwargs): + override_params = {} + if dataset_name == "DIV2K": + override_params = { + "target_image_size": -1, + "lock_ratio": True, + "center_crop": False, + "padding": False, + } + if dataset_name == "imagenet": + override_params = {"center_crop": True, "padding": False} + if dataset_name == "movie_posters": + override_params = {"center_crop": True, "padding": False} + if dataset_name == "high_quality_1024": + override_params = {"target_image_size": (1024, 1024)} + + self.data_params = {**kwargs, **override_params} + + def __call__(self, instances): + images = torch.stack( + [ + pil_to_tensor(instance["image"], **self.data_params) + for instance in instances + ], + dim=0, + ) + idx = [instance["idx"] for instance in instances] + return dict(image=images, idx=idx) + + +class ImagenetDataset(torch.utils.data.Dataset): + def __init__(self, dataset_name, split_name="test", n_take=None): + print(dataset_name, split_name) + ds = load_dataset("huaweilin/VTBench", name=dataset_name, split=split_name if n_take is None else f"{split_name}[:{n_take}]") + self.image_list = ds["image"] + + def __len__(self): + return len(self.image_list) + + def __getitem__(self, idx): + return dict( + image=self.image_list[idx], + idx=idx, + ) + + +def get_dataset(dataset_name, split_name, n_take): + dataset = ImagenetDataset(dataset_name, split_name, n_take) + return dataset diff --git a/src/data_processing.py b/src/data_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..2381611522565f4de0cc29af12f57ff8ec92eb9b --- /dev/null +++ b/src/data_processing.py @@ -0,0 +1,89 @@ +import numpy as np +import PIL +from PIL import Image +import torch + + +def pil_to_tensor( + img: Image.Image, + target_image_size=512, + lock_ratio=True, + center_crop=True, + padding=False, + standardize=True, + **kwarg +) -> torch.Tensor: + if img.mode != "RGB": + img = img.convert("RGB") + + if isinstance(target_image_size, int): + target_size = (target_image_size, target_image_size) + if target_image_size < 0: + target_size = img.size + else: + target_size = target_image_size # (width, height) + + if lock_ratio: + original_width, original_height = img.size + target_width, target_height = target_size + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if center_crop: + scale = max(scale_w, scale_h) + elif padding: + scale = min(scale_w, scale_h) + else: + scale = 1.0 # fallback + + new_size = (round(original_width * scale), round(original_height * scale)) + img = img.resize(new_size, Image.LANCZOS) + + if center_crop: + left = (img.width - target_width) // 2 + top = (img.height - target_height) // 2 + img = img.crop((left, top, left + target_width, top + target_height)) + elif padding: + new_img = Image.new("RGB", target_size, (0, 0, 0)) + left = (target_width - img.width) // 2 + top = (target_height - img.height) // 2 + new_img.paste(img, (left, top)) + img = new_img + else: + img = img.resize(target_size, Image.LANCZOS) + + np_img = np.array(img) / 255.0 # Normalize to [0, 1] + if standardize: + np_img = np_img * 2 - 1 # Scale to [-1, 1] + tensor_img = torch.from_numpy(np_img).permute(2, 0, 1).float() # (C, H, W) + + return tensor_img + + +def tensor_to_pil(chw_tensor: torch.Tensor, standardize=True, **kwarg) -> PIL.Image: + # Ensure detachment and move tensor to CPU. + detached_chw_tensor = chw_tensor.detach().cpu() + + # Normalize tensor to [0, 1] range from [-1, 1] range. + if standardize: + normalized_chw_tensor = ( + torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0 + ) / 2.0 + else: + normalized_chw_tensor = torch.clamp(detached_chw_tensor, 0.0, 1.0) + + # Permute CHW tensor to HWC format and convert to NumPy array. + hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy() + + # Convert to an 8-bit unsigned integer format. + image_array_uint8 = (hwc_array * 255).astype(np.uint8) + + # Convert NumPy array to PIL Image. + pil_image = Image.fromarray(image_array_uint8) + + # Convert image to RGB if it is not already. + if pil_image.mode != "RGB": + pil_image = pil_image.convert("RGB") + + return pil_image diff --git a/src/model_processing.py b/src/model_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..859535b84eeb8179b900032022d402e9ee71e169 --- /dev/null +++ b/src/model_processing.py @@ -0,0 +1,409 @@ +import requests +import os +import yaml +from .utils import get_ckpt, get_yaml_config + + +def download_ckpt_yaml(model_path, model_name, ckpt_path, yaml_url=None): + def download_file(url, save_path): + response = requests.get(url) + response.raise_for_status() + with open(save_path, 'wb') as f: + f.write(response.content) + + # os.makedirs(model_path, exist_ok=True) + local_dir = os.path.join(model_path, model_name) + os.makedirs(local_dir, exist_ok=True) + + ckpt_name = ckpt_path.split("/")[-1] + local_ckpt_path = os.path.join(local_dir, ckpt_name) + if not os.path.exists(local_ckpt_path): + print(f"Downloading CKPT to {local_ckpt_path}") + download_file(ckpt_path, local_ckpt_path) + + if yaml_url: + yaml_name = yaml_url.split("/")[-1] + local_yaml_path = os.path.join(local_dir, yaml_name) + if not os.path.exists(local_yaml_path): + print(f"Downloading YAML to {local_yaml_path}") + download_file(yaml_url, local_yaml_path) + return local_ckpt_path, local_yaml_path + + return local_ckpt_path, None + + +def get_model(model_path, model_name): + model = None + data_params = { + "target_image_size": (512, 512), + "lock_ratio": True, + "center_crop": True, + "padding": False, + } + + if model_name.lower() == "anole": + from src.vqvaes.anole.anole import VQModel + yaml_url = "https://huggingface.co/GAIR/Anole-7b-v0.1/resolve/main/tokenizer/vqgan.yaml" + ckpt_path = "https://huggingface.co/GAIR/Anole-7b-v0.1/resolve/main/tokenizer/vqgan.ckpt" + + if model_path is not None: + ckpt_path, yaml_url = download_ckpt_yaml(model_path, "anole", ckpt_path, yaml_url) + config = get_yaml_config(yaml_url) + + params = config["model"]["params"] + if "lossconfig" in params: + del params["lossconfig"] + params["ckpt_path"] = ckpt_path + model = VQModel(**params) + data_params = { + "target_image_size": (512, 512), + "lock_ratio": True, + "center_crop": True, + "padding": False, + } + + elif model_name.lower() == "chameleon": + from src.vqvaes.anole.anole import VQModel + + yaml_url = "https://huggingface.co/huaweilin/chameleon_vqvae/resolve/main/vqgan.yaml" + ckpt_path = "https://huggingface.co/huaweilin/chameleon_vqvae/resolve/main/vqgan.ckpt" + if model_path is not None: + ckpt_path, yaml_url = download_ckpt_yaml(model_path, "chameleon", ckpt_path, yaml_url) + config = get_yaml_config(yaml_url) + + params = config["model"]["params"] + if "lossconfig" in params: + del params["lossconfig"] + params["ckpt_path"] = ckpt_path + model = VQModel(**params) + data_params = { + "target_image_size": (512, 512), + "lock_ratio": True, + "center_crop": True, + "padding": False, + } + + elif model_name.lower() == "llamagen-ds16": + from src.vqvaes.llamagen.llamagen import VQ_models + ckpt_path = "https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds16_c2i.pt" + if model_path is not None: + ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds16", ckpt_path, None) + + model = VQ_models["VQ-16"](codebook_size=16384, codebook_embed_dim=8) + model.load_state_dict(get_ckpt(ckpt_path, key="model")) + data_params = { + "target_image_size": (512, 512), + "lock_ratio": True, + "center_crop": True, + "padding": False, + } + + elif model_name.lower() == "llamagen-ds16-t2i": + from src.vqvaes.llamagen.llamagen import VQ_models + ckpt_path = "https://huggingface.co/peizesun/llamagen_t2i/resolve/main/vq_ds16_t2i.pt" + if model_path is not None: + ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds16-t2i", ckpt_path, None) + + model = VQ_models["VQ-16"](codebook_size=16384, codebook_embed_dim=8) + model.load_state_dict(get_ckpt(ckpt_path, key="model")) + data_params = { + "target_image_size": (512, 512), + "lock_ratio": True, + "center_crop": True, + "padding": False, + } + + elif model_name.lower() == "llamagen-ds8": + from src.vqvaes.llamagen.llamagen import VQ_models + ckpt_path = "https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds8_c2i.pt" + if model_path is not None: + ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds8", ckpt_path, None) + + model = VQ_models["VQ-8"](codebook_size=16384, codebook_embed_dim=8) + model.load_state_dict(get_ckpt(ckpt_path, key="model")) + data_params = { + "target_image_size": (256, 256), + "lock_ratio": True, + "center_crop": True, + "padding": False, + } + + elif model_name.lower() == "flowmo_lo": + from src.vqvaes.flowmo.flowmo import build_model + yaml_url = "https://raw.githubusercontent.com/kylesargent/FlowMo/refs/heads/main/flowmo/configs/base.yaml" + ckpt_path = "https://huggingface.co/ksarge/FlowMo/resolve/main/flowmo_lo.pth" + if model_path is not None: + ckpt_path, yaml_url = download_ckpt_yaml(model_path, "flowmo_lo", ckpt_path, yaml_url) + config = get_yaml_config(yaml_url) + + config.model.context_dim = 18 + model = build_model(config) + model.load_state_dict( + get_ckpt(ckpt_path, key="model_ema_state_dict") + ) + data_params = { + "target_image_size": (256, 256), + "lock_ratio": True, + "center_crop": True, + "padding": False, + } + + elif model_name.lower() == "flowmo_hi": + from src.vqvaes.flowmo.flowmo import build_model + + yaml_url = "https://raw.githubusercontent.com/kylesargent/FlowMo/refs/heads/main/flowmo/configs/base.yaml" + ckpt_path = "https://huggingface.co/ksarge/FlowMo/resolve/main/flowmo_hi.pth" + if model_path is not None: + ckpt_path, yaml_url = download_ckpt_yaml(model_path, "flowmo_hi", ckpt_path, yaml_url) + config = get_yaml_config(yaml_url) + + config.model.context_dim = 56 + config.model.codebook_size_for_entropy = 14 + model = build_model(config) + model.load_state_dict( + get_ckpt(ckpt_path, key="model_ema_state_dict") + ) + data_params = { + "target_image_size": (256, 256), + "lock_ratio": True, + "center_crop": True, + "padding": False, + } + + elif model_name.lower() == "open_magvit2": + from src.vqvaes.open_magvit2.open_magvit2 import VQModel + + yaml_url = "https://raw.githubusercontent.com/TencentARC/SEED-Voken/refs/heads/main/configs/Open-MAGVIT2/gpu/imagenet_lfqgan_256_L.yaml" + ckpt_path = "https://huggingface.co/TencentARC/Open-MAGVIT2-Tokenizer-256-resolution/resolve/main/imagenet_256_L.ckpt" + if model_path is not None: + ckpt_path, yaml_url = download_ckpt_yaml(model_path, "open_magvit2", ckpt_path, yaml_url) + config = get_yaml_config(yaml_url) + + model = VQModel(**config.model.init_args) + model.load_state_dict(get_ckpt(ckpt_path, key="state_dict")) + data_params = { + "target_image_size": (256, 256), + "lock_ratio": True, + "center_crop": True, + "padding": False, + } + + elif "maskbit" in model_name.lower(): + from src.vqvaes.maskbit.maskbit import ConvVQModel + + if "16bit" in model_name.lower(): + yaml_url = "https://raw.githubusercontent.com/markweberdev/maskbit/refs/heads/main/configs/tokenizer/maskbit_tokenizer_16bit.yaml" + ckpt_path = "https://huggingface.co/markweber/maskbit_tokenizer_16bit/resolve/main/maskbit_tokenizer_16bit.bin" + if model_path is not None: + ckpt_path, yaml_url = download_ckpt_yaml(model_path, "maskbit-16bit", ckpt_path, yaml_url) + elif "18bit" in model_name.lower(): + yaml_url = "https://raw.githubusercontent.com/markweberdev/maskbit/refs/heads/main/configs/tokenizer/maskbit_tokenizer_18bit.yaml" + ckpt_path = "https://huggingface.co/markweber/maskbit_tokenizer_18bit/resolve/main/maskbit_tokenizer_18bit.bin" + if model_path is not None: + ckpt_path, yaml_url = download_ckpt_yaml(model_path, "maskbit-18bit", ckpt_path, yaml_url) + else: + raise Exception(f"Unsupported model: {model_name}") + + config = get_yaml_config(yaml_url) + model = ConvVQModel(config.model.vq_model, legacy=False) + model.load_pretrained(get_ckpt(ckpt_path, key=None)) + data_params = { + "target_image_size": (256, 256), + "lock_ratio": True, + "center_crop": True, + "padding": False, + "standardize": False, + } + + elif "bsqvit" in model_name.lower(): + from src.vqvaes.bsqvit.bsqvit import VITBSQModel + + yaml_url = "https://huggingface.co/huaweilin/bsqvit_256x256/resolve/main/config.yaml" + ckpt_path = "https://huggingface.co/huaweilin/bsqvit_256x256/resolve/main/checkpoint.pt" + if model_path is not None: + ckpt_path, yaml_url = download_ckpt_yaml(model_path, "bsqvit", ckpt_path, yaml_url) + + config = get_yaml_config(yaml_url) + model = VITBSQModel(**config["model"]["params"]) + model.init_from_ckpt(get_ckpt(ckpt_path, key="state_dict")) + data_params = { + "target_image_size": (256, 256), + "lock_ratio": True, + "center_crop": True, + "padding": False, + "standardize": False, + } + + elif "titok" in model_name.lower(): + from src.vqvaes.titok.titok import TiTok + + ckpt_path = None + if "bl64" in model_name.lower(): + ckpt_path = "yucornetto/tokenizer_titok_bl64_vq8k_imagenet" + elif "bl128" in model_name.lower(): + ckpt_path = "yucornetto/tokenizer_titok_bl128_vq8k_imagenet" + elif "sl256" in model_name.lower(): + ckpt_path = "yucornetto/tokenizer_titok_sl256_vq8k_imagenet" + elif "l32" in model_name.lower(): + ckpt_path = "yucornetto/tokenizer_titok_l32_imagenet" + elif "b64" in model_name.lower(): + ckpt_path = "yucornetto/tokenizer_titok_b64_imagenet" + elif "s128" in model_name.lower(): + ckpt_path = "yucornetto/tokenizer_titok_s128_imagenet" + else: + raise Exception(f"Unsupported model: {model_name}") + + model = TiTok.from_pretrained(ckpt_path) + data_params = { + "target_image_size": (256, 256), + "lock_ratio": True, + "center_crop": True, + "padding": False, + "standardize": False, + } + + elif "janus_pro" in model_name.lower(): + from janus.models import MultiModalityCausalLM + from src.vqvaes.janus_pro.janus_pro import forward + import types + + model = MultiModalityCausalLM.from_pretrained( + "deepseek-ai/Janus-Pro-7B", trust_remote_code=True + ).gen_vision_model + model.forward = types.MethodType(forward, model) + data_params = { + "target_image_size": (384, 384), + "lock_ratio": True, + "center_crop": False, + "padding": True, + } + + elif "var" in model_name.lower(): + from src.vqvaes.var.var_vq import VQVAE + + ckpt_path = "https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth" + if model_path is not None: + ckpt_path, _ = download_ckpt_yaml(model_path, "var", ckpt_path, None) + + v_patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16) + if "512" in model_name.lower(): + v_patch_nums = (1, 2, 3, 4, 6, 9, 13, 18, 24, 32) + model = VQVAE( + vocab_size=4096, + z_channels=32, + ch=160, + test_mode=True, + share_quant_resi=4, + v_patch_nums=v_patch_nums, + ) + model.load_state_dict(get_ckpt(ckpt_path, key=None)) + data_params = { + "target_image_size": ( + (512, 512) if "512" in model_name.lower() else (256, 256) + ), + "lock_ratio": True, + "center_crop": False, + "padding": True, + "standardize": False, + } + + elif ( + "infinity" in model_name.lower() + ): # "infinity_d32", "infinity_d64", "infinity_d56_f8_14_patchify" + from src.vqvaes.infinity.vae import vae_model + + if "d32" in model_name: + ckpt_path = "https://huggingface.co/FoundationVision/Infinity/resolve/main/infinity_vae_d32.pth" + codebook_dim = 32 + if model_path is not None: + ckpt_path, _ = download_ckpt_yaml(model_path, "infinity-d32", ckpt_path, None) + elif "d64" in model_name: + ckpt_path = "https://huggingface.co/FoundationVision/Infinity/resolve/main/infinity_vae_d64.pth" + codebook_dim = 64 + if model_path is not None: + ckpt_path, _ = download_ckpt_yaml(model_path, "infinity-d64", ckpt_path, None) + + schedule_mode = "dynamic" + codebook_size = 2**codebook_dim + patch_size = 16 + encoder_ch_mult = [1, 2, 4, 4, 4] + decoder_ch_mult = [1, 2, 4, 4, 4] + + ckpt = get_ckpt(ckpt_path, key=None) + model = vae_model( + ckpt, + schedule_mode, + codebook_dim, + codebook_size, + patch_size=patch_size, + encoder_ch_mult=encoder_ch_mult, + decoder_ch_mult=decoder_ch_mult, + test_mode=True, + ) + + data_params = { + "target_image_size": (1024, 1024), + "lock_ratio": True, + "center_crop": False, + "padding": True, + "standardize": False, + } + + elif "sd3.5l" in model_name.lower(): # SD3.5L + from src.vaes.stable_diffusion.vae import forward + from diffusers import AutoencoderKL + import types + + model = AutoencoderKL.from_pretrained( + "huaweilin/stable-diffusion-3.5-large-vae", subfolder="vae" + ) + model.forward = types.MethodType(forward, model) + data_params = { + "target_image_size": (1024, 1024), + "lock_ratio": True, + "center_crop": False, + "padding": True, + "standardize": True, + } + + elif "FLUX.1-dev".lower() in model_name.lower(): # SD3.5L + from src.vaes.stable_diffusion.vae import forward + from diffusers import AutoencoderKL + import types + + model = AutoencoderKL.from_pretrained( + "black-forest-labs/FLUX.1-dev", subfolder="vae" + ) + model.forward = types.MethodType(forward, model) + data_params = { + "target_image_size": (1024, 1024), + "lock_ratio": True, + "center_crop": False, + "padding": True, + "standardize": True, + } + + elif "gpt4o" in model_name.lower(): + from src.vaes.gpt_image.gpt_image import GPTImage + + data_params = { + "target_image_size": (1024, 1024), + "lock_ratio": True, + "center_crop": False, + "padding": True, + "standardize": False, + } + model = GPTImage(data_params) + + else: + raise Exception(f"Unsupported model: \"{model_name}\"") + + try: + trainable_params = sum(p.numel() for p in model.parameters()) + print("trainable_params:", trainable_params) + except Exception as e: + print(e) + pass + + model.eval() + return model, data_params diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6723621c4bd02a8de1f10d589d644e1c9141b9ad --- /dev/null +++ b/src/utils.py @@ -0,0 +1,47 @@ +import os +from omegaconf import OmegaConf +import torch +import tempfile +from safetensors.torch import load_file +import requests +import yaml + +def get_ckpt(path, key="state_dict"): + is_url = path.startswith("http://") or path.startswith("https://") + suffix = os.path.splitext(path)[-1] + + if is_url: + print(f"Loading checkpoint from URL: {path}") + with tempfile.NamedTemporaryFile(suffix=suffix) as tmp_file: + response = requests.get(path) + response.raise_for_status() + tmp_file.write(response.content) + tmp_file.flush() + ckpt_path = tmp_file.name + + if suffix == ".safetensors": + checkpoint = load_file(ckpt_path) + else: + checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) + else: + print(f"Loading checkpoint from local path: {path}") + if suffix == ".safetensors": + checkpoint = load_file(path) + else: + checkpoint = torch.load(path, map_location="cpu", weights_only=False) + + if key is not None and key in checkpoint: + checkpoint = checkpoint[key] + + return checkpoint + + +def get_yaml_config(path): + if path.startswith("http://") or path.startswith("https://"): + response = requests.get(path) + response.raise_for_status() + config = OmegaConf.create(response.text) + else: + with open(path, 'r') as f: + config = OmegaConf.load(f) + return config diff --git a/src/vaes/gpt_image/gpt_image.py b/src/vaes/gpt_image/gpt_image.py new file mode 100644 index 0000000000000000000000000000000000000000..5a23527da558b0e63189333b8793bc4efd3e5176 --- /dev/null +++ b/src/vaes/gpt_image/gpt_image.py @@ -0,0 +1,48 @@ +import base64 +from torchvision.transforms.functional import to_pil_image +from openai import OpenAI +import io +import torch +import numpy as np +from PIL import Image +from ...data_processing import tensor_to_pil, pil_to_tensor + + +class GPTImage: + def __init__(self, data_params): + self.client = OpenAI(organization="org-xZTnLOf1k9s04LEoKKjl4jOB") + self.prompt = "Please recreate the exact same image without any alterations. Please preserve the original resolution (1024*1024)." + self.data_params = data_params + + def eval(self): + pass + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward(self, input): + results = [] + for image in input: + image = tensor_to_pil(image, **self.data_params) + buffer = io.BytesIO() + image.save(buffer, format="PNG") + buffer.seek(0) + image_file = ("image.png", buffer, "image/png") + + try: + result = self.client.images.edit( + model="gpt-image-1", + image=image_file, + prompt=self.prompt, + n=1, + size="1024x1024", + ) + image_base64 = result.data[0].b64_json + image_bytes = base64.b64decode(image_base64) + image = Image.open(io.BytesIO(image_bytes)) + results.append(pil_to_tensor(image, **self.data_params)) + except Exception as e: + print("💥 Unexpected error occurred:", e) + results.append(None) + + return results, None, None diff --git a/src/vaes/stable_diffusion/vae.py b/src/vaes/stable_diffusion/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa855c6f553a23560264b8fb0b8a4eec80cbc27 --- /dev/null +++ b/src/vaes/stable_diffusion/vae.py @@ -0,0 +1,23 @@ +def forward( + self, + sample, + sample_posterior=False, + return_dict=True, + generator=None, +): + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + return dec, None, None diff --git a/src/vqvaes/__init__.py b/src/vqvaes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/vqvaes/anole/anole.py b/src/vqvaes/anole/anole.py new file mode 100644 index 0000000000000000000000000000000000000000..581528a8c93c86d410abe4de94fe6ff569fb39dc --- /dev/null +++ b/src/vqvaes/anole/anole.py @@ -0,0 +1,706 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# This source code is licensed under the Chameleon License found in the +# LICENSE file in the root directory of this source tree. + +""" +Contents of this file are taken from https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/models/vqgan.py +[with minimal dependencies] + +This implementation is inference-only -- training steps and optimizer components +introduce significant additional dependencies +""" + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...utils import get_ckpt + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__( + self, + n_e, + e_dim, + beta, + remap=None, + unknown_index="random", + sane_index_shape=False, + legacy=True, + ): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( + device=new.device + ) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits is False, "Only for interface compatible with Gumbel" + assert return_logits is False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", z_flattened, self.embedding.weight.transpose(0, 1) + ) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean( + (z_q - z.detach()) ** 2 + ) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean( + (z_q - z.detach()) ** 2 + ) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape( + z.shape[0], -1 + ) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3] + ) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +# Alias +VectorQuantizer = VectorQuantizer2 + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise ValueError("Unexpected attention type") + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class VQModel(nn.Module): + def __init__( + self, + ddconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.quantize = VectorQuantizer( + n_embed, + embed_dim, + beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape, + ) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.image_key = image_key + if colorize_nlabels is not None: + assert isinstance(colorize_nlabels, int) + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + def init_from_ckpt(self, path, ignore_keys=list()): + if path.startswith("http://") or path.startswith("https://"): + sd = get_ckpt(path) + else: + print(f"Loading checkpoint from local path: {path}") + sd = torch.load(path, map_location="cpu")["state_dict"] + + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print(f"Deleting key {k} from state_dict.") + del sd[k] + + self.load_state_dict(sd, strict=False) + print(f"VQModel loaded from {path}") + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + # def forward(self, input): + # quant, diff, _ = self.encode(input) + # dec = self.decode(quant) + # return dec, diff + + def forward(self, input): + quant, diff, [_, _, img_toks] = self.encode(input) + + batch_size, n_channel, height, width = ( + input.shape[0], + quant.shape[-1], + quant.shape[-2], + quant.shape[-3], + ) + codebook_entry = self.quantize.get_codebook_entry( + img_toks, (batch_size, n_channel, height, width) + ) + pixels = self.decode(codebook_entry) + + return pixels, img_toks, quant + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) + return x.float() + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + return log + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x diff --git a/src/vqvaes/bsqvit/attention_mask.py b/src/vqvaes/bsqvit/attention_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..edc24dc4b026288f59c315b5d92f5e273335aef3 --- /dev/null +++ b/src/vqvaes/bsqvit/attention_mask.py @@ -0,0 +1,42 @@ +import torch + + +def get_attention_mask(sequence_length, device, mask_type="block-causal", **kwargs): + if mask_type.lower() == 'none' or mask_type is None: + return None + elif mask_type.lower() == 'block-causal': + return _block_caulsal_mask_impl(sequence_length, device, **kwargs) + elif mask_type.lower() == 'causal': + return _caulsal_mask_impl(sequence_length, device, **kwargs) + else: + raise NotImplementedError(f"Mask type {mask_type} not implemented") + + +def _block_caulsal_mask_impl(sequence_length, device, block_size=16, **kwargs): + """ + Create a block-causal mask + """ + assert sequence_length % block_size == 0, "for block causal masks sequence length must be divisible by block size" + blocks = torch.ones(sequence_length // block_size, block_size, block_size, device=device) + block_diag_enable_mask = torch.block_diag(*blocks) + causal_enable_mask = torch.ones(sequence_length, sequence_length, device=device).tril_(0) + disable_mask = ((block_diag_enable_mask + causal_enable_mask) < 0.5) + return disable_mask + + +def _caulsal_mask_impl(sequence_length, device, **kwargs): + """ + Create a causal mask + """ + causal_disable_mask = torch.triu( + torch.full((sequence_length, sequence_length), float('-inf'), dtype=torch.float32, device=device), + diagonal=1, + ) + return causal_disable_mask + + +if __name__ == '__main__': + mask = get_attention_mask(9, "cuda", mask_type="block-causal", block_size=3) + print(mask) + mask = get_attention_mask(9, "cuda", mask_type="causal") + print(mask) \ No newline at end of file diff --git a/src/vqvaes/bsqvit/bsqvit.py b/src/vqvaes/bsqvit/bsqvit.py new file mode 100644 index 0000000000000000000000000000000000000000..56f9a3e97c5100b19a2cb5f5336a45375b654bfc --- /dev/null +++ b/src/vqvaes/bsqvit/bsqvit.py @@ -0,0 +1,150 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .quantizer.bsq import BinarySphericalQuantizer +from .quantizer.vq import VectorQuantizer +from .transformer import TransformerDecoder, TransformerEncoder + + +class VITVQModel(nn.Module): + def __init__(self, vitconfig, n_embed, embed_dim, + l2_norm=False, logit_laplace=False, ckpt_path=None, ignore_keys=[], + grad_checkpointing=False, selective_checkpointing=False, + clamp_range=(0, 1), + dvitconfig=None, + ): + super().__init__() + self.encoder = TransformerEncoder(**vitconfig) + dvitconfig = vitconfig if dvitconfig is None else dvitconfig + self.decoder = TransformerDecoder(**dvitconfig, logit_laplace=logit_laplace) + if self.training and grad_checkpointing: + self.encoder.set_grad_checkpointing(True, selective=selective_checkpointing) + self.decoder.set_grad_checkpointing(True, selective=selective_checkpointing) + + self.n_embed = n_embed + self.embed_dim = embed_dim + self.l2_norm = l2_norm + self.setup_quantizer() + + self.quant_embed = nn.Linear(in_features=vitconfig['width'], out_features=embed_dim) + self.post_quant_embed = nn.Linear(in_features=embed_dim, out_features=dvitconfig['width']) + self.l2_norm = l2_norm + self.logit_laplace = logit_laplace + self.clamp_range = clamp_range + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def setup_quantizer(self): + self.quantize = VectorQuantizer(self.n_embed, self.embed_dim, l2_norm=self.l2_norm, beta=0.25, input_format='blc') + + # def init_from_ckpt(self, ckpt_path, ignore_keys=[]): + def init_from_ckpt(self, state_dict, ignore_keys=[]): + state_dict = {k[7:]: v for k, v in state_dict.items() if k.startswith('module.')} + filtered_state_dict = {k: v for k, v in state_dict.items() if all([not k.startswith(ig) for ig in ignore_keys])} + missing_keys, unexpected_keys = self.load_state_dict(filtered_state_dict, strict=False) + print(f"missing_keys: {missing_keys}") + print(f"unexpected_keys: {unexpected_keys}") + + def encode(self, x, skip_quantize=False): + h = self.encoder(x) + h = self.quant_embed(h) + if skip_quantize: + assert not self.training, 'skip_quantize should be used in eval mode only.' + if self.l2_norm: + h = F.normalize(h, dim=-1) + return h, {}, {} + quant, loss, info = self.quantize(h) + return quant, loss, info + + def decode(self, quant): + h = self.post_quant_embed(quant) + x = self.decoder(h) + return x + + def clamp(self, x): + if self.logit_laplace: + dec, _ = x.chunk(2, dim=1) + x = self.logit_laplace_loss.unmap(F.sigmoid(dec)) + else: + x = x.clamp_(self.clamp_range[0], self.clamp_range[1]) + return x + + def forward(self, input, skip_quantize=False): + if self.logit_laplace: + input = self.logit_laplace_loss.inmap(input) + quant, loss, info = self.encode(input, skip_quantize=skip_quantize) + dec = self.decode(quant) + if self.logit_laplace: + dec, lnb = dec.chunk(2, dim=1) + logit_laplace_loss = self.logit_laplace_loss(dec, lnb, input) + info.update({'logit_laplace_loss': logit_laplace_loss}) + dec = self.logit_laplace_loss.unmap(F.sigmoid(dec)) + else: + dec = dec.clamp_(self.clamp_range[0], self.clamp_range[1]) + return dec, loss, info + + def get_last_layer(self): + return self.decoder.conv_out.weight + + +class VITBSQModel(VITVQModel): + def __init__(self, vitconfig, embed_dim, embed_group_size=9, + l2_norm=False, logit_laplace=False, ckpt_path=None, ignore_keys=[], + grad_checkpointing=False, selective_checkpointing=False, + clamp_range=(0, 1), + dvitconfig=None, beta=0., gamma0=1.0, gamma=1.0, zeta=1.0, + persample_entropy_compute='group', + cb_entropy_compute='group', + post_q_l2_norm=False, + inv_temperature=1., + ): + # set quantizer params + self.beta = beta # commit loss + self.gamma0 = gamma0 # entropy + self.gamma = gamma # entropy penalty + self.zeta = zeta # lpips + self.embed_group_size = embed_group_size + self.persample_entropy_compute = persample_entropy_compute + self.cb_entropy_compute = cb_entropy_compute + self.post_q_l2_norm = post_q_l2_norm + self.inv_temperature = inv_temperature + + # call init + super().__init__( + vitconfig, + 2 ** embed_dim, + embed_dim, + l2_norm=l2_norm, + logit_laplace=logit_laplace, + ckpt_path=ckpt_path, + ignore_keys=ignore_keys, + grad_checkpointing=grad_checkpointing, + selective_checkpointing=selective_checkpointing, + clamp_range=clamp_range, + dvitconfig=dvitconfig, + ) + + + def setup_quantizer(self): + self.quantize = BinarySphericalQuantizer( + self.embed_dim, self.beta, self.gamma0, self.gamma, self.zeta, + group_size=self.embed_group_size, + persample_entropy_compute=self.persample_entropy_compute, + cb_entropy_compute=self.cb_entropy_compute, + input_format='blc', + l2_norm=self.post_q_l2_norm, + inv_temperature=self.inv_temperature, + ) + + def encode(self, x, skip_quantize=False): + h = self.encoder(x) + h = self.quant_embed(h) + if self.l2_norm: + h = F.normalize(h, dim=-1) + if skip_quantize: + assert not self.training, 'skip_quantize should be used in eval mode only.' + return h, {}, {} + quant, loss, info = self.quantize(h) + return quant, loss, info diff --git a/src/vqvaes/bsqvit/quantizer/bsq.py b/src/vqvaes/bsqvit/quantizer/bsq.py new file mode 100644 index 0000000000000000000000000000000000000000..5742793b75d74a331d0494df62cef427e016a173 --- /dev/null +++ b/src/vqvaes/bsqvit/quantizer/bsq.py @@ -0,0 +1,223 @@ +from einops import rearrange, reduce +import torch +import torch.nn as nn +from torch.autograd import Function + + +class DifferentiableEntropyFunction(Function): + @staticmethod + def forward(ctx, zq, basis, K, eps): + zb = (zq + 1) / 2 + zi = ((zb * basis).sum(-1)).to(torch.int64) + cnt = torch.scatter_reduce(torch.zeros(2**K, device=zq.device, dtype=zq.dtype), + 0, + zi.flatten(), + torch.ones_like(zi.flatten()).to(zq.dtype), + 'sum') + prob = (cnt + eps) / (cnt + eps).sum() + H = -(prob * torch.log(prob)).sum() + ctx.save_for_backward(zq, zi, prob) + ctx.K = K + return H + + @staticmethod + def backward(ctx, grad_output): + zq, zi, prob= ctx.saved_tensors + grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K + reord_grad = grad_array[zi.flatten()].reshape(zi.shape) + grad_input = reord_grad.unsqueeze(-1) * zq + return grad_input, None, None, None, None + + +def codebook_entropy(zq, basis, K, eps=1e-4): + return DifferentiableEntropyFunction.apply(zq, basis, K, eps) + + +class BinarySphericalQuantizer(nn.Module): + def __init__(self, embed_dim, beta, gamma0, gamma, zeta, + input_format='bchw', + soft_entropy=True, group_size=9, + persample_entropy_compute='group', + cb_entropy_compute='group', + l2_norm=False, + inv_temperature=1): + super().__init__() + self.embed_dim = embed_dim + self.beta = beta # loss weight for commit loss + self.gamma0 = gamma0 # loss weight for entropy penalty + self.gamma = gamma # loss weight for entropy penalty + self.zeta = zeta # loss weight for entire entropy penalty + self.input_format = input_format + assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size" + self.num_groups = self.embed_dim // group_size + self.group_size = group_size + assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'" + assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'" + self.persample_entropy_compute = persample_entropy_compute + self.cb_entropy_compute = cb_entropy_compute + self.l2_norm = l2_norm + self.inv_temperature = inv_temperature + + self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1)) + self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1)) + + self.num_dimensions = 2 ** embed_dim + self.bits_per_index = embed_dim + + # we only need to keep the codebook portion up to the group size + # because we approximate the H loss with this subcode + group_codes = torch.arange(2 ** self.group_size) + group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:] + self.register_buffer('group_codebook', group_codebook, persistent=False) + + self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf + + def quantize(self, z): + assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}" + + zhat = torch.where(z > 0, + torch.tensor(1, dtype=z.dtype, device=z.device), + torch.tensor(-1, dtype=z.dtype, device=z.device)) + return z + (zhat - z).detach() + + def forward(self, z): + if self.input_format == 'bchw': + z = rearrange(z, 'b c h w -> b h w c') + zq = self.quantize(z) + + indices = self.codes_to_indexes(zq.detach()) + group_indices = self.codes_to_group_indexes(zq.detach()) + if not self.training: + used_codes = torch.unique(indices, return_counts=False) + else: + used_codes = None + + q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. + + if self.soft_entropy: + persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z) + entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy + else: + zb_by_sample= ((zq + 1)/2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32) + persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample) + cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim) + entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy + + zq = zq * q_scale + + # commit loss + commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1)) + + if self.input_format == 'bchw': + zq = rearrange(zq, 'b h w c -> b c h w') + + return ( + zq, + commit_loss + self.zeta * entropy_penalty / self.inv_temperature, + {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices, + "avg_prob": avg_prob} + ) + + def soft_entropy_loss(self, z): + # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size + # the sub-code is the last group_size bits of the full code + group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1) + divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size) + + # we calculate the distance between the divided_z and the codebook for each subgroup + distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book) + prob = (-distance * self.inv_temperature).softmax(dim = -1) + if self.persample_entropy_compute == 'analytical': + if self.l2_norm: + p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature) + else: + p = torch.sigmoid(-4 * z * self.inv_temperature) + prob = torch.stack([p, 1-p], dim=-1) + per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() + else: + per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() + + # macro average of the probability of each subgroup + avg_prob = reduce(prob, '... g d ->g d', 'mean') + codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) + + # the approximation of the entropy is the sum of the entropy of each subgroup + return per_sample_entropy, codebook_entropy.sum(), avg_prob + + def get_hard_per_sample_entropy(self, zb_by_sample): + probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1] + persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8) + persample_entropy = persample_entropy.sum(-1) + return persample_entropy.mean() + + def codes_to_indexes(self, zhat): + """Converts a `code` to an index in the codebook. + Args: + zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} + """ + assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}" + return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64) + + def codes_to_group_indexes(self, zhat): + """Converts a `code` to a list of indexes (in groups) in the codebook. + Args: + zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} + """ + zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size) + return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64) + + def indexes_to_codes(self, indices): + """Inverse of `indexes_to_codes`.""" + indices = indices.unsqueeze(-1) + codes_non_centered = torch.remainder( + torch.floor_divide(indices, self.basis), 2 + ) + return codes_non_centered * 2 - 1 + + def group_indexes_to_codes(self, group_indices): + """Inverse of `group_indexes_to_codes`.""" + group_indices = group_indices.unsqueeze(-1) + codes_non_centered = torch.remainder( + torch.floor_divide(group_indices, self.group_basis), 2 + ) + codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)') + return codes_non_centered * 2 - 1 + + def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): + if normalize: + probs = (count + eps) / (count + eps).sum(dim=dim, keepdim =True) + else: + probs = count + H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) + return H + + def get_group_codebook_entry(self, group_indices): + z_q = self.group_indexes_to_codes(group_indices) + q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. + z_q = z_q * q_scale + if self.input_format == 'bchw': + h, w = int(z_q.shape[1] ** 0.5) + assert h * w == z_q.shape[1], 'Invalid sequence length' + z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) + return z_q + + def get_codebook_entry(self, indices): + z_q = self.indexes_to_codes(indices) + q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. + z_q = z_q * q_scale + if self.input_format == 'bchw': + h, w = int(z_q.shape[1] ** 0.5) + assert h * w == z_q.shape[1], 'Invalid sequence length' + z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) + return z_q + + +if __name__ == "__main__": + K = 8 + # zq = torch.randint(0, 2, (4, 32, K), dtype=torch.bfloat16, device='cuda') * 2 - 1 + zq = torch.zeros((4, 32, K), dtype=torch.bfloat16, device='cuda') * 2 - 1 + basis = (2 ** torch.arange(K - 1, -1, -1)).to(torch.bfloat16).cuda() + zq.requires_grad = True + h = codebook_entropy(zq, basis, K) + h.backward() + print(zq.grad, zq) diff --git a/src/vqvaes/bsqvit/quantizer/vq.py b/src/vqvaes/bsqvit/quantizer/vq.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b49ff7c5f61a2b303f1886a6219c573b2fd966 --- /dev/null +++ b/src/vqvaes/bsqvit/quantizer/vq.py @@ -0,0 +1,152 @@ +from einops import rearrange +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + + +class VectorQuantizer(nn.Module): + def __init__(self, n_embed, embed_dim, l2_norm, beta, input_format='bchw'): + super().__init__() + + self.n_embed = n_embed + self.embed_dim = embed_dim + self.l2_norm = l2_norm + self.beta = beta + assert input_format in ['bchw', 'blc'] + self.input_format = input_format + + self.embedding = nn.Embedding(n_embed, embed_dim) + self.embedding.weight.data.uniform_(-1 / n_embed, 1 / n_embed) + self.bits_per_index = int(np.ceil(np.log2(n_embed))) + + def forward(self, z): + batch = z.shape[0] + if self.input_format == 'bchw': + z = rearrange(z, 'b c h w -> b h w c') + + if self.l2_norm: + z = F.normalize(z, dim=-1) + z_flatten = z.reshape(-1, self.embed_dim) + embedding_weight = F.normalize(self.embedding.weight, dim=-1) + d = -z_flatten @ embedding_weight.t() + else: + z_flatten = z.reshape(-1, self.embed_dim) + d = torch.sum(z_flatten ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * z_flatten @ self.embedding.weight.t() + + min_encoding_indices = torch.argmin(d.detach(), dim=1) + if not self.training: + used_codes = torch.unique(min_encoding_indices, return_counts=False) + else: + used_codes = None + cb_usage = F.one_hot(min_encoding_indices, self.n_embed).sum(0) + cb_entropy = self.get_entropy(cb_usage) + + z_q = self.embedding(min_encoding_indices).view(z.shape) + if self.l2_norm: + z_q = F.normalize(z_q, dim=-1) + + # fix the issue with loss scaling + # loss weight should not associate with the dimensionality of words + # loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + loss = self.beta * torch.mean(((z_q.detach() - z) ** 2).sum(dim=-1)) + torch.mean(((z_q - z.detach()) ** 2).sum(dim=-1)) + + z_q = z + (z_q - z).detach() + if self.input_format == 'bchw': + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, {"H":cb_entropy, "used_codes": used_codes, 'indices': min_encoding_indices.view(batch, -1)} + + def get_entropy(self, count, eps=1e-4): + probs = (count + eps) / (count + eps).sum() + H = -(probs * torch.log(probs)).sum() + return H + + + def get_codebook_entry(self, indices): + z_q = self.embedding(indices) + if self.l2_norm: + z_q = F.normalize(z_q, dim=-1) + + if self.input_format == 'bchw': + h = w = int(z_q.shape[1] ** 0.5) + assert h * w == z_q.shape[1], 'Invalid sequence length' + z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) + return z_q + + +class EMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embed_dim, l2_norm, beta, decay=0.99, eps=1e-5, random_restart=True, restart_threshold=1.0, input_format='bchw'): + super().__init__() + + self.n_embed = n_embed + self.embed_dim = embed_dim + self.l2_norm = l2_norm + self.beta = beta + self.decay = decay + self.eps = eps + self.random_restart = random_restart + self.restart_threshold = restart_threshold + self.input_format = input_format + + self.embedding = nn.Embedding(n_embed, embed_dim) + self.embedding.weight.data.uniform_(-1 / n_embed, 1 / n_embed) # TODO (yzhao): test other initialization methods + self.register_buffer("ema_cluster_size", torch.zeros(self.n_embed)) + self.embedding_avg = nn.Parameter(torch.Tensor(self.n_embed, self.embed_dim)) + self.embedding_avg.data.copy_(self.embedding.weight.data) + + def _tile(self, z): + n_z, embedding_dim = z.shape + if n_z < self.n_embed: + n_repeats = (self.n_embed + n_z - 1) // n_z + std = 0.01 / np.sqrt(embedding_dim) + z = z.repeat(n_repeats, 1) + z = z + torch.randn_like(z) * std + return z + + def forward(self, z): + if self.input_format == 'bchw': + z = rearrange(z, 'b c h w -> b h w c') + z_flatten = z.reshape(-1, self.embed_dim) + + d = torch.sum(z_flatten ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * z_flatten @ self.embedding.weight.t() + + encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + encodings = torch.zeros(encoding_indices.size(0), self.n_embed, device=z.device) + encodings.scatter_(1, encoding_indices, 1) + + z_q = self.embedding(encoding_indices).view(z.shape) + if self.l2_norm: + z = F.normalize(z, dim=-1) + z_q = F.normalize(z_q, dim=-1) + + if self.training: + # EMA update cluster size + encodings_sum = encodings.sum(0) + if dist.is_initialized(): dist.all_reduce(encodings_sum) + self.ema_cluster_size.data.mul_(self.decay).add_(encodings_sum, alpha=1-self.decay) + + # EMA update of the embedding vectors + dw = encodings.t() @ z_flatten + if dist.is_initialized(): dist.all_reduce(dw) + self.embedding_avg.data.mul_(self.decay).add_(dw, alpha=1-self.decay) + + # Laplace smoothing of the cluster size + n = torch.sum(self.ema_cluster_size) + weights = (self.ema_cluster_size + self.eps) / (n + self.n_embed * self.eps) * n + self.embedding.weight.data = self.embedding_avg.data / weights.unsqueeze(1) + + if self.random_restart: + zz = self._tile(z_flatten) + _k_rand = zz[torch.randperm(zz.size(0))][:self.n_embed] + if dist.is_initialized(): dist.broadcast(_k_rand, 0) + usage = (self.ema_cluster_size.view(-1, 1) > self.restart_threshold).float() + self.embedding.weight.data.mul_(usage).add_(_k_rand * (1 - usage)) + + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + + z_q = z + (z_q - z).detach() + if self.input_format == 'bchw': + z_q = rearrange(z_q, 'b h w c -> b c h w') + # TODO (yzhao): monitor utility of the dictionary + return z_q, loss, {} diff --git a/src/vqvaes/bsqvit/stylegan_utils/custom_ops.py b/src/vqvaes/bsqvit/stylegan_utils/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c754c0ca49cfb91e527eaca17d8f8b445b90fa07 --- /dev/null +++ b/src/vqvaes/bsqvit/stylegan_utils/custom_ops.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import glob +import torch +import torch.utils.cpp_extension +import importlib +import hashlib +import shutil +from pathlib import Path + +from torch.utils.file_baton import FileBaton + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Compile and load. + verbose_build = (verbosity == 'full') + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + source_dirs_set = set(os.path.dirname(source) for source in sources) + if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): + all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) + + # Compute a combined hash digest for all source files in the same + # custom op directory (usually .cu, .cpp, .py and .h files). + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) + + if not os.path.isdir(digest_build_dir): + os.makedirs(digest_build_dir, exist_ok=True) + baton = FileBaton(os.path.join(digest_build_dir, 'lock')) + if baton.try_acquire(): + try: + for src in all_source_files: + shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) + finally: + baton.release() + else: + # Someone else is copying source files under the digest dir, + # wait until done and continue. + baton.wait() + digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, + verbose=verbose_build, sources=digest_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/src/vqvaes/bsqvit/stylegan_utils/misc.py b/src/vqvaes/bsqvit/stylegan_utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..57bbba993d9236a063e3d4dbabdb40fa978bdcda --- /dev/null +++ b/src/vqvaes/bsqvit/stylegan_utils/misc.py @@ -0,0 +1,40 @@ +import torch +import warnings + + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to suppress known warnings in torch.jit.trace(). + +class suppress_tracer_warnings(warnings.catch_warnings): + def __enter__(self): + super().__enter__() + warnings.simplefilter('ignore', category=torch.jit.TracerWarning) + return self + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') diff --git a/src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.cpp b/src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7370ea8c9f4a2556b147a5233cbc4672ae67fa52 --- /dev/null +++ b/src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_REDUCED_FLOATING_TYPES(x.scalar_type(), "upfirdn2d_cuda", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act", &bias_act); +} + +//------------------------------------------------------------------------ diff --git a/src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.cu b/src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.cu new file mode 100644 index 0000000000000000000000000000000000000000..06448585947d2660c07e12a039367225ddef7b7f --- /dev/null +++ b/src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.cu @@ -0,0 +1,176 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.h b/src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.h new file mode 100644 index 0000000000000000000000000000000000000000..a32187e1fb7e3bae509d4eceaf900866866875a4 --- /dev/null +++ b/src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.h @@ -0,0 +1,38 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.py b/src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..3aea3f18446dfd4d6339f500148a3e57d423afb3 --- /dev/null +++ b/src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.py @@ -0,0 +1,226 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient bias and activation.""" + +import os +import warnings +import numpy as np +import torch +import traceback +from typing import Any + +from .. import custom_ops + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None +_null_tensor = torch.empty([0]) + +def _init(): + global _inited, _plugin + if not _inited: + _inited = True + sources = ['bias_act.cpp', 'bias_act.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +#---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + +#---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +#---------------------------------------------------------------------------- diff --git a/src/vqvaes/bsqvit/stylegan_utils/ops/conv2d_gradfix.py b/src/vqvaes/bsqvit/stylegan_utils/ops/conv2d_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..ba27f9be24918fc6ed5f2bf4b868f3b68ce1e4c2 --- /dev/null +++ b/src/vqvaes/bsqvit/stylegan_utils/ops/conv2d_gradfix.py @@ -0,0 +1,170 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.conv2d` that supports +arbitrarily high order gradients with zero performance penalty.""" + +import warnings +import contextlib +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + +@contextlib.contextmanager +def no_weight_gradients(): + global weight_gradients_disabled + old = weight_gradients_disabled + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + +#---------------------------------------------------------------------------- + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + return True + warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') + return False + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + +#---------------------------------------------------------------------------- + +_conv2d_gradfix_cache = dict() + +def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + if not transpose: + output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + else: # transpose + output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + ctx.save_for_backward(input, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) + assert grad_input.shape == input.shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input): + op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') + flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] + grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) + assert grad_weight.shape == weight_shape + ctx.save_for_backward(grad_output, input) + return grad_weight + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output.shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input.shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/src/vqvaes/bsqvit/stylegan_utils/ops/conv2d_resample.py b/src/vqvaes/bsqvit/stylegan_utils/ops/conv2d_resample.py new file mode 100644 index 0000000000000000000000000000000000000000..ae27b7fc7fc43ab9f92b109606dc8f9541e2453e --- /dev/null +++ b/src/vqvaes/bsqvit/stylegan_utils/ops/conv2d_resample.py @@ -0,0 +1,155 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""2D convolution with optional up/downsampling.""" + +import torch + +from .. import misc +from . import conv2d_gradfix +from . import upfirdn2d +from .upfirdn2d import _parse_padding +from .upfirdn2d import _get_filter_size + +#---------------------------------------------------------------------------- + +def _get_weight_shape(w): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + shape = [int(sz) for sz in w.shape] + misc.assert_shape(w, shape) + return shape + +#---------------------------------------------------------------------------- + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + w = w.flip([2, 3]) + + # Workaround performance pitfall in cuDNN 8.0.5, triggered when using + # 1x1 kernel + memory_format=channels_last + less than 64 channels. + if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: + if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: + if out_channels <= 4 and groups == 1: + in_shape = x.shape + x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) + x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) + else: + x = x.to(memory_format=torch.contiguous_format) + w = w.to(memory_format=torch.contiguous_format) + x = conv2d_gradfix.conv2d(x, w, groups=groups) + return x.to(memory_format=torch.channels_last) + + # Otherwise => execute using conv2d_gradfix. + op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d + return op(x, w, stride=stride, padding=padding, groups=groups) + +#---------------------------------------------------------------------------- + +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling upfirdn2d.setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + assert isinstance(groups, int) and (groups >= 1) + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + px0, px1, py0, py1 = _parse_padding(padding) + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) + + # Fallback: Generic reference implementation. + x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + +#---------------------------------------------------------------------------- \ No newline at end of file diff --git a/src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.cpp b/src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ade900a70f6d96648395021af14f1443a66878bb --- /dev/null +++ b/src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_REDUCED_FLOATING_TYPES(x.scalar_type(), "upfirdn2d_cuda", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d", &upfirdn2d); +} + +//------------------------------------------------------------------------ diff --git a/src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.cu b/src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.cu new file mode 100644 index 0000000000000000000000000000000000000000..d6a76b59fd9bed88598c8c8ab1cc80b7061b4053 --- /dev/null +++ b/src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.cu @@ -0,0 +1,353 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.h b/src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.h new file mode 100644 index 0000000000000000000000000000000000000000..c9e2032bcac9d2abde7a75eea4d812da348afadd --- /dev/null +++ b/src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.h @@ -0,0 +1,59 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.py b/src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f34cc7a8ba69bad5d670d693968a57a51ad24c --- /dev/null +++ b/src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.py @@ -0,0 +1,382 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient resampling of 2D images.""" + +import os +import warnings +import numpy as np +import torch +import traceback + +from .. import custom_ops, misc +from . import conv2d_gradfix + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None + +def _init(): + global _inited, _plugin + if not _inited: + sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + with misc.suppress_tracer_warnings(): + fw = int(fw) + fh = int(fh) + misc.assert_shape(f, [fh, fw][:f.ndim]) + assert fw >= 1 and fh >= 1 + return fw, fh + +#---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + +#---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + +#---------------------------------------------------------------------------- + +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + +#---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + else: + y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) + y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + +#---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- diff --git a/src/vqvaes/bsqvit/transformer.py b/src/vqvaes/bsqvit/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..823f09fcf4230309f47d47e3649b2f4c560dbaeb --- /dev/null +++ b/src/vqvaes/bsqvit/transformer.py @@ -0,0 +1,416 @@ +from collections import OrderedDict +from typing import Callable, Optional, Union +from einops import rearrange +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from timm.models.layers import to_2tuple +from timm.models.layers import trunc_normal_ +from timm.models.layers import DropPath + +from .attention_mask import get_attention_mask + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + use_preln: bool = True, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head, dropout=attn_drop) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + # disable this following JAX implementation. + # Reference: https://github.com/google-research/magvit/blob/main/videogvt/models/simplified_bert.py#L112 + # ("drop1", nn.Dropout(drop)), + ("c_proj", nn.Linear(mlp_width, d_model)), + ("drop2", nn.Dropout(drop)), + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.use_preln = use_preln + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False): + attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask, is_causal=is_causal)[0] + + def checkpoint_forward(self, x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False): + state = x + if self.use_preln: + x = checkpoint(self.ln_1, x, use_reentrant=False) + x = self.attention(x, attn_mask, is_causal) + x = checkpoint(self.ls_1, x, use_reentrant=False) + state = state + self.drop_path(x) + x = checkpoint(self.ln_2, state, use_reentrant=False) + x = self.mlp(x) + x = checkpoint(self.ls_2, x, use_reentrant=False) + state = state + self.drop_path(x) + else: + x = self.attention(x, attn_mask, is_causal) + x = state + self.drop_path(x) + state = checkpoint(self.ln_1, x, use_reentrant=False) + x = self.mlp(state) + state = state + self.drop_path(x) + state = checkpoint(self.ln_2, state, use_reentrant=False) + return state + + def forward(self, x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool =False, + selective_checkpointing: bool = False): + if selective_checkpointing: + return self.checkpoint_forward(x, attn_mask, is_causal=is_causal) + if self.use_preln: + x = x + self.drop_path(self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask, is_causal=is_causal))) + x = x + self.drop_path(self.ls_2(self.mlp(self.ln_2(x)))) + else: + x = x + self.drop_path(self.attention(x, attn_mask=attn_mask, is_causal=is_causal)) + x = self.ln_1(x) + x = x + self.drop_path(self.mlp(x)) + x = self.ln_2(x) + return x + + +class Transformer(nn.Module): + def __init__(self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + use_preln: bool = True, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + self.selective_checkpointing = False + self.grad_checkpointing_params = {'use_reentrant': False} + if attn_drop == 0 and drop_path == 0 and drop_path == 0: + self.grad_checkpointing_params.update({'preserve_rng_state': False}) + else: + self.grad_checkpointing_params.update({'preserve_rng_state': True}) + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, + drop=drop, attn_drop=attn_drop, drop_path=drop_path, + act_layer=act_layer, norm_layer=norm_layer, + use_preln=use_preln) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool =False): + for r in self.resblocks: + if self.training and self.grad_checkpointing and not torch.jit.is_scripting(): + if not self.selective_checkpointing: + x = checkpoint(r, x, attn_mask, is_causal=is_causal, **self.grad_checkpointing_params) + else: + x = r(x, attn_mask=attn_mask, is_causal=is_causal, selective_checkpointing=True) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + num_frames: int = 1, + cross_frames: bool = True, + ls_init_value: float = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + ln_pre: bool = True, + ln_post: bool = True, + act_layer: str = 'gelu', + norm_layer: str = 'layer_norm', + mask_type: Union[str, None] = 'none', + mask_block_size: int = -1 + ): + super().__init__() + self.image_size = to_2tuple(image_size) + self.patch_size = to_2tuple(patch_size) + self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) + self.patches_per_frame = self.grid_size[0] * self.grid_size[1] + self.mask_type = mask_type + self.mask_block_size = mask_block_size + + if act_layer.lower() == 'gelu': + self.act_layer = nn.GELU + else: + raise ValueError(f"Unsupported activation function: {act_layer}") + if norm_layer.lower() == 'layer_norm': + self.norm_layer = nn.LayerNorm + else: + raise ValueError(f"Unsupported normalization: {norm_layer}") + + self.conv1 = nn.Linear( + in_features=3 * self.patch_size[0] * self.patch_size[1], + out_features=width, + bias=not ln_pre + ) + + scale = width ** -0.5 + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1], width)) + assert num_frames >= 1 + self.num_frames = num_frames + self.cross_frames = cross_frames + if num_frames > 1 and cross_frames: + self.temporal_positional_embedding = nn.Parameter(torch.zeros(num_frames, width)) + else: + self.temporal_positional_embedding = None + + self.ln_pre = self.norm_layer(width) if ln_pre else nn.Identity() + + self.transformer = Transformer( + width, layers, heads, mlp_ratio, ls_init_value=ls_init_value, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, + act_layer=self.act_layer, norm_layer=self.norm_layer, + ) + + self.ln_post = self.norm_layer(width) + + self.init_parameters() + + def init_parameters(self): + if self.positional_embedding is not None: + nn.init.normal_(self.positional_embedding, std=0.02) + trunc_normal_(self.conv1.weight, std=0.02) + for block in self.transformer.resblocks: + for n, p in block.named_parameters(): + if 'weight' in n: + if 'ln' not in n: + trunc_normal_(p, std=0.02) + elif 'bias' in n: + nn.init.zeros_(p) + else: + raise NotImplementedError(f'Unknown parameters named {n}') + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True, selective=False): + self.transformer.grad_checkpointing = enable + self.transformer.selective_checkpointing = selective + + + def forward(self, x): + if self.num_frames == 1: + x = rearrange( + x, "b c (hh sh) (ww sw) -> b (hh ww) (c sh sw)", + sh=self.patch_size[0], sw=self.patch_size[1] + ) + x = self.conv1(x) + x = x + self.positional_embedding.to(x.dtype) + elif self.cross_frames: + num_frames = x.shape[2] + assert num_frames <= self.num_frames, 'Number of frames should be less or equal to the model setting' + x = rearrange( + x, "b c t (hh sh) (ww sw) -> b (t hh ww) (c sh sw)", + sh=self.patch_size[0], sw=self.patch_size[1] + ) + x = self.conv1(x) + tile_pos_embed = self.positional_embedding.repeat(num_frames, 1) + tile_tem_embed = self.temporal_positional_embedding[:num_frames].repeat_interleave(self.patches_per_frame, 0) + total_pos_embed = tile_pos_embed + tile_tem_embed + x = x + total_pos_embed.to(x.dtype).squeeze(0) + else: + x = rearrange( + x, "b c t (hh sh) (ww sw) -> (b t) (hh ww) (c sh sw)", + sh=self.patch_size[0], sw=self.patch_size[1] + ) + x = self.conv1(x) + x = x + self.positional_embedding.to(x.dtype) + + x = self.ln_pre(x) + x = x.permute(1, 0, 2) + block_size = self.grid_size[0] * self.grid_size[1] if self.mask_block_size <= 0 else self.mask_block_size + attn_mask = get_attention_mask(x.size(0), x.device, mask_type=self.mask_type, block_size=block_size) + x = self.transformer(x, attn_mask, is_causal=self.mask_type == 'causal') + x = x.permute(1, 0, 2) + x = self.ln_post(x) + + return x + + +class TransformerDecoder(nn.Module): + def __init__(self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + num_frames: int = 1, + cross_frames: bool = True, + ls_init_value: float = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + ln_pre: bool = True, + ln_post: bool = True, + act_layer: str = 'gelu', + norm_layer: str = 'layer_norm', + use_ffn_output: bool = True, + dim_ffn_output: int = 3072, + logit_laplace: bool = False, + mask_type: Union[str, None] = 'none', + mask_block_size: int = -1 + ): + super().__init__() + self.image_size = to_2tuple(image_size) + self.patch_size = to_2tuple(patch_size) + self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) + self.patches_per_frame = self.grid_size[0] * self.grid_size[1] + self.mask_type = mask_type + self.mask_block_size = mask_block_size + + if act_layer.lower() == 'gelu': + self.act_layer = nn.GELU + else: + raise ValueError(f"Unsupported activation function: {act_layer}") + if norm_layer.lower() == 'layer_norm': + self.norm_layer = nn.LayerNorm + else: + raise ValueError(f"Unsupported normalization: {norm_layer}") + + self.use_ffn_output = use_ffn_output + if use_ffn_output: + self.ffn = nn.Sequential( + nn.Linear(width, dim_ffn_output), + nn.Tanh(), + ) + self.conv_out = nn.Linear( + in_features=dim_ffn_output, + out_features=3 * self.patch_size[0] * self.patch_size[1] * (1 + logit_laplace) + ) + else: + self.ffn = nn.Identity() + self.conv_out = nn.Linear( + in_features=width, + out_features=3 * self.patch_size[0] * self.patch_size[1] * (1 + logit_laplace) + ) + + scale = width ** -0.5 + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1], width)) + assert num_frames >= 1 + self.num_frames = num_frames + self.cross_frames = cross_frames + if num_frames > 1 and cross_frames: + self.temporal_positional_embedding = nn.Parameter(torch.zeros(num_frames, width)) + else: + self.temporal_positional_embedding = None + + self.ln_pre = self.norm_layer(width) if ln_pre else nn.Identity() + + self.transformer = Transformer( + width, layers, heads, mlp_ratio, ls_init_value=ls_init_value, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, + act_layer=self.act_layer, norm_layer=self.norm_layer, + ) + + self.ln_post = self.norm_layer(width) if ln_post else nn.Identity() + + self.init_parameters() + + def init_parameters(self): + if self.positional_embedding is not None: + nn.init.normal_(self.positional_embedding, std=0.02) + + for block in self.transformer.resblocks: + for n, p in block.named_parameters(): + if 'weight' in n: + if 'ln' not in n: + trunc_normal_(p, std=0.02) + elif 'bias' in n: + nn.init.zeros_(p) + else: + raise NotImplementedError(f'Unknown parameters named {n}') + if self.use_ffn_output: + trunc_normal_(self.ffn[0].weight, std=0.02) + trunc_normal_(self.conv_out.weight, std=0.02) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True, selective=False): + self.transformer.grad_checkpointing = enable + self.transformer.selective_checkpointing = selective + + def forward(self, x): + if self.num_frames == 1 or not self.cross_frames: + x = x + self.positional_embedding.to(x.dtype) + else: + num_frames = x.shape[1] // self.patches_per_frame + assert num_frames <= self.num_frames, 'Number of frames should be less or equal to the model setting' + tile_pos_embed = self.positional_embedding.repeat(num_frames, 1) + tile_tem_embed = self.temporal_positional_embedding[:num_frames].repeat_interleave(self.patches_per_frame, 0) + total_pos_embed = tile_pos_embed + tile_tem_embed + x = x + total_pos_embed.to(x.dtype).squeeze(0) + x = self.ln_pre(x) + x = x.permute(1, 0, 2) + block_size = self.grid_size[0] * self.grid_size[1] if self.mask_block_size <= 0 else self.mask_block_size + attn_mask = get_attention_mask(x.size(0), x.device, mask_type=self.mask_type, block_size=block_size) + x = self.transformer(x, attn_mask, is_causal=self.mask_type == 'causal') + x = x.permute(1, 0, 2) + x = self.ln_post(x) + x = self.ffn(x) + x = self.conv_out(x) + if self.num_frames == 1: + x = rearrange( + x, "b (hh ww) (c sh sw) -> b c (hh sh) (ww sw)", + hh = self.grid_size[0], ww=self.grid_size[1], + sh=self.patch_size[0], sw=self.patch_size[1] + ) + elif self.cross_frames: + x = rearrange( + x, "b (t hh ww) (c sh sw) -> b c t (hh sh) (ww sw)", + t = num_frames, hh = self.grid_size[0], ww=self.grid_size[1], + sh=self.patch_size[0], sw=self.patch_size[1] + ) + else: + x = rearrange( + x, "(b t) (hh ww) (c sh sw) -> b c t (hh sh) (ww sw)", + t = num_frames, hh = self.grid_size[0], ww=self.grid_size[1], + sh=self.patch_size[0], sw=self.patch_size[1] + ) + + return x diff --git a/src/vqvaes/flowmo/flowmo.py b/src/vqvaes/flowmo/flowmo.py new file mode 100644 index 0000000000000000000000000000000000000000..09d0bc8a7f68845336a19b80d037e7f1b528a035 --- /dev/null +++ b/src/vqvaes/flowmo/flowmo.py @@ -0,0 +1,945 @@ +"""Model code for FlowMo. + +Sources: https://github.com/feizc/FluxMusic/blob/main/train.py +https://github.com/black-forest-labs/flux/tree/main/src/flux +""" + +import ast +import itertools +import math +from dataclasses import dataclass +from typing import List, Tuple + +import einops +import torch +from einops import rearrange, repeat +from mup import MuReadout +from torch import Tensor, nn +import argparse +import contextlib +import copy +import glob +import os +import subprocess +import tempfile +import time + +import fsspec +import psutil +import torch +import torch.distributed as dist +from mup import MuReadout, set_base_shapes +from omegaconf import OmegaConf +from torch.utils.data import DataLoader + +from .lookup_free_quantize import LFQ + +MUP_ENABLED = True + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + b, h, l, d = q.shape + q, k = apply_rope(q, k, pe) + + if torch.__version__ == "2.0.1+cu117": # tmp workaround + if d != 64: + print("MUP is broken in this setting! Be careful!") + x = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + ) + else: + x = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + scale=8.0 / d if MUP_ENABLED else None, + ) + assert x.shape == q.shape + x = rearrange(x, "B H L D -> B L (H D)") + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack( + [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], + dim=-1, + ) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +def _get_diagonal_gaussian(parameters): + mean, logvar = torch.chunk(parameters, 2, dim=1) + logvar = torch.clamp(logvar, -30.0, 20.0) + return mean, logvar + + +def _sample_diagonal_gaussian(mean, logvar): + std = torch.exp(0.5 * logvar) + x = mean + std * torch.randn(mean.shape, device=mean.device) + return x + + +def _kl_diagonal_gaussian(mean, logvar): + var = torch.exp(logvar) + return 0.5 * torch.sum(torch.pow(mean, 2) + var - 1.0 - logvar, dim=1).mean() + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor): + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + self.lin.weight[dim * 2 : dim * 3].data[:] = 0.0 + self.lin.bias[dim * 2 : dim * 3].data[:] = 0.0 + self.lin.weight[dim * 5 : dim * 6].data[:] = 0.0 + self.lin.bias[dim * 5 : dim * 6].data[:] = 0.0 + + def forward(self, vec: Tensor) -> Tuple[ModulationOut, ModulationOut]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk( + self.multiplier, dim=-1 + ) + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float, + qkv_bias: bool = False, + ): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): + pe_single, pe_double = pe + p = 1 + if vec is None: + img_mod1, img_mod2 = ModulationOut(0, 1 - p, 1), ModulationOut(0, 1 - p, 1) + txt_mod1, txt_mod2 = ModulationOut(0, 1 - p, 1), ModulationOut(0, 1 - p, 1) + else: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (p + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange( + img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (p + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange( + txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe_double) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp( + (p + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift + ) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp( + (p + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift + ) + return img, txt + + +class LastLayer(nn.Module): + def __init__( + self, + hidden_size: int, + patch_size: int, + out_channels: int, + readout_zero_init=False, + ): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + if MUP_ENABLED: + self.linear = MuReadout( + hidden_size, + patch_size * patch_size * out_channels, + bias=True, + readout_zero_init=readout_zero_init, + ) + else: + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x: Tensor, vec) -> Tensor: + if vec is None: + pass + else: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.norm_final(x) + x = self.linear(x) + return x + + +@dataclass +class FluxParams: + in_channels: int + patch_size: int + context_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + axes_dim: List[int] + theta: int + qkv_bias: bool + + +DIT_ZOO = dict( + dit_xl_4=dict( + hidden_size=1152, + mlp_ratio=4.0, + num_heads=16, + axes_dim=[8, 28, 28], + theta=10_000, + qkv_bias=True, + ), + dit_l_4=dict( + hidden_size=1024, + mlp_ratio=4.0, + num_heads=16, + axes_dim=[8, 28, 28], + theta=10_000, + qkv_bias=True, + ), + dit_b_4=dict( + hidden_size=768, + mlp_ratio=4.0, + num_heads=12, + axes_dim=[8, 28, 28], + theta=10_000, + qkv_bias=True, + ), + dit_s_4=dict( + hidden_size=384, + mlp_ratio=4.0, + num_heads=6, + axes_dim=[8, 28, 28], + theta=10_000, + qkv_bias=True, + ), + dit_mup_test=dict( + hidden_size=768, + mlp_ratio=4.0, + num_heads=12, + axes_dim=[8, 28, 28], + theta=10_000, + qkv_bias=True, + ), +) + + +def prepare_idxs(img, code_length, patch_size): + bs, c, h, w = img.shape + + img_ids = torch.zeros(h // patch_size, w // patch_size, 3, device=img.device) + img_ids[..., 1] = ( + img_ids[..., 1] + torch.arange(h // patch_size, device=img.device)[:, None] + ) + img_ids[..., 2] = ( + img_ids[..., 2] + torch.arange(w // patch_size, device=img.device)[None, :] + ) + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + txt_ids = ( + torch.zeros((bs, code_length, 3), device=img.device) + + torch.arange(code_length, device=img.device)[None, :, None] + ) + return img_ids, txt_ids + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams, name="", lsg=False): + super().__init__() + + self.name = name + self.lsg = lsg + self.params = params + self.in_channels = params.in_channels + self.patch_size = params.patch_size + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError( + f"Got {params.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim + ) + + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.txt_in = nn.Linear(params.context_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for idx in range(params.depth) + ] + ) + + self.final_layer_img = LastLayer( + self.hidden_size, 1, self.out_channels, readout_zero_init=False + ) + self.final_layer_txt = LastLayer( + self.hidden_size, 1, params.context_dim, readout_zero_init=False + ) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + ) -> Tensor: + b, c, h, w = img.shape + + img = rearrange( + img, + "b c (gh ph) (gw pw) -> b (gh gw) (ph pw c)", + ph=self.patch_size, + pw=self.patch_size, + ) + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + img = self.img_in(img) + + if timesteps is None: + vec = None + else: + vec = self.time_in(timestep_embedding(timesteps, 256)) + + txt = self.txt_in(txt) + pe_single = self.pe_embedder(torch.cat((txt_ids,), dim=1)) + pe_double = self.pe_embedder(torch.cat((txt_ids, img_ids), dim=1)) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, pe=(pe_single, pe_double), vec=vec) + + img = self.final_layer_img(img, vec=vec) + img = rearrange( + img, + "b (gh gw) (ph pw c) -> b c (gh ph) (gw pw)", + ph=self.patch_size, + pw=self.patch_size, + gh=h // self.patch_size, + gw=w // self.patch_size, + ) + + txt = self.final_layer_txt(txt, vec=vec) + return img, txt, {"final_txt": txt} + + +def get_weights_to_fix(model): + with torch.no_grad(): + for name, module in itertools.chain(model.named_modules()): + if "double_blocks" in name and isinstance(module, torch.nn.Linear): + yield name, module.weight + + +class FlowMo(nn.Module): + def __init__(self, width, config): + super().__init__() + code_length = config.model.code_length + context_dim = config.model.context_dim + enc_depth = config.model.enc_depth + dec_depth = config.model.dec_depth + + patch_size = config.model.patch_size + self.config = config + + self.image_size = config.data.image_size + self.patch_size = config.model.patch_size + self.code_length = code_length + self.dit_mode = "dit_b_4" + self.context_dim = context_dim + self.encoder_context_dim = context_dim * ( + 1 + (self.config.model.quantization_type == "kl") + ) + + if config.model.quantization_type == "lfq": + self.quantizer = LFQ( + codebook_size=2**self.config.model.codebook_size_for_entropy, + dim=self.config.model.codebook_size_for_entropy, + num_codebooks=1, + token_factorization=False, + ) + + if self.config.model.enc_mup_width is not None: + enc_width = self.config.model.enc_mup_width + else: + enc_width = width + + encoder_params = FluxParams( + in_channels=3 * patch_size**2, + context_dim=self.encoder_context_dim, + patch_size=patch_size, + depth=enc_depth, + **DIT_ZOO[self.dit_mode], + ) + decoder_params = FluxParams( + in_channels=3 * patch_size**2, + context_dim=context_dim + 1, + patch_size=patch_size, + depth=dec_depth, + **DIT_ZOO[self.dit_mode], + ) + + # width=4, dit_b_4 is the usual model + encoder_params.hidden_size = enc_width * (encoder_params.hidden_size // 4) + decoder_params.hidden_size = width * (decoder_params.hidden_size // 4) + encoder_params.axes_dim = [ + (d // 4) * enc_width for d in encoder_params.axes_dim + ] + decoder_params.axes_dim = [(d // 4) * width for d in decoder_params.axes_dim] + + self.encoder = Flux(encoder_params, name="encoder") + self.decoder = Flux(decoder_params, name="decoder") + + @torch.compile + def encode(self, img): + b, c, h, w = img.shape + + img_idxs, txt_idxs = prepare_idxs(img, self.code_length, self.patch_size) + txt = torch.zeros( + (b, self.code_length, self.encoder_context_dim), device=img.device + ) + + _, code, aux = self.encoder(img, img_idxs, txt, txt_idxs, timesteps=None) + + return code, aux + + def _decode(self, img, code, timesteps): + b, c, h, w = img.shape + + img_idxs, txt_idxs = prepare_idxs( + img, + self.code_length, + self.patch_size, + ) + pred, _, decode_aux = self.decoder( + img, img_idxs, code, txt_idxs, timesteps=timesteps + ) + return pred, decode_aux + + @torch.compile + def decode(self, *args, **kwargs): + return self._decode(*args, **kwargs) + + @torch.compile + def decode_checkpointed(self, *args, **kwargs): + # Need to compile(checkpoint), not checkpoint(compile) + assert not kwargs, kwargs + return torch.utils.checkpoint.checkpoint( + self._decode, + *args, + # WARNING: Do not use_reentrant=True with compile, it will silently + # produce incorrect gradients! + use_reentrant=False, + ) + + @torch.compile + def _quantize(self, code): + """ + Args: + code: [b codelength context dim] + + Returns: + quantized code of the same shape + """ + b, t, f = code.shape + indices = None + if self.config.model.quantization_type == "noop": + quantized = code + quantizer_loss = torch.tensor(0.0).to(code.device) + elif self.config.model.quantization_type == "kl": + # colocating features of same token before split is maybe slightly + # better? + mean, logvar = _get_diagonal_gaussian( + einops.rearrange(code, "b t f -> b (f t)") + ) + code = einops.rearrange( + _sample_diagonal_gaussian(mean, logvar), + "b (f t) -> b t f", + f=f // 2, + t=t, + ) + quantizer_loss = _kl_diagonal_gaussian(mean, logvar) + elif self.config.model.quantization_type == "lfq": + assert f % self.config.model.codebook_size_for_entropy == 0, f + code = einops.rearrange( + code, + "b t (fg fh) -> b fg (t fh)", + fg=self.config.model.codebook_size_for_entropy, + ) + + (quantized, entropy_aux_loss, indices), breakdown = self.quantizer( + code, return_loss_breakdown=True + ) + assert quantized.shape == code.shape + quantized = einops.rearrange(quantized, "b fg (t fh) -> b t (fg fh)", t=t) + + quantizer_loss = ( + entropy_aux_loss * self.config.model.entropy_loss_weight + + breakdown.commitment * self.config.model.commit_loss_weight + ) + code = quantized + else: + raise NotImplementedError + return code, indices, quantizer_loss + + # def forward( + # self, + # img, + # noised_img, + # timesteps, + # enable_cfg=True, + # ): + # aux = {} + # + # code, encode_aux = self.encode(img) + # + # aux["original_code"] = code + # + # b, t, f = code.shape + # + # code, _, aux["quantizer_loss"] = self._quantize(code) + # + # mask = torch.ones_like(code[..., :1]) + # code = torch.concatenate([code, mask], axis=-1) + # code_pre_cfg = code + # + # if self.config.model.enable_cfg and enable_cfg: + # cfg_mask = (torch.rand((b,), device=code.device) > 0.1)[:, None, None] + # code = code * cfg_mask + # + # v_est, decode_aux = self.decode(noised_img, code, timesteps) + # aux.update(decode_aux) + # + # if self.config.model.posttrain_sample: + # aux["posttrain_sample"] = self.reconstruct_checkpoint(code_pre_cfg) + # + # return v_est, aux + + def forward(self, img): + return self.reconstruct(img) + + def reconstruct_checkpoint(self, code): + with torch.autocast( + "cuda", + dtype=torch.bfloat16, + ): + bs, *_ = code.shape + + z = torch.randn((bs, 3, self.image_size, self.image_size)).cuda() + ts = ( + torch.rand((bs, self.config.model.posttrain_sample_k + 1)) + .cumsum(dim=1) + .cuda() + ) + ts = ts - ts[:, :1] + ts = (ts / ts[:, -1:]).flip(dims=(1,)) + dts = ts[:, :-1] - ts[:, 1:] + + for i, (t, dt) in enumerate((zip(ts.T, dts.T))): + if self.config.model.posttrain_sample_enable_cfg: + mask = (torch.rand((bs,), device=code.device) > 0.1)[ + :, None, None + ].to(code.dtype) + code_t = code * mask + else: + code_t = code + + vc, _ = self.decode_checkpointed(z, code_t, t) + + z = z - dt[:, None, None, None] * vc + return z + + @torch.no_grad() + def reconstruct(self, images, dtype=torch.bfloat16, code=None): + """ + Args: + images in [bchw] [-1, 1] + + Returns: + images in [bchw] [-1, 1] + """ + model = self + config = self.config.eval.sampling + + with torch.autocast( + "cuda", + dtype=dtype, + ): + bs, c, h, w = images.shape + if code is None: + x = images.cuda() + prequantized_code = model.encode(x)[0].cuda() + code, indices, _ = model._quantize(prequantized_code) + + z = torch.randn((bs, 3, h, w)).cuda() + + mask = torch.ones_like(code[..., :1]) + code = torch.concatenate([code * mask, mask], axis=-1) + + cfg_mask = 0.0 + null_code = code * cfg_mask if config.cfg != 1.0 else None + + samples = rf_sample( + model, + z, + code, + null_code=null_code, + sample_steps=config.sample_steps, + cfg=config.cfg, + schedule=config.schedule, + )[-1].clip(-1, 1) + return samples.to(torch.float32), code, prequantized_code + + +def rf_loss(config, model, batch, aux_state): + x = batch["image"] + b = x.size(0) + + if config.opt.schedule == "lognormal": + nt = torch.randn((b,)).to(x.device) + t = torch.sigmoid(nt) + elif config.opt.schedule == "fat_lognormal": + nt = torch.randn((b,)).to(x.device) + t = torch.sigmoid(nt) + t = torch.where(torch.rand_like(t) <= 0.9, t, torch.rand_like(t)) + elif config.opt.schedule == "uniform": + t = torch.rand((b,), device=x.device) + elif config.opt.schedule.startswith("debug"): + p = float(config.opt.schedule.split("_")[1]) + t = torch.ones((b,), device=x.device) * p + else: + raise NotImplementedError + + t = t.view([b, *([1] * len(x.shape[1:]))]) + z1 = torch.randn_like(x) + zt = (1 - t) * x + t * z1 + + zt, t = zt.to(x.dtype), t.to(x.dtype) + + vtheta, aux = model( + img=x, + noised_img=zt, + timesteps=t.reshape((b,)), + ) + + diff = z1 - vtheta - x + x_pred = zt - vtheta * t + + loss = ((diff) ** 2).mean(dim=list(range(1, len(x.shape)))) + loss = loss.mean() + + aux["loss_dict"] = {} + aux["loss_dict"]["diffusion_loss"] = loss + aux["loss_dict"]["quantizer_loss"] = aux["quantizer_loss"] + + if config.opt.lpips_weight != 0.0: + aux_loss = 0.0 + if config.model.posttrain_sample: + x_pred = aux["posttrain_sample"] + + lpips_dist = aux_state["lpips_model"](x, x_pred) + lpips_dist = (config.opt.lpips_weight * lpips_dist).mean() + aux_loss + aux["loss_dict"]["lpips_loss"] = lpips_dist + else: + lpips_dist = 0.0 + + loss = loss + aux["quantizer_loss"] + lpips_dist + aux["loss_dict"]["total_loss"] = loss + return loss, aux + + +def _edm_to_flow_convention(noise_level): + # z = x + \sigma z' + return noise_level / (1 + noise_level) + + +def rf_sample( + model, + z, + code, + null_code=None, + sample_steps=25, + cfg=2.0, + schedule="linear", +): + b = z.size(0) + if schedule == "linear": + ts = torch.arange(1, sample_steps + 1).flip(0) / sample_steps + dts = torch.ones_like(ts) * (1.0 / sample_steps) + elif schedule.startswith("pow"): + p = float(schedule.split("_")[1]) + ts = torch.arange(0, sample_steps + 1).flip(0) ** (1 / p) / sample_steps ** ( + 1 / p + ) + dts = ts[:-1] - ts[1:] + else: + raise NotImplementedError + + if model.config.eval.sampling.cfg_interval is None: + interval = None + else: + cfg_lo, cfg_hi = ast.literal_eval(model.config.eval.sampling.cfg_interval) + interval = _edm_to_flow_convention(cfg_lo), _edm_to_flow_convention(cfg_hi) + + images = [] + for i, (t, dt) in enumerate((zip(ts, dts))): + timesteps = torch.tensor([t] * b).to(z.device) + vc, decode_aux = model.decode(img=z, timesteps=timesteps, code=code) + + if null_code is not None and ( + interval is None + or ((t.item() >= interval[0]) and (t.item() <= interval[1])) + ): + vu, _ = model.decode(img=z, timesteps=timesteps, code=null_code) + vc = vu + cfg * (vc - vu) + + z = z - dt * vc + images.append(z) + return images + + +def build_model(config): + with tempfile.TemporaryDirectory() as log_dir: + MUP_ENABLED = config.model.enable_mup + model_partial = FlowMo + + shared_kwargs = dict(config=config) + model = model_partial( + **shared_kwargs, + width=config.model.mup_width, + ).cuda() + + if config.model.enable_mup: + print("Mup enabled!") + with torch.device("cpu"): + base_model = model_partial( + **shared_kwargs, width=config.model.mup_width + ) + delta_model = model_partial( + **shared_kwargs, + width=( + config.model.mup_width * 4 if config.model.mup_width == 1 else 1 + ), + ) + true_model = model_partial( + **shared_kwargs, width=config.model.mup_width + ) + + if torch.distributed.is_initialized(): + bsh_path = os.path.join(log_dir, f"{dist.get_rank()}.bsh") + else: + bsh_path = os.path.join(log_dir, "0.bsh") + set_base_shapes( + true_model, base_model, delta=delta_model, savefile=bsh_path + ) + + model = set_base_shapes(model, base=bsh_path) + + for module in model.modules(): + if isinstance(module, MuReadout): + module.width_mult = lambda: module.weight.infshape.width_mult() + return model diff --git a/src/vqvaes/flowmo/lookup_free_quantize.py b/src/vqvaes/flowmo/lookup_free_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..e0d09a7c94cc51bc8d7e4aa597ec4fd30472ea1d --- /dev/null +++ b/src/vqvaes/flowmo/lookup_free_quantize.py @@ -0,0 +1,396 @@ +""" +Code is from https://github.com/TencentARC/SEED-Voken. Thanks! + +Lookup Free Quantization +Proposed in https://arxiv.org/abs/2310.05737 + +In the simplest setup, each dimension is quantized into {-1, 1}. +An entropy penalty is used to encourage utilization. + +Refer to +https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py +https://github.com/theAdamColton/ijepa-enhanced/blob/7edef5f7288ae8f537f0db8a10044a2a487f70c9/ijepa_enhanced/lfq.py +""" + +from collections import namedtuple +from math import ceil, log2 + +import torch +import torch.nn.functional as F +from einops import pack, rearrange, reduce, unpack +from torch import einsum +from torch.nn import Module + +# constants + +LossBreakdown = namedtuple( + "LossBreakdown", + ["per_sample_entropy", "codebook_entropy", "commitment", "avg_probs"], +) + +# helper functions + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg() if callable(arg) else arg + return None + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +# entropy + +# def log(t, eps = 1e-5): +# return t.clamp(min = eps).log() + + +def entropy(prob): + return (-prob * torch.log(prob + 1e-5)).sum(dim=-1) + + +# class + + +def mult_along_first_dims(x, y): + """ + returns x * y elementwise along the leading dimensions of y + """ + ndim_to_expand = x.ndim - y.ndim + for _ in range(ndim_to_expand): + y = y.unsqueeze(-1) + return x * y + + +def masked_mean(x, m): + """ + takes the mean of the elements of x that are not masked + the mean is taken along the shared leading dims of m + equivalent to: x[m].mean(tuple(range(m.ndim))) + + The benefit of using masked_mean rather than using + tensor indexing is that masked_mean is much faster + for torch-compile on batches. + + The drawback is larger floating point errors + """ + x = mult_along_first_dims(x, m) + x = x / m.sum() + return x.sum(tuple(range(m.ndim))) + + +def entropy_loss( + logits, + mask=None, + # temperature=0.01, + sample_minimization_weight=1.0, + batch_maximization_weight=1.0, + eps=1e-5, +): + """ + Entropy loss of unnormalized logits + + logits: Affinities are over the last dimension + + https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279 + LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024) + """ + # import pdb + # pdb.set_trace() + # print(logits.shape) + # raise + + temperature = 0.1 + probs = F.softmax(logits / temperature, -1) + log_probs = F.log_softmax(logits / temperature + eps, -1) + + if mask is not None: + # avg_probs = probs[mask].mean(tuple(range(probs.ndim - 1))) + # avg_probs = einx.mean("... D -> D", probs[mask]) + + avg_probs = masked_mean(probs, mask) + # avg_probs = einx.mean("... D -> D", avg_probs) + else: + avg_probs = reduce(probs, "... D -> D", "mean") + + avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps)) + + sample_entropy = -torch.sum(probs * log_probs, -1) + if mask is not None: + # sample_entropy = sample_entropy[mask].mean() + sample_entropy = masked_mean(sample_entropy, mask).mean() + else: + sample_entropy = torch.mean(sample_entropy) + + loss = (sample_minimization_weight * sample_entropy) - ( + batch_maximization_weight * avg_entropy + ) + + return sample_entropy, avg_entropy, loss + + +class LFQ(Module): + def __init__( + self, + *, + dim=None, + codebook_size=None, + num_codebooks=1, + sample_minimization_weight=1.0, + batch_maximization_weight=1.0, + token_factorization=False, + factorized_bits=[9, 9], + ): + super().__init__() + + # some assert validations + + assert exists(dim) or exists( + codebook_size + ), "either dim or codebook_size must be specified for LFQ" + assert ( + not exists(codebook_size) or log2(codebook_size).is_integer() + ), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})" + + self.codebook_size = default(codebook_size, lambda: 2**dim) + self.codebook_dim = int(log2(codebook_size)) + + codebook_dims = self.codebook_dim * num_codebooks + dim = default(dim, codebook_dims) + + has_projections = dim != codebook_dims + self.has_projections = has_projections + + self.dim = dim + self.codebook_dim = self.codebook_dim + self.num_codebooks = num_codebooks + + # for entropy loss + self.sample_minimization_weight = sample_minimization_weight + self.batch_maximization_weight = batch_maximization_weight + + # for no auxiliary loss, during inference + self.token_factorization = token_factorization + if not self.token_factorization: # for first stage model + self.register_buffer( + "mask", 2 ** torch.arange(self.codebook_dim), persistent=False + ) + else: + self.factorized_bits = factorized_bits + self.register_buffer( + "pre_mask", 2 ** torch.arange(factorized_bits[0]), persistent=False + ) + self.register_buffer( + "post_mask", 2 ** torch.arange(factorized_bits[1]), persistent=False + ) + + self.register_buffer("zero", torch.tensor(0.0), persistent=False) + + # codes + + all_codes = torch.arange(codebook_size) + bits = self.indices_to_bits(all_codes) + codebook = bits * 2.0 - 1.0 + + self.register_buffer("codebook", codebook, persistent=False) + + @property + def dtype(self): + return self.codebook.dtype + + def indices_to_bits(self, x): + """ + x: long tensor of indices + + returns big endian bits + """ + mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) + # x is now big endian bits, the last dimension being the bits + x = (x.unsqueeze(-1) & mask) != 0 + return x + + def get_codebook_entry(self, x, bhwc, order): # 0610 + if self.token_factorization: + if order == "pre": + mask = 2 ** torch.arange( + self.factorized_bits[0], device=x.device, dtype=torch.long + ) + else: + mask = 2 ** torch.arange( + self.factorized_bits[1], device=x.device, dtype=torch.long + ) + else: + mask = 2 ** torch.arange( + self.codebook_dim, device=x.device, dtype=torch.long + ) + + x = (x.unsqueeze(-1) & mask) != 0 + x = x * 2.0 - 1.0 # back to the float + ## scale back to the + b, h, w, c = bhwc + x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c) + x = rearrange(x, "b h w c -> b c h w") + return x + + def bits_to_indices(self, bits): + """ + bits: bool tensor of big endian bits, where the last dimension is the bit dimension + + returns indices, which are long integers from 0 to self.codebook_size + """ + assert bits.shape[-1] == self.codebook_dim + indices = 2 ** torch.arange( + 0, + self.codebook_dim, + 1, + dtype=torch.long, + device=bits.device, + ) + return (bits * indices).sum(-1) + + def decode(self, x): + """ + x: ... NH + where NH is number of codebook heads + A longtensor of codebook indices, containing values from + 0 to self.codebook_size + """ + x = self.indices_to_bits(x) + # to some sort of float + x = x.to(self.dtype) + # -1 or 1 + x = x * 2 - 1 + x = rearrange(x, "... NC Z-> ... (NC Z)") + return x + + def forward( + self, + x, + inv_temperature=100.0, + return_loss_breakdown=False, + mask=None, + return_loss=True, + ): + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension, which is also log2(codebook size) + c - number of codebook dim + """ + # x = x.tanh() * 1.5 + + x = rearrange(x, "b d ... -> b ... d") + x, ps = pack_one(x, "b * d") + # split out number of codebooks + + x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks) + + codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype) + quantized = torch.where( + x > 0, codebook_value, -codebook_value + ) # higher than 0 filled + + # calculate indices + if self.token_factorization: + indices_pre = reduce( + (quantized[..., : self.factorized_bits[0]] > 0).int() + * self.pre_mask.int(), + "b n c d -> b n c", + "sum", + ) + indices_post = reduce( + (quantized[..., self.factorized_bits[0] :] > 0).int() + * self.post_mask.int(), + "b n c d -> b n c", + "sum", + ) + else: + # print(quantized.shape) + indices = reduce( + (quantized > 0).int() * self.mask.int(), "b n c d -> b n c", "sum" + ) + # print(indices.shape) + + # entropy aux loss + + if self.training and return_loss: + logits = 2 * einsum("... i d, j d -> ... i j", x, self.codebook) + # the same as euclidean distance up to a constant + # import pdb + # pdb.set_trace() + + per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss( + logits=logits, + sample_minimization_weight=self.sample_minimization_weight, + batch_maximization_weight=self.batch_maximization_weight, + ) + + avg_probs = self.zero + else: + # logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook) + # probs = F.softmax(logits / 0.01, -1) + # avg_probs = reduce(probs, "b n c d -> b d", "mean") + # avg_probs = torch.sum(avg_probs, 0) #batch dimension + # if not training, just return dummy 0 + per_sample_entropy = codebook_entropy = self.zero + ## calculate the codebook_entropy needed for one batch evaluation + entropy_aux_loss = self.zero + avg_probs = self.zero + + # commit loss + + if self.training: + commit_loss = F.mse_loss(x, quantized.detach(), reduction="none") + + if exists(mask): + commit_loss = commit_loss[mask] + + commit_loss = commit_loss.mean() + else: + commit_loss = self.zero + + # use straight-through gradients (optionally with custom activation fn) if training + + quantized = x + (quantized - x).detach() # transfer to quantized + + # merge back codebook dim + + quantized = rearrange(quantized, "b n c d -> b n (c d)") + + # reconstitute image or video dimensions + + quantized = unpack_one(quantized, ps, "b * d") + quantized = rearrange(quantized, "b ... d -> b d ...") + + if self.token_factorization: + indices_pre = unpack_one(indices_pre, ps, "b * c") + indices_post = unpack_one(indices_post, ps, "b * c") + indices_pre = indices_pre.flatten() + indices_post = indices_post.flatten() + indices = (indices_pre, indices_post) + else: + # print(indices.shape, ps) + indices = unpack_one(indices, ps, "b * c") + # print(indices.shape) + indices = indices.flatten() + + ret = (quantized, entropy_aux_loss, indices) + + if not return_loss_breakdown: + return ret + + return ret, LossBreakdown( + per_sample_entropy, codebook_entropy, commit_loss, avg_probs + ) diff --git a/src/vqvaes/infinity/conv.py b/src/vqvaes/infinity/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..9523b385b41acef1296f02841c97b02660edb3d4 --- /dev/null +++ b/src/vqvaes/infinity/conv.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F + + +class Conv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + cnn_type="2d", + causal_offset=0, + temporal_down=False, + ): + super().__init__() + self.cnn_type = cnn_type + self.slice_seq_len = 17 + + if cnn_type == "2d": + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=padding + ) + if cnn_type == "3d": + if temporal_down == False: + stride = (1, stride, stride) + else: + stride = (stride, stride, stride) + self.conv = nn.Conv3d( + in_channels, out_channels, kernel_size, stride=stride, padding=0 + ) + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + self.padding = ( + kernel_size[0] - 1 + causal_offset, # Temporal causal padding + padding, # Height padding + padding, # Width padding + ) + self.causal_offset = causal_offset + self.stride = stride + self.kernel_size = kernel_size + + def forward(self, x): + if self.cnn_type == "2d": + if x.ndim == 5: + B, C, T, H, W = x.shape + x = rearrange(x, "B C T H W -> (B T) C H W") + x = self.conv(x) + x = rearrange(x, "(B T) C H W -> B C T H W", T=T) + return x + else: + return self.conv(x) + if self.cnn_type == "3d": + assert ( + self.stride[0] == 1 or self.stride[0] == 2 + ), f"only temporal stride = 1 or 2 are supported" + xs = [] + for i in range(0, x.shape[2], self.slice_seq_len + self.stride[0] - 1): + st = i + en = min(i + self.slice_seq_len, x.shape[2]) + _x = x[:, :, st:en, :, :] + if i == 0: + _x = F.pad( + _x, + ( + self.padding[2], + self.padding[2], # Width + self.padding[1], + self.padding[1], # Height + self.padding[0], + 0, + ), + ) # Temporal + else: + padding_0 = self.kernel_size[0] - 1 + _x = F.pad( + _x, + ( + self.padding[2], + self.padding[2], # Width + self.padding[1], + self.padding[1], # Height + padding_0, + 0, + ), + ) # Temporal + _x[ + :, + :, + :padding_0, + self.padding[1] : _x.shape[-2] - self.padding[1], + self.padding[2] : _x.shape[-1] - self.padding[2], + ] += x[:, :, i - padding_0 : i, :, :] + _x = self.conv(_x) + xs.append(_x) + try: + x = torch.cat(xs, dim=2) + except: + device = x.device + del x + xs = [_x.cpu().pin_memory() for _x in xs] + torch.cuda.empty_cache() + x = torch.cat([_x.cpu() for _x in xs], dim=2).to(device=device) + return x diff --git a/src/vqvaes/infinity/dynamic_resolution.py b/src/vqvaes/infinity/dynamic_resolution.py new file mode 100644 index 0000000000000000000000000000000000000000..59ea3ea7b3ac4ca493e3ac27e4503f4b5d41d614 --- /dev/null +++ b/src/vqvaes/infinity/dynamic_resolution.py @@ -0,0 +1,147 @@ +import json +import numpy as np +import tqdm + +vae_stride = 16 +ratio2hws = { + 1.000: [ + (1, 1), + (2, 2), + (4, 4), + (6, 6), + (8, 8), + (12, 12), + (16, 16), + (20, 20), + (24, 24), + (32, 32), + (40, 40), + (48, 48), + (64, 64), + ], + 1.250: [ + (1, 1), + (2, 2), + (3, 3), + (5, 4), + (10, 8), + (15, 12), + (20, 16), + (25, 20), + (30, 24), + (35, 28), + (45, 36), + (55, 44), + (70, 56), + ], + 1.333: [ + (1, 1), + (2, 2), + (4, 3), + (8, 6), + (12, 9), + (16, 12), + (20, 15), + (24, 18), + (28, 21), + (36, 27), + (48, 36), + (60, 45), + (72, 54), + ], + 1.500: [ + (1, 1), + (2, 2), + (3, 2), + (6, 4), + (9, 6), + (15, 10), + (21, 14), + (27, 18), + (33, 22), + (39, 26), + (48, 32), + (63, 42), + (78, 52), + ], + 1.750: [ + (1, 1), + (2, 2), + (3, 3), + (7, 4), + (11, 6), + (14, 8), + (21, 12), + (28, 16), + (35, 20), + (42, 24), + (56, 32), + (70, 40), + (84, 48), + ], + 2.000: [ + (1, 1), + (2, 2), + (4, 2), + (6, 3), + (10, 5), + (16, 8), + (22, 11), + (30, 15), + (38, 19), + (46, 23), + (60, 30), + (74, 37), + (90, 45), + ], + 2.500: [ + (1, 1), + (2, 2), + (5, 2), + (10, 4), + (15, 6), + (20, 8), + (25, 10), + (30, 12), + (40, 16), + (50, 20), + (65, 26), + (80, 32), + (100, 40), + ], + 3.000: [ + (1, 1), + (2, 2), + (6, 2), + (9, 3), + (15, 5), + (21, 7), + (27, 9), + (36, 12), + (45, 15), + (54, 18), + (72, 24), + (90, 30), + (111, 37), + ], +} +full_ratio2hws = {} +for ratio, hws in ratio2hws.items(): + full_ratio2hws[ratio] = hws + full_ratio2hws[int(1 / ratio * 1000) / 1000] = [(item[1], item[0]) for item in hws] + +dynamic_resolution_h_w = {} +predefined_HW_Scales_dynamic = {} +for ratio in full_ratio2hws: + dynamic_resolution_h_w[ratio] = {} + for ind, leng in enumerate([7, 10, 13]): + h, w = ( + full_ratio2hws[ratio][leng - 1][0], + full_ratio2hws[ratio][leng - 1][1], + ) # feature map size + pixel = (h * vae_stride, w * vae_stride) # The original image (H, W) + dynamic_resolution_h_w[ratio][pixel[1]] = { + "pixel": pixel, + "scales": full_ratio2hws[ratio][:leng], + } # W as key + predefined_HW_Scales_dynamic[(h, w)] = full_ratio2hws[ratio][:leng] diff --git a/src/vqvaes/infinity/flux_vqgan.py b/src/vqvaes/infinity/flux_vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..58fde04a3dbd3c3a450b543a338dd4af787ae93e --- /dev/null +++ b/src/vqvaes/infinity/flux_vqgan.py @@ -0,0 +1,771 @@ +import argparse +import os +import imageio +import torch +import numpy as np +from einops import rearrange +from torch import Tensor, nn +import torch.nn.functional as F +import torchvision +from torchvision import transforms +from safetensors.torch import load_file +import torch.utils.checkpoint as checkpoint + +from .conv import Conv +from .multiscale_bsq import MultiScaleBSQ + +ptdtype = {None: torch.float32, "fp32": torch.float32, "bf16": torch.bfloat16} + + +class Normalize(nn.Module): + def __init__(self, in_channels, norm_type, norm_axis="spatial"): + super().__init__() + self.norm_axis = norm_axis + assert norm_type in ["group", "batch", "no"] + if norm_type == "group": + if in_channels % 32 == 0: + self.norm = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + elif in_channels % 24 == 0: + self.norm = nn.GroupNorm( + num_groups=24, num_channels=in_channels, eps=1e-6, affine=True + ) + else: + raise NotImplementedError + elif norm_type == "batch": + self.norm = nn.SyncBatchNorm( + in_channels, track_running_stats=False + ) # Runtime Error: grad inplace if set track_running_stats to True + elif norm_type == "no": + self.norm = nn.Identity() + + def forward(self, x): + if self.norm_axis == "spatial": + if x.ndim == 4: + x = self.norm(x) + else: + B, C, T, H, W = x.shape + x = rearrange(x, "B C T H W -> (B T) C H W") + x = self.norm(x) + x = rearrange(x, "(B T) C H W -> B C T H W", T=T) + elif self.norm_axis == "spatial-temporal": + x = self.norm(x) + else: + raise NotImplementedError + return x + + +def swish(x: Tensor) -> Tensor: + try: + return x * torch.sigmoid(x) + except: + device = x.device + x = x.cpu().pin_memory() + return (x * torch.sigmoid(x)).to(device=device) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type="group", cnn_param=None): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize( + in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"] + ) + + self.q = Conv(in_channels, in_channels, kernel_size=1) + self.k = Conv(in_channels, in_channels, kernel_size=1) + self.v = Conv(in_channels, in_channels, kernel_size=1) + self.proj_out = Conv(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + B, _, T, _, _ = h_.shape + h_ = self.norm(h_) + h_ = rearrange(h_, "B C T H W -> (B T) C H W") # spatial attention only + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "(b t) 1 (h w) c -> b c t h w", h=h, w=w, c=c, b=B, t=T) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, norm_type="group", cnn_param=None + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = Normalize( + in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"] + ) + if cnn_param["res_conv_2d"] in ["half", "full"]: + self.conv1 = Conv( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + cnn_type="2d", + ) + else: + self.conv1 = Conv( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + cnn_type=cnn_param["cnn_type"], + ) + self.norm2 = Normalize( + out_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"] + ) + if cnn_param["res_conv_2d"] in ["full"]: + self.conv2 = Conv( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + cnn_type="2d", + ) + else: + self.conv2 = Conv( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + cnn_type=cnn_param["cnn_type"], + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = Conv( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__( + self, in_channels, cnn_type="2d", spatial_down=False, temporal_down=False + ): + super().__init__() + assert spatial_down == True + if cnn_type == "2d": + self.pad = (0, 1, 0, 1) + if cnn_type == "3d": + self.pad = ( + 0, + 1, + 0, + 1, + 0, + 0, + ) # add padding to the right for h-axis and w-axis. No padding for t-axis + # no asymmetric padding in torch conv, must do it ourselves + self.conv = Conv( + in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0, + cnn_type=cnn_type, + temporal_down=temporal_down, + ) + + def forward(self, x: Tensor): + x = nn.functional.pad(x, self.pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__( + self, + in_channels, + cnn_type="2d", + spatial_up=False, + temporal_up=False, + use_pxsl=False, + ): + super().__init__() + if cnn_type == "2d": + self.scale_factor = 2 + self.causal_offset = 0 + else: + assert spatial_up == True + if temporal_up: + self.scale_factor = (2, 2, 2) + self.causal_offset = -1 + else: + self.scale_factor = (1, 2, 2) + self.causal_offset = 0 + self.use_pxsl = use_pxsl + if self.use_pxsl: + self.conv = Conv( + in_channels, + in_channels * 4, + kernel_size=3, + stride=1, + padding=1, + cnn_type=cnn_type, + causal_offset=self.causal_offset, + ) + self.pxsl = nn.PixelShuffle(2) + else: + self.conv = Conv( + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1, + cnn_type=cnn_type, + causal_offset=self.causal_offset, + ) + + def forward(self, x: Tensor): + if self.use_pxsl: + x = self.conv(x) + x = self.pxsl(x) + else: + try: + x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest") + except: + # shard across channel + _xs = [] + for i in range(x.shape[1]): + _x = F.interpolate( + x[:, i : i + 1, ...], + scale_factor=self.scale_factor, + mode="nearest", + ) + _xs.append(_x) + x = torch.cat(_xs, dim=1) + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + in_channels=3, + patch_size=8, + temporal_patch_size=4, + norm_type="group", + cnn_param=None, + use_checkpoint=False, + use_vae=True, + ): + super().__init__() + self.max_down = np.log2(patch_size) + self.temporal_max_down = np.log2(temporal_patch_size) + self.temporal_down_offset = self.max_down - self.temporal_max_down + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_channels + self.cnn_param = cnn_param + self.use_checkpoint = use_checkpoint + # downsampling + # self.conv_in = Conv(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + # cnn_param["cnn_type"] = "2d" for images, cnn_param["cnn_type"] = "3d" for videos + if cnn_param["conv_in_out_2d"] == "yes": # "yes" for video + self.conv_in = Conv( + in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type="2d" + ) + else: + self.conv_in = Conv( + in_channels, + ch, + kernel_size=3, + stride=1, + padding=1, + cnn_type=cnn_param["cnn_type"], + ) + + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + norm_type=norm_type, + cnn_param=cnn_param, + ) + ) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + # downsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE + spatial_down = True if i_level < self.max_down else False + temporal_down = ( + True + if i_level < self.max_down and i_level >= self.temporal_down_offset + else False + ) + if spatial_down or temporal_down: + down.downsample = Downsample( + block_in, + cnn_type=cnn_param["cnn_type"], + spatial_down=spatial_down, + temporal_down=temporal_down, + ) + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + norm_type=norm_type, + cnn_param=cnn_param, + ) + if cnn_param["cnn_attention"] == "yes": + self.mid.attn_1 = AttnBlock(block_in, norm_type, cnn_param=cnn_param) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + norm_type=norm_type, + cnn_param=cnn_param, + ) + + # end + self.norm_out = Normalize( + block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"] + ) + if cnn_param["conv_inner_2d"] == "yes": + self.conv_out = Conv( + block_in, + (int(use_vae) + 1) * z_channels, + kernel_size=3, + stride=1, + padding=1, + cnn_type="2d", + ) + else: + self.conv_out = Conv( + block_in, + (int(use_vae) + 1) * z_channels, + kernel_size=3, + stride=1, + padding=1, + cnn_type=cnn_param["cnn_type"], + ) + + def forward(self, x, return_hidden=False): + if not self.use_checkpoint: + return self._forward(x, return_hidden=return_hidden) + else: + return checkpoint.checkpoint( + self._forward, x, return_hidden, use_reentrant=False + ) + + def _forward(self, x: Tensor, return_hidden=False) -> Tensor: + # downsampling + h0 = self.conv_in(x) + hs = [h0] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if hasattr(self.down[i_level], "downsample"): + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + hs_mid = [h] + h = self.mid.block_1(h) + if self.cnn_param["cnn_attention"] == "yes": + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + hs_mid.append(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + if return_hidden: + return h, hs, hs_mid + else: + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + out_ch=3, + patch_size=8, + temporal_patch_size=4, + norm_type="group", + cnn_param=None, + use_checkpoint=False, + use_freq_dec=False, # use frequency features for decoder + use_pxsf=False, + ): + super().__init__() + self.max_up = np.log2(patch_size) + self.temporal_max_up = np.log2(temporal_patch_size) + self.temporal_up_offset = self.max_up - self.temporal_max_up + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.ffactor = 2 ** (self.num_resolutions - 1) + self.cnn_param = cnn_param + self.use_checkpoint = use_checkpoint + self.use_freq_dec = use_freq_dec + self.use_pxsf = use_pxsf + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + + # z to block_in + if cnn_param["conv_inner_2d"] == "yes": + self.conv_in = Conv( + z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type="2d" + ) + else: + self.conv_in = Conv( + z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1, + cnn_type=cnn_param["cnn_type"], + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + norm_type=norm_type, + cnn_param=cnn_param, + ) + if cnn_param["cnn_attention"] == "yes": + self.mid.attn_1 = AttnBlock( + block_in, norm_type=norm_type, cnn_param=cnn_param + ) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + norm_type=norm_type, + cnn_param=cnn_param, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + norm_type=norm_type, + cnn_param=cnn_param, + ) + ) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + # upsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE, offset 1 compared with encoder + # https://github.com/black-forest-labs/flux/blob/b4f689aaccd40de93429865793e84a734f4a6254/src/flux/modules/autoencoder.py#L228 + spatial_up = True if 1 <= i_level <= self.max_up else False + temporal_up = ( + True + if 1 <= i_level <= self.max_up + and i_level >= self.temporal_up_offset + 1 + else False + ) + if spatial_up or temporal_up: + up.upsample = Upsample( + block_in, + cnn_type=cnn_param["cnn_type"], + spatial_up=spatial_up, + temporal_up=temporal_up, + use_pxsl=self.use_pxsf, + ) + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize( + block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"] + ) + if cnn_param["conv_in_out_2d"] == "yes": + self.conv_out = Conv( + block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type="2d" + ) + else: + self.conv_out = Conv( + block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1, + cnn_type=cnn_param["cnn_type"], + ) + + def forward(self, z): + if not self.use_checkpoint: + return self._forward(z) + else: + return checkpoint.checkpoint(self._forward, z, use_reentrant=False) + + def _forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + if self.cnn_param["cnn_attention"] == "yes": + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if hasattr(self.up[i_level], "upsample"): + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class AutoEncoder(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + cnn_param = dict( + cnn_type=args.cnn_type, + conv_in_out_2d=args.conv_in_out_2d, + res_conv_2d=args.res_conv_2d, + cnn_attention=args.cnn_attention, + cnn_norm_axis=args.cnn_norm_axis, + conv_inner_2d=args.conv_inner_2d, + ) + self.encoder = Encoder( + ch=args.base_ch, + ch_mult=args.encoder_ch_mult, + num_res_blocks=args.num_res_blocks, + z_channels=args.codebook_dim, + patch_size=args.patch_size, + temporal_patch_size=args.temporal_patch_size, + cnn_param=cnn_param, + use_checkpoint=args.use_checkpoint, + use_vae=args.use_vae, + ) + self.decoder = Decoder( + ch=args.base_ch, + ch_mult=args.decoder_ch_mult, + num_res_blocks=args.num_res_blocks, + z_channels=args.codebook_dim, + patch_size=args.patch_size, + temporal_patch_size=args.temporal_patch_size, + cnn_param=cnn_param, + use_checkpoint=args.use_checkpoint, + use_freq_dec=args.use_freq_dec, + use_pxsf=args.use_pxsf, # pixelshuffle for upsampling + ) + self.z_drop = nn.Dropout(args.z_drop) + self.scale_factor = 0.3611 + self.shift_factor = 0.1159 + self.codebook_dim = self.embed_dim = args.codebook_dim + + self.gan_feat_weight = args.gan_feat_weight + self.video_perceptual_weight = args.video_perceptual_weight + self.recon_loss_type = args.recon_loss_type + self.l1_weight = args.l1_weight + self.use_vae = args.use_vae + self.kl_weight = args.kl_weight + self.lfq_weight = args.lfq_weight + self.image_gan_weight = args.image_gan_weight # image GAN loss weight + self.video_gan_weight = args.video_gan_weight # video GAN loss weight + self.perceptual_weight = args.perceptual_weight + self.flux_weight = args.flux_weight + self.cycle_weight = args.cycle_weight + self.cycle_feat_weight = args.cycle_feat_weight + self.cycle_gan_weight = args.cycle_gan_weight + + self.flux_image_encoder = None + + if not args.use_vae: + if args.quantizer_type == "MultiScaleBSQ": + self.quantizer = MultiScaleBSQ( + dim=args.codebook_dim, # this is the input feature dimension, defaults to log2(codebook_size) if not defined + codebook_size=args.codebook_size, # codebook size, must be a power of 2 + entropy_loss_weight=args.entropy_loss_weight, # how much weight to place on entropy loss + diversity_gamma=args.diversity_gamma, # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894 + preserve_norm=args.preserve_norm, # preserve norm of the input for BSQ + ln_before_quant=args.ln_before_quant, # use layer norm before quantization + ln_init_by_sqrt=args.ln_init_by_sqrt, # layer norm init value 1/sqrt(d) + commitment_loss_weight=args.commitment_loss_weight, # loss weight of commitment loss + new_quant=args.new_quant, + use_decay_factor=args.use_decay_factor, + mask_out=args.mask_out, + use_stochastic_depth=args.use_stochastic_depth, + drop_rate=args.drop_rate, + schedule_mode=args.schedule_mode, + keep_first_quant=args.keep_first_quant, + keep_last_quant=args.keep_last_quant, + remove_residual_detach=args.remove_residual_detach, + use_out_phi=args.use_out_phi, + use_out_phi_res=args.use_out_phi_res, + random_flip=args.random_flip, + flip_prob=args.flip_prob, + flip_mode=args.flip_mode, + max_flip_lvl=args.max_flip_lvl, + random_flip_1lvl=args.random_flip_1lvl, + flip_lvl_idx=args.flip_lvl_idx, + drop_when_test=args.drop_when_test, + drop_lvl_idx=args.drop_lvl_idx, + drop_lvl_num=args.drop_lvl_num, + ) + self.quantize = self.quantizer + self.vocab_size = args.codebook_size + else: + raise NotImplementedError(f"{args.quantizer_type} not supported") + + def forward(self, x): + is_image = x.ndim == 4 + if not is_image: + B, C, T, H, W = x.shape + else: + B, C, H, W = x.shape + T = 1 + enc_dtype = ptdtype[self.args.encoder_dtype] + + with torch.amp.autocast("cuda", dtype=enc_dtype): + h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W + hs = [_h.detach() for _h in hs] + hs_mid = [_h.detach() for _h in hs_mid] + h = h.to(dtype=torch.float32) + # print(z.shape) + # Multiscale LFQ + z, all_indices, _, _, all_loss, _ = self.quantizer(h) + x_recon = self.decoder(z) + vq_output = { + "commitment_loss": torch.mean(all_loss) + * self.lfq_weight, # here commitment loss is sum of commitment loss and entropy penalty + "encodings": all_indices, + } + # return x_recon, vq_output + return x_recon, None, z + + def encode_for_raw_features( + self, x, scale_schedule, return_residual_norm_per_scale=False + ): + is_image = x.ndim == 4 + if not is_image: + B, C, T, H, W = x.shape + else: + B, C, H, W = x.shape + T = 1 + + enc_dtype = ptdtype[self.args.encoder_dtype] + with torch.amp.autocast("cuda", dtype=enc_dtype): + h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W + + hs = [_h.detach() for _h in hs] + hs_mid = [_h.detach() for _h in hs_mid] + h = h.to(dtype=torch.float32) + return h, hs, hs_mid + + def encode(self, x, scale_schedule, return_residual_norm_per_scale=False): + h, hs, hs_mid = self.encode_for_raw_features( + x, scale_schedule, return_residual_norm_per_scale + ) + # Multiscale LFQ + ( + z, + all_indices, + all_bit_indices, + residual_norm_per_scale, + all_loss, + var_input, + ) = self.quantizer( + h, + scale_schedule=scale_schedule, + return_residual_norm_per_scale=return_residual_norm_per_scale, + ) + return h, z, all_indices, all_bit_indices, residual_norm_per_scale, var_input + + def decode(self, z): + x_recon = self.decoder(z) + x_recon = torch.clamp(x_recon, min=-1, max=1) + return x_recon + + def decode_from_indices(self, all_indices, scale_schedule, label_type): + summed_codes = 0 + for idx_Bl in all_indices: + codes = self.quantizer.lfq.indices_to_codes(idx_Bl, label_type) + summed_codes += F.interpolate( + codes, size=scale_schedule[-1], mode=self.quantizer.z_interplote_up + ) + assert summed_codes.shape[-3] == 1 + x_recon = self.decoder(summed_codes.squeeze(-3)) + x_recon = torch.clamp(x_recon, min=-1, max=1) + return summed_codes, x_recon + + @staticmethod + def add_model_specific_args(parent_parser): + parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument("--flux_weight", type=float, default=0) + parser.add_argument("--cycle_weight", type=float, default=0) + parser.add_argument("--cycle_feat_weight", type=float, default=0) + parser.add_argument("--cycle_gan_weight", type=float, default=0) + parser.add_argument("--cycle_loop", type=int, default=0) + parser.add_argument("--z_drop", type=float, default=0.0) + return parser diff --git a/src/vqvaes/infinity/multiscale_bsq.py b/src/vqvaes/infinity/multiscale_bsq.py new file mode 100644 index 0000000000000000000000000000000000000000..22898b317cdb1909577cc8f68be82bf2afda6160 --- /dev/null +++ b/src/vqvaes/infinity/multiscale_bsq.py @@ -0,0 +1,893 @@ +""" +Binary Spherical Quantization +Proposed in https://arxiv.org/abs/2406.07548 + +In the simplest setup, each dimension is quantized into {-1, 1}. +An entropy penalty is used to encourage utilization. +""" + +import random +from math import log2, ceil +from functools import partial, cache +from collections import namedtuple +from contextlib import nullcontext + +import torch.distributed as dist +from torch.distributed import nn as dist_nn + +import torch +from torch import nn, einsum +import torch.nn.functional as F +from torch.nn import Module +from torch.amp import autocast +import numpy as np + +from einops import rearrange, reduce, pack, unpack + +# from einx import get_at + +from .dynamic_resolution import predefined_HW_Scales_dynamic + +# constants + +Return = namedtuple( + "Return", ["quantized", "indices", "bit_indices", "entropy_aux_loss"] +) + +LossBreakdown = namedtuple( + "LossBreakdown", ["per_sample_entropy", "batch_entropy", "commitment"] +) + +# distributed helpers + + +@cache +def is_distributed(): + return dist.is_initialized() and dist.get_world_size() > 1 + + +def maybe_distributed_mean(t): + if not is_distributed(): + return t + + dist_nn.all_reduce(t) + t = t / dist.get_world_size() + return t + + +# helper functions + + +def exists(v): + return v is not None + + +def identity(t): + return t + + +def default(*args): + for arg in args: + if exists(arg): + return arg() if callable(arg) else arg + return None + + +def round_up_multiple(num, mult): + return ceil(num / mult) * mult + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def l2norm(t): + return F.normalize(t, dim=-1) + + +# entropy + + +def log(t, eps=1e-5): + return t.clamp(min=eps).log() + + +def entropy(prob): + return (-prob * log(prob)).sum(dim=-1) + + +# cosine sim linear + + +class CosineSimLinear(Module): + def __init__(self, dim_in, dim_out, scale=1.0): + super().__init__() + self.scale = scale + self.weight = nn.Parameter(torch.randn(dim_in, dim_out)) + + def forward(self, x): + x = F.normalize(x, dim=-1) + w = F.normalize(self.weight, dim=0) + return (x @ w) * self.scale + + +def get_latent2scale_schedule(T: int, H: int, W: int, mode="original"): + assert mode in ["original", "dynamic", "dense", "same1", "same2", "same3"] + predefined_HW_Scales = { + # 256 * 256 + (32, 32): [ + (1, 1), + (2, 2), + (3, 3), + (4, 4), + (6, 6), + (9, 9), + (13, 13), + (18, 18), + (24, 24), + (32, 32), + ], + (16, 16): [ + (1, 1), + (2, 2), + (3, 3), + (4, 4), + (5, 5), + (6, 6), + (8, 8), + (10, 10), + (13, 13), + (16, 16), + ], + # 1024x1024 + (64, 64): [ + (1, 1), + (2, 2), + (3, 3), + (4, 4), + (5, 5), + (7, 7), + (9, 9), + (12, 12), + (16, 16), + (21, 21), + (27, 27), + (36, 36), + (48, 48), + (64, 64), + ], + (36, 64): [ + (1, 1), + (2, 2), + (3, 3), + (4, 4), + (6, 6), + (9, 12), + (13, 16), + (18, 24), + (24, 32), + (32, 48), + (36, 64), + ], + } + if mode == "dynamic": + predefined_HW_Scales.update(predefined_HW_Scales_dynamic) + elif mode == "dense": + predefined_HW_Scales[(16, 16)] = [(x, x) for x in range(1, 16 + 1)] + predefined_HW_Scales[(32, 32)] = predefined_HW_Scales[(16, 16)] + [ + (20, 20), + (24, 24), + (28, 28), + (32, 32), + ] + predefined_HW_Scales[(64, 64)] = predefined_HW_Scales[(32, 32)] + [ + (40, 40), + (48, 48), + (56, 56), + (64, 64), + ] + elif mode.startswith("same"): + num_quant = int(mode[len("same") :]) + predefined_HW_Scales[(16, 16)] = [(16, 16) for _ in range(num_quant)] + predefined_HW_Scales[(32, 32)] = [(32, 32) for _ in range(num_quant)] + predefined_HW_Scales[(64, 64)] = [(64, 64) for _ in range(num_quant)] + + predefined_T_Scales = [1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15, 17, 17, 17, 17, 17] + patch_THW_shape_per_scale = predefined_HW_Scales[(H, W)] + if len(predefined_T_Scales) < len(patch_THW_shape_per_scale): + # print("warning: the length of predefined_T_Scales is less than the length of patch_THW_shape_per_scale!") + predefined_T_Scales += [predefined_T_Scales[-1]] * ( + len(patch_THW_shape_per_scale) - len(predefined_T_Scales) + ) + patch_THW_shape_per_scale = [ + (min(T, t), h, w) + for (h, w), t in zip( + patch_THW_shape_per_scale, + predefined_T_Scales[: len(patch_THW_shape_per_scale)], + ) + ] + return patch_THW_shape_per_scale + + +class LayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + normalized_shape: int + """ + + def __init__( + self, + normalized_shape, + norm_weight=False, + eps=1e-6, + data_format="channels_first", + ): + super().__init__() + if norm_weight: + self.weight = nn.Parameter( + torch.ones(normalized_shape) / (normalized_shape**0.5) + ) + else: + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + if x.ndim == 4: # (b, c, h, w) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + elif x.ndim == 5: # (b, c, t, h, w) + x = ( + self.weight[:, None, None, None] * x + + self.bias[:, None, None, None] + ) + else: + raise ValueError( + "the number of dimensions of the input should be 4 or 5" + ) + return x + + +class MultiScaleBSQ(Module): + """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" + + def __init__( + self, + *, + dim, + codebook_size, + soft_clamp_input_value=None, + aux_loss=False, # intermediate auxiliary loss + ln_before_quant=False, # add a LN before multi-scale RQ + ln_init_by_sqrt=False, # weight init by 1/sqrt(d) + use_decay_factor=False, + use_stochastic_depth=False, + drop_rate=0.0, + schedule_mode="original", # ["original", "dynamic", "dense"] + keep_first_quant=False, + keep_last_quant=False, + remove_residual_detach=False, + random_flip=False, + flip_prob=0.5, + flip_mode="stochastic", # "stochastic", "deterministic" + max_flip_lvl=1, + random_flip_1lvl=False, # random flip one level each time + flip_lvl_idx=None, + drop_when_test=False, + drop_lvl_idx=None, + drop_lvl_num=0, + **kwargs, + ): + super().__init__() + codebook_dim = int(log2(codebook_size)) + + requires_projection = codebook_dim != dim + self.project_in = ( + nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() + ) + self.has_projections = requires_projection + self.layernorm = ( + LayerNorm(codebook_dim, norm_weight=ln_init_by_sqrt) + if ln_before_quant + else nn.Identity() + ) + self.use_stochastic_depth = use_stochastic_depth + self.drop_rate = drop_rate + self.remove_residual_detach = remove_residual_detach + self.random_flip = random_flip + self.flip_prob = flip_prob + self.flip_mode = flip_mode + self.max_flip_lvl = max_flip_lvl + self.random_flip_1lvl = random_flip_1lvl + self.flip_lvl_idx = flip_lvl_idx + assert (random_flip and random_flip_1lvl) == False + self.drop_when_test = drop_when_test + self.drop_lvl_idx = drop_lvl_idx + self.drop_lvl_num = drop_lvl_num + if self.drop_when_test: + assert drop_lvl_idx is not None + assert drop_lvl_num > 0 + + self.lfq = BSQ( + dim=codebook_dim, + codebook_scale=1 / np.sqrt(codebook_dim), + soft_clamp_input_value=soft_clamp_input_value, + # experimental_softplus_entropy_loss=True, + # entropy_loss_offset=2, + **kwargs, + ) + + self.z_interplote_up = "trilinear" + self.z_interplote_down = "area" + + self.use_decay_factor = use_decay_factor + self.schedule_mode = schedule_mode + self.keep_first_quant = keep_first_quant + self.keep_last_quant = keep_last_quant + if self.use_stochastic_depth and self.drop_rate > 0: + assert self.keep_first_quant or self.keep_last_quant + + @property + def codebooks(self): + return self.lfq.codebook + + def get_codes_from_indices(self, indices_list): + all_codes = [] + for indices in indices_list: + codes = self.lfq.indices_to_codes(indices) + all_codes.append(codes) + _, _, T, H, W = all_codes[-1].size() + summed_codes = 0 + for code in all_codes: + summed_codes += F.interpolate( + code, size=(T, H, W), mode=self.z_interplote_up + ) + return summed_codes + + def get_output_from_indices(self, indices): + codes = self.get_codes_from_indices(indices) + codes_summed = reduce(codes, "q ... -> ...", "sum") + return self.project_out(codes_summed) + + def flip_quant(self, x): + assert self.flip_mode == "stochastic" + flip_mask = torch.rand_like(x) < self.flip_prob + x = x.clone() + x[flip_mask] = -x[flip_mask] + return x + + def forward( + self, + x, + scale_schedule=None, + mask=None, + return_all_codes=False, + return_residual_norm_per_scale=False, + ): + if x.ndim == 4: + x = x.unsqueeze(2) + B, C, T, H, W = x.size() + + if scale_schedule is None: + if self.schedule_mode.startswith("same"): + scale_num = int(self.schedule_mode[len("same") :]) + assert T == 1 + scale_schedule = [(1, H, W)] * scale_num + else: + scale_schedule = get_latent2scale_schedule( + T, H, W, mode=self.schedule_mode + ) + scale_num = len(scale_schedule) + + # x = self.project_in(x) + x = x.permute(0, 2, 3, 4, 1).contiguous() # (b, c, t, h, w) => (b, t, h, w, c) + x = self.project_in(x) + x = x.permute(0, 4, 1, 2, 3).contiguous() # (b, t, h, w, c) => (b, c, t, h, w) + x = self.layernorm(x) + + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + all_bit_indices = [] + var_inputs = [] + residual_norm_per_scale = [] + + # go through the layers + out_fact = init_out_fact = 1.0 + # residual_list = [] + # interpolate_residual_list = [] + # quantized_list = [] + if self.drop_when_test: + drop_lvl_start = self.drop_lvl_idx + drop_lvl_end = self.drop_lvl_idx + self.drop_lvl_num + scale_num = len(scale_schedule) + with autocast("cuda", enabled=False): + for si, (pt, ph, pw) in enumerate(scale_schedule): + out_fact = ( + max(0.1, out_fact) if self.use_decay_factor else init_out_fact + ) + if (pt, ph, pw) != (T, H, W): + interpolate_residual = F.interpolate( + residual, size=(pt, ph, pw), mode=self.z_interplote_down + ) + else: + interpolate_residual = residual + if return_residual_norm_per_scale: + residual_norm_per_scale.append( + ( + torch.abs(interpolate_residual) + < 0.05 * self.lfq.codebook_scale + ).sum() + / interpolate_residual.numel() + ) + # residual_list.append(torch.norm(residual.detach(), dim=1).mean()) + # interpolate_residual_list.append(torch.norm(interpolate_residual.detach(), dim=1).mean()) + if ( + self.training + and self.use_stochastic_depth + and random.random() < self.drop_rate + ): + if (si == 0 and self.keep_first_quant) or ( + si == scale_num - 1 and self.keep_last_quant + ): + quantized, indices, _, loss = self.lfq(interpolate_residual) + quantized = quantized * out_fact + all_indices.append(indices) + all_losses.append(loss) + else: + quantized = torch.zeros_like(interpolate_residual) + elif self.drop_when_test and drop_lvl_start <= si < drop_lvl_end: + continue + else: + # residual_norm = torch.norm(interpolate_residual.detach(), dim=1) # (b, t, h, w) + # print(si, residual_norm.min(), residual_norm.max(), residual_norm.mean()) + quantized, indices, bit_indices, loss = self.lfq( + interpolate_residual + ) + if self.random_flip and si < self.max_flip_lvl: + quantized = self.flip_quant(quantized) + if self.random_flip_1lvl and si == self.flip_lvl_idx: + quantized = self.flip_quant(quantized) + quantized = quantized * out_fact + all_indices.append(indices) + # quantized_list.append(torch.norm(quantized.detach(), dim=1).mean()) + if (pt, ph, pw) != (T, H, W): + quantized = F.interpolate( + quantized, size=(T, H, W), mode=self.z_interplote_up + ).contiguous() + + if self.remove_residual_detach: + residual = residual - quantized + else: + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + + all_bit_indices.append(bit_indices) + all_losses.append(loss) + if si != scale_num - 1: + var_inputs.append( + F.interpolate( + quantized_out, + size=scale_schedule[si + 1], + mode=self.z_interplote_down, + ).contiguous() + ) + + if self.use_decay_factor: + out_fact -= 0.1 + # print("residual_list:", residual_list) + # print("interpolate_residual_list:", interpolate_residual_list) + # print("quantized_list:", quantized_list) + # import ipdb; ipdb.set_trace() + # project out, if needed + quantized_out = quantized_out.permute( + 0, 2, 3, 4, 1 + ).contiguous() # (b, c, t, h, w) => (b, t, h, w, c) + quantized_out = self.project_out(quantized_out) + quantized_out = quantized_out.permute( + 0, 4, 1, 2, 3 + ).contiguous() # (b, t, h, w, c) => (b, c, t, h, w) + + # image + if quantized_out.size(2) == 1: + quantized_out = quantized_out.squeeze(2) + + # stack all losses and indices + + all_losses = torch.stack(all_losses, dim=-1) + + ret = ( + quantized_out, + all_indices, + all_bit_indices, + residual_norm_per_scale, + all_losses, + var_inputs, + ) + + if not return_all_codes: + return ret + + # whether to return all codes from all codebooks across layers + all_codes = self.get_codes_from_indices(all_indices) + + # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) + + return (*ret, all_codes) + + +class BSQ(Module): + def __init__( + self, + *, + dim=None, + codebook_size=None, + entropy_loss_weight=0.1, + commitment_loss_weight=0.25, + diversity_gamma=1.0, + straight_through_activation=nn.Identity(), + num_codebooks=1, + keep_num_codebooks_dim=None, + codebook_scale=1.0, # for residual LFQ, codebook scaled down by 2x at each layer + frac_per_sample_entropy=1.0, # make less than 1. to only use a random fraction of the probs for per sample entropy + has_projections=None, + projection_has_bias=True, + soft_clamp_input_value=None, + cosine_sim_project_in=False, + cosine_sim_project_in_scale=None, + channel_first=None, + experimental_softplus_entropy_loss=False, + entropy_loss_offset=5.0, # how much to shift the loss before softplus + spherical=True, # from https://arxiv.org/abs/2406.07548 + force_quantization_f32=True, # will force the quantization step to be full precision + inv_temperature=100.0, + gamma0=1.0, + gamma=1.0, + zeta=1.0, + preserve_norm=False, # whether to preserve the original norm info + new_quant=False, # new quant function, + mask_out=False, # mask the output as 0 in some conditions + use_out_phi=False, # use output phi network + use_out_phi_res=False, # residual out phi + ): + super().__init__() + + # some assert validations + + assert exists(dim) or exists( + codebook_size + ), "either dim or codebook_size must be specified for LFQ" + assert ( + not exists(codebook_size) or log2(codebook_size).is_integer() + ), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})" + + codebook_size = default(codebook_size, lambda: 2**dim) + self.codebook_size = codebook_size + + codebook_dim = int(log2(codebook_size)) + codebook_dims = codebook_dim * num_codebooks + dim = default(dim, codebook_dims) + self.codebook_dims = codebook_dims + + has_projections = default(has_projections, dim != codebook_dims) + + if cosine_sim_project_in: + cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale) + project_in_klass = partial(CosineSimLinear, scale=cosine_sim_project_in) + else: + project_in_klass = partial(nn.Linear, bias=projection_has_bias) + + self.project_in = ( + project_in_klass(dim, codebook_dims) if has_projections else nn.Identity() + ) # nn.Identity() + self.project_out = ( + nn.Linear(codebook_dims, dim, bias=projection_has_bias) + if has_projections + else nn.Identity() + ) # nn.Identity() + self.has_projections = has_projections + + self.out_phi = ( + nn.Linear(codebook_dims, codebook_dims) if use_out_phi else nn.Identity() + ) + self.use_out_phi_res = use_out_phi_res + if self.use_out_phi_res: + self.out_phi_scale = nn.Parameter( + torch.zeros(codebook_dims), requires_grad=True + ) # init as zero + + self.dim = dim + self.codebook_dim = codebook_dim + self.num_codebooks = num_codebooks + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + # channel first + + self.channel_first = channel_first + + # straight through activation + + self.activation = straight_through_activation + + # For BSQ (binary spherical quantization) + if not spherical: + raise ValueError("For BSQ, spherical must be True.") + self.persample_entropy_compute = "analytical" + self.inv_temperature = inv_temperature + self.gamma0 = gamma0 # loss weight for entropy penalty + self.gamma = gamma # loss weight for entropy penalty + self.zeta = zeta # loss weight for entire entropy penalty + self.preserve_norm = preserve_norm + self.new_quant = new_quant + self.mask_out = mask_out + + # entropy aux loss related weights + + assert 0 < frac_per_sample_entropy <= 1.0 + self.frac_per_sample_entropy = frac_per_sample_entropy + + self.diversity_gamma = diversity_gamma + self.entropy_loss_weight = entropy_loss_weight + + # codebook scale + + self.codebook_scale = codebook_scale + + # commitment loss + + self.commitment_loss_weight = commitment_loss_weight + + # whether to soft clamp the input value from -value to value + + self.soft_clamp_input_value = soft_clamp_input_value + assert ( + not exists(soft_clamp_input_value) + or soft_clamp_input_value >= codebook_scale + ) + + # whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions) + + self.entropy_loss_offset = entropy_loss_offset + self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss + + # for no auxiliary loss, during inference + + self.register_buffer("mask", 2 ** torch.arange(codebook_dim - 1, -1, -1)) + self.register_buffer("zero", torch.tensor(0.0), persistent=False) + + # whether to force quantization step to be f32 + + self.force_quantization_f32 = force_quantization_f32 + + # codes + + # all_codes = torch.arange(codebook_size) + # bits = ((all_codes[..., None].int() & self.mask) != 0).float() + # codebook = self.bits_to_codes(bits) + + # self.register_buffer('codebook', codebook.float(), persistent = False) + + def bits_to_codes(self, bits): + return bits * self.codebook_scale * 2 - self.codebook_scale + + # @property + # def dtype(self): + # return self.codebook.dtype + + def indices_to_codes(self, indices, label_type="int_label", project_out=True): + assert label_type in ["int_label", "bit_label"] + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + should_transpose = default(self.channel_first, is_img_or_video) + + if not self.keep_num_codebooks_dim: + if label_type == "int_label": + indices = rearrange(indices, "... -> ... 1") + else: + indices = indices.unsqueeze(-2) + + # indices to codes, which are bits of either -1 or 1 + + if label_type == "int_label": + assert indices[..., None].int().min() > 0 + bits = ( + (indices[..., None].int() & self.mask) != 0 + ).float() # .to(self.dtype) + else: + bits = indices + + codes = self.bits_to_codes(bits) + + codes = l2norm(codes) # must normalize when using BSQ + + codes = rearrange(codes, "... c d -> ... (c d)") + + # whether to project codes out to original dimensions + # if the input feature dimensions were not log2(codebook size) + + if project_out: + codes = self.project_out(codes) + + # rearrange codes back to original shape + + if should_transpose: + codes = rearrange(codes, "b ... d -> b d ...") + + return codes + + def quantize(self, z): + assert ( + z.shape[-1] == self.codebook_dims + ), f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}" + + zhat = torch.where( + z > 0, + torch.tensor(1, dtype=z.dtype, device=z.device), + torch.tensor(-1, dtype=z.dtype, device=z.device), + ) + return z + (zhat - z).detach() + + def quantize_new(self, z): + assert ( + z.shape[-1] == self.codebook_dims + ), f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}" + + zhat = torch.where( + z > 0, + torch.tensor(1, dtype=z.dtype, device=z.device), + torch.tensor(-1, dtype=z.dtype, device=z.device), + ) + + q_scale = 1.0 / (self.codebook_dims**0.5) + zhat = q_scale * zhat # on unit sphere + + return z + (zhat - z).detach() + + def soft_entropy_loss(self, z): + if self.persample_entropy_compute == "analytical": + # if self.l2_norm: + p = torch.sigmoid(-4 * z / (self.codebook_dims**0.5) * self.inv_temperature) + # else: + # p = torch.sigmoid(-4 * z * self.inv_temperature) + prob = torch.stack([p, 1 - p], dim=-1) # (b, h, w, 18, 2) + per_sample_entropy = ( + self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() + ) # (b,h,w,18)->(b,h,w)->scalar + else: + per_sample_entropy = ( + self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() + ) + + # macro average of the probability of each subgroup + avg_prob = reduce(prob, "... g d ->g d", "mean") # (18, 2) + codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) + + # the approximation of the entropy is the sum of the entropy of each subgroup + return per_sample_entropy, codebook_entropy.sum(), avg_prob + + def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): + if normalize: # False + probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True) + else: # True + probs = count + H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) + return H + + def forward(self, x, return_loss_breakdown=False, mask=None, entropy_weight=0.1): + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension, which is also log2(codebook size) + c - number of codebook dim + """ + + is_img_or_video = x.ndim >= 4 + should_transpose = default(self.channel_first, is_img_or_video) + + # standardize image or video into (batch, seq, dimension) + + if should_transpose: + x = rearrange(x, "b d ... -> b ... d") + x, ps = pack_one(x, "b * d") # x.shape [b, hwt, c] + + assert ( + x.shape[-1] == self.dim + ), f"expected dimension of {self.dim} but received {x.shape[-1]}" + + x = self.project_in(x) + + # split out number of codebooks + + x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks) + + x = l2norm(x) + + # whether to force quantization step to be full precision or not + + force_f32 = self.force_quantization_f32 + + quantization_context = ( + partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext + ) + + indices = None + with quantization_context(): + + if force_f32: + orig_dtype = x.dtype + x = x.float() + + # use straight-through gradients (optionally with custom activation fn) if training + if self.new_quant: + quantized = self.quantize_new(x) + + # calculate indices + bit_indices = (quantized > 0).int() + entropy_penalty = persample_entropy = cb_entropy = self.zero + commit_loss = self.zero + + # input back to original dtype if needed + + if force_f32: + x = x.type(orig_dtype) + + # merge back codebook dim + x = quantized # rename quantized to x for output + x = rearrange(x, "b n c d -> b n (c d)") + + # project out to feature dimension if needed + + x = self.project_out(x) + + # reconstitute image or video dimensions + + if should_transpose: + x = unpack_one(x, ps, "b * d") + x = rearrange(x, "b ... d -> b d ...") + + bit_indices = unpack_one(bit_indices, ps, "b * c d") + + # whether to remove single codebook dim + + if not self.keep_num_codebooks_dim: + bit_indices = rearrange(bit_indices, "... 1 d -> ... d") + + # complete aux loss + + aux_loss = ( + commit_loss * self.commitment_loss_weight + + (self.zeta * entropy_penalty / self.inv_temperature) * entropy_weight + ) + # returns + + ret = Return(x, indices, bit_indices, aux_loss) + + if not return_loss_breakdown: + return ret + + return ret, LossBreakdown(persample_entropy, cb_entropy, commit_loss) diff --git a/src/vqvaes/infinity/vae.py b/src/vqvaes/infinity/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..5c3db35295a057d905d47df2d49b162844cd854f --- /dev/null +++ b/src/vqvaes/infinity/vae.py @@ -0,0 +1,287 @@ +import argparse +import torch + +from .flux_vqgan import AutoEncoder + + +def load_cnn(model, state_dict, prefix, expand=False, use_linear=False): + delete_keys = [] + loaded_keys = [] + for key in state_dict: + if key.startswith(prefix): + _key = key[len(prefix) :] + if _key in model.state_dict(): + # load nn.Conv2d or nn.Linear to nn.Linear + if use_linear and ( + ".q.weight" in key + or ".k.weight" in key + or ".v.weight" in key + or ".proj_out.weight" in key + ): + load_weights = state_dict[key].squeeze() + elif _key.endswith(".conv.weight") and expand: + if model.state_dict()[_key].shape == state_dict[key].shape: + # 2D cnn to 2D cnn + load_weights = state_dict[key] + else: + # 2D cnn to 3D cnn + _expand_dim = model.state_dict()[_key].shape[2] + load_weights = ( + state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1) + ) + else: + load_weights = state_dict[key] + model.state_dict()[_key].copy_(load_weights) + delete_keys.append(key) + loaded_keys.append(prefix + _key) + # load nn.Conv2d to Conv class + conv_list = ( + ["conv"] + if use_linear + else ["conv", ".q.", ".k.", ".v.", ".proj_out.", ".nin_shortcut."] + ) + if any(k in _key for k in conv_list): + if _key.endswith(".weight"): + conv_key = _key.replace(".weight", ".conv.weight") + if conv_key and conv_key in model.state_dict(): + if model.state_dict()[conv_key].shape == state_dict[key].shape: + # 2D cnn to 2D cnn + load_weights = state_dict[key] + else: + # 2D cnn to 3D cnn + _expand_dim = model.state_dict()[conv_key].shape[2] + load_weights = ( + state_dict[key] + .unsqueeze(2) + .repeat(1, 1, _expand_dim, 1, 1) + ) + model.state_dict()[conv_key].copy_(load_weights) + delete_keys.append(key) + loaded_keys.append(prefix + conv_key) + if _key.endswith(".bias"): + conv_key = _key.replace(".bias", ".conv.bias") + if conv_key and conv_key in model.state_dict(): + model.state_dict()[conv_key].copy_(state_dict[key]) + delete_keys.append(key) + loaded_keys.append(prefix + conv_key) + # load nn.GroupNorm to Normalize class + if "norm" in _key: + if _key.endswith(".weight"): + norm_key = _key.replace(".weight", ".norm.weight") + if norm_key and norm_key in model.state_dict(): + model.state_dict()[norm_key].copy_(state_dict[key]) + delete_keys.append(key) + loaded_keys.append(prefix + norm_key) + if _key.endswith(".bias"): + norm_key = _key.replace(".bias", ".norm.bias") + if norm_key and norm_key in model.state_dict(): + model.state_dict()[norm_key].copy_(state_dict[key]) + delete_keys.append(key) + loaded_keys.append(prefix + norm_key) + + for key in delete_keys: + del state_dict[key] + + return model, state_dict, loaded_keys + + +def vae_model( + vqgan_ckpt, + schedule_mode, + codebook_dim, + codebook_size, + test_mode=True, + patch_size=16, + encoder_ch_mult=[1, 2, 4, 4, 4], + decoder_ch_mult=[1, 2, 4, 4, 4], +): + args = argparse.Namespace( + vqgan_ckpt=vqgan_ckpt, + sd_ckpt=None, + inference_type="image", + save="./imagenet_val_bsq", + save_prediction=True, + image_recon4video=False, + junke_old=False, + device="cuda", + max_steps=1000000.0, + log_every=1, + visu_every=1000, + ckpt_every=1000, + default_root_dir="", + compile="no", + ema="no", + lr=0.0001, + beta1=0.9, + beta2=0.95, + warmup_steps=0, + optim_type="Adam", + disc_optim_type=None, + lr_min=0.0, + warmup_lr_init=0.0, + max_grad_norm=1.0, + max_grad_norm_disc=1.0, + disable_sch=False, + patch_size=patch_size, + temporal_patch_size=4, + embedding_dim=256, + codebook_dim=codebook_dim, + num_quantizers=8, + quantizer_type="MultiScaleBSQ", + use_vae=False, + use_freq_enc=False, + use_freq_dec=False, + preserve_norm=False, + ln_before_quant=False, + ln_init_by_sqrt=False, + use_pxsf=False, + new_quant=True, + use_decay_factor=False, + mask_out=False, + use_stochastic_depth=False, + drop_rate=0.0, + schedule_mode=schedule_mode, + lr_drop=None, + lr_drop_rate=0.1, + keep_first_quant=False, + keep_last_quant=False, + remove_residual_detach=False, + use_out_phi=False, + use_out_phi_res=False, + use_lecam_reg=False, + lecam_weight=0.05, + perceptual_model="vgg16", + base_ch_disc=64, + random_flip=False, + flip_prob=0.5, + flip_mode="stochastic", + max_flip_lvl=1, + not_load_optimizer=False, + use_lecam_reg_zero=False, + freeze_encoder=False, + rm_downsample=False, + random_flip_1lvl=False, + flip_lvl_idx=0, + drop_when_test=False, + drop_lvl_idx=0, + drop_lvl_num=1, + disc_version="v1", + magvit_disc=False, + sigmoid_in_disc=False, + activation_in_disc="leaky_relu", + apply_blur=False, + apply_noise=False, + dis_warmup_steps=0, + dis_lr_multiplier=1.0, + dis_minlr_multiplier=False, + disc_channels=64, + disc_layers=3, + discriminator_iter_start=0, + disc_pretrain_iter=0, + disc_optim_steps=1, + disc_warmup=0, + disc_pool="no", + disc_pool_size=1000, + advanced_disc=False, + recon_loss_type="l1", + video_perceptual_weight=0.0, + image_gan_weight=1.0, + video_gan_weight=1.0, + image_disc_weight=0.0, + video_disc_weight=0.0, + l1_weight=4.0, + gan_feat_weight=0.0, + perceptual_weight=0.0, + kl_weight=0.0, + lfq_weight=0.0, + entropy_loss_weight=0.1, + commitment_loss_weight=0.25, + diversity_gamma=1, + norm_type="group", + disc_loss_type="hinge", + use_checkpoint=False, + precision="fp32", + encoder_dtype="fp32", + upcast_attention="", + upcast_tf32=False, + tokenizer="flux", + pretrained=None, + pretrained_mode="full", + inflation_pe=False, + init_vgen="no", + no_init_idis=False, + init_idis="keep", + init_vdis="no", + enable_nan_detector=False, + turn_on_profiler=False, + profiler_scheduler_wait_steps=10, + debug=True, + video_logger=False, + bytenas="", + username="", + seed=1234, + vq_to_vae=False, + load_not_strict=False, + zero=0, + bucket_cap_mb=40, + manual_gc_interval=1000, + data_path=[""], + data_type=[""], + dataset_list=["imagenet"], + fps=-1, + dataaug="resizecrop", + multi_resolution=False, + random_bucket_ratio=0.0, + sequence_length=16, + resolution=[256, 256], + batch_size=[1], + num_workers=0, + image_channels=3, + codebook_size=codebook_size, + codebook_l2_norm=True, + codebook_show_usage=True, + commit_loss_beta=0.25, + entropy_loss_ratio=0.0, + base_ch=128, + num_res_blocks=2, + encoder_ch_mult=encoder_ch_mult, + decoder_ch_mult=decoder_ch_mult, + dropout_p=0.0, + cnn_type="2d", + cnn_version="v1", + conv_in_out_2d="no", + conv_inner_2d="no", + res_conv_2d="no", + cnn_attention="no", + cnn_norm_axis="spatial", + flux_weight=0, + cycle_weight=0, + cycle_feat_weight=0, + cycle_gan_weight=0, + cycle_loop=0, + z_drop=0.0, + ) + + vae = AutoEncoder(args) + use_vae = vae.use_vae + if not use_vae: + num_codes = args.codebook_size + if isinstance(vqgan_ckpt, str): + state_dict = torch.load( + args.vqgan_ckpt, map_location=torch.device("cpu"), weights_only=True + ) + else: + state_dict = args.vqgan_ckpt + if state_dict: + if args.ema == "yes": + vae, new_state_dict, loaded_keys = load_cnn( + vae, state_dict["ema"], prefix="", expand=False + ) + else: + vae, new_state_dict, loaded_keys = load_cnn( + vae, state_dict["vae"], prefix="", expand=False + ) + if test_mode: + vae.eval() + [p.requires_grad_(False) for p in vae.parameters()] + return vae diff --git a/src/vqvaes/janus_pro/janus_pro.py b/src/vqvaes/janus_pro/janus_pro.py new file mode 100644 index 0000000000000000000000000000000000000000..e23e23d84dd5384ee34ff91204be78180171fab7 --- /dev/null +++ b/src/vqvaes/janus_pro/janus_pro.py @@ -0,0 +1,15 @@ +def forward(self, input): + quant, diff, [_, _, img_toks] = self.encode(input) + + batch_size, height, width, n_channel = ( + input.shape[0], + quant.shape[-1], + quant.shape[-2], + quant.shape[-3], + ) + codebook_entry = self.quantize.get_codebook_entry( + img_toks, (batch_size, n_channel, height, width) + ) + pixels = self.decode(codebook_entry) + + return pixels, img_toks, quant diff --git a/src/vqvaes/llamagen/llamagen.py b/src/vqvaes/llamagen/llamagen.py new file mode 100644 index 0000000000000000000000000000000000000000..214e279ae326c8e5efc0d9e21b09a56014a8be4b --- /dev/null +++ b/src/vqvaes/llamagen/llamagen.py @@ -0,0 +1,530 @@ +# Modified from: +# taming-transformers: https://github.com/CompVis/taming-transformers +# maskgit: https://github.com/google-research/maskgit +from dataclasses import dataclass, field +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +@dataclass +class ModelArgs: + codebook_size: int = 16384 + codebook_embed_dim: int = 8 + codebook_l2_norm: bool = True + codebook_show_usage: bool = True + commit_loss_beta: float = 0.25 + entropy_loss_ratio: float = 0.0 + + encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + z_channels: int = 256 + dropout_p: float = 0.0 + + +class VQModel(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.encoder = Encoder( + ch_mult=config.encoder_ch_mult, + z_channels=config.z_channels, + dropout=config.dropout_p, + ) + self.decoder = Decoder( + ch_mult=config.decoder_ch_mult, + z_channels=config.z_channels, + dropout=config.dropout_p, + ) + + self.quantize = VectorQuantizer( + config.codebook_size, + config.codebook_embed_dim, + config.commit_loss_beta, + config.entropy_loss_ratio, + config.codebook_l2_norm, + config.codebook_show_usage, + ) + self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1) + self.post_quant_conv = nn.Conv2d( + config.codebook_embed_dim, config.z_channels, 1 + ) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b, shape=None, channel_first=True): + quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first) + dec = self.decode(quant_b) + return dec + + # def forward(self, input): + # quant, diff, _ = self.encode(input) + # dec = self.decode(quant) + # return dec, diff + + def forward(self, input): + quant, diff, [_, _, img_toks] = self.encode(input) + batch_size, n_channel, height, width = ( + input.shape[0], + quant.shape[-3], + quant.shape[-1], + quant.shape[-2], + ) + codebook_entry = self.quantize.get_codebook_entry( + img_toks, (batch_size, n_channel, height, width) + ) + pixels = self.decode(codebook_entry) + + return pixels, img_toks, quant + + +class Encoder(nn.Module): + def __init__( + self, + in_channels=3, + ch=128, + ch_mult=(1, 1, 2, 2, 4), + num_res_blocks=2, + norm_type="group", + dropout=0.0, + resamp_with_conv=True, + z_channels=256, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1) + + # downsampling + in_ch_mult = (1,) + tuple(ch_mult) + self.conv_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + res_block.append( + ResnetBlock( + block_in, block_out, dropout=dropout, norm_type=norm_type + ) + ) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != self.num_resolutions - 1: + conv_block.downsample = Downsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # middle + self.mid = nn.ModuleList() + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d( + block_in, z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + h = self.conv_in(x) + # downsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.downsample(h) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + z_channels=256, + ch=128, + ch_mult=(1, 1, 2, 2, 4), + num_res_blocks=2, + norm_type="group", + dropout=0.0, + resamp_with_conv=True, + out_channels=3, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + + block_in = ch * ch_mult[self.num_resolutions - 1] + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.ModuleList() + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + + # upsampling + self.conv_blocks = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + block_in, block_out, dropout=dropout, norm_type=norm_type + ) + ) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != 0: + conv_block.upsample = Upsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + @property + def last_layer(self): + return self.conv_out.weight + + def forward(self, z): + # z to block_in + h = self.conv_in(z) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # upsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks + 1): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VectorQuantizer(nn.Module): + def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.entropy_loss_ratio = entropy_loss_ratio + self.l2_norm = l2_norm + self.show_usage = show_usage + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + if self.l2_norm: + self.embedding.weight.data = F.normalize( + self.embedding.weight.data, p=2, dim=-1 + ) + if self.show_usage: + self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536))) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = torch.einsum("b c h w -> b h w c", z).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + if self.l2_norm: + z = F.normalize(z, p=2, dim=-1) + z_flattened = F.normalize(z_flattened, p=2, dim=-1) + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(embedding**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding) + ) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = embedding[min_encoding_indices].view(z.shape) + perplexity = None + min_encodings = None + vq_loss = None + commit_loss = None + entropy_loss = None + codebook_usage = 0 + + if self.show_usage and self.training: + cur_len = min_encoding_indices.shape[0] + self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone() + self.codebook_used[-cur_len:] = min_encoding_indices + codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e + + # compute loss for embedding + if self.training: + vq_loss = torch.mean((z_q - z.detach()) ** 2) + commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = torch.einsum("b h w c -> b c h w", z_q) + + return ( + z_q, + (vq_loss, commit_loss, entropy_loss, codebook_usage), + (perplexity, min_encodings, min_encoding_indices), + ) + + def get_codebook_entry(self, indices, shape=None, channel_first=True): + # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel) + if self.l2_norm: + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + z_q = embedding[indices] # (b*h*w, c) + + if shape is not None: + if channel_first: + z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + else: + z_q = z_q.view(shape) + return z_q + + +class ResnetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + norm_type="group", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = Normalize(out_channels, norm_type) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type="group"): + super().__init__() + self.norm = Normalize(in_channels, norm_type) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, norm_type="group"): + assert norm_type in ["group", "batch"] + if norm_type == "group": + return nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + elif norm_type == "batch": + return nn.SyncBatchNorm(in_channels) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): + flat_affinity = affinity.reshape(-1, affinity.shape[-1]) + flat_affinity /= temperature + probs = F.softmax(flat_affinity, dim=-1) + log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) + if loss_type == "softmax": + target_probs = probs + else: + raise ValueError("Entropy loss {} not supported".format(loss_type)) + avg_probs = torch.mean(target_probs, dim=0) + avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5)) + sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1)) + loss = sample_entropy - avg_entropy + return loss + + +################################################################################# +# VQ Model Configs # +################################################################################# +def VQ_8(**kwargs): + return VQModel( + ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs) + ) + + +def VQ_16(**kwargs): + return VQModel( + ModelArgs( + encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs + ) + ) + + +VQ_models = {"VQ-16": VQ_16, "VQ-8": VQ_8} diff --git a/src/vqvaes/maskbit/maskbit.py b/src/vqvaes/maskbit/maskbit.py new file mode 100644 index 0000000000000000000000000000000000000000..417986d77cacba6802d69064045e78fcf3ca3cdb --- /dev/null +++ b/src/vqvaes/maskbit/maskbit.py @@ -0,0 +1,154 @@ +"""This file contains the definition of the our tokenizer, which can use VQ or LFQ.""" + +import math +from typing import Mapping, Text, Tuple + +import torch +from einops import rearrange + +from .modules import BaseModel, ConvDecoder, ConvDecoderLegacy, ConvEncoder +from .quantizer import LookupFreeQuantizer, SimpleVectorizer + + +def choose_vector_quantizer_class(config): + if config.quantizer_type == "lookup": + return SimpleVectorizer( + config.codebook_size, + config.token_size, + config.commitment_cost, + config.entropy_loss_weight, + config.entropy_loss_temperature, + config.entropy_gamma, + config.get("use_l2_normalisation", False), + ) + elif config.quantizer_type == "lookup-free": + return LookupFreeQuantizer( + config.token_size, + config.commitment_cost, + config.entropy_loss_weight, + config.entropy_loss_temperature, + config.entropy_gamma, + ) + elif config.quantizer_type == "vae": + return NotImplementedError( + "Currently not supported. We welcome a well tested PR." + ) + else: + raise ValueError("Unknown vector quantizer class") + + +class ConvVQModel(BaseModel): + def __init__(self, config, legacy: bool = False, finetune_decoder: bool = False): + """Initializes the convolutional VQ-VAE model. + + Args: + config: The configuration for the model. + legacy -> bool: Whether to use the legacy decoder, which is a different implementation of the same architecture. + finetune_decoder -> bool: Whether to finetune the decoder. + """ + super().__init__() + self.config = config + self.encoder = ConvEncoder(self.config) + if legacy: + # To support older weights and MaskGIT + self.decoder = ConvDecoderLegacy(self.config) + else: + self.decoder = ConvDecoder(self.config) + + self.finetune_decoder = finetune_decoder + if self.finetune_decoder: + self.encoder.eval() + self.encoder.requires_grad_(False) + self.quantize = choose_vector_quantizer_class(self.config) + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def encode( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Encodes the input tensor, i.e. runs the encoder. + + Args: + x -> torch.Tensor: The input tensor. + + Returns: + z_quantized -> torch.Tensor: The quantized latent representation. + result_dict -> Mapping[Text, torch.Tensor]: A dictionary containing additional results + and losses from the quantizer. + """ + z = self.encoder(x) + z_quantized, result_dict = self.quantize(z) + return z_quantized, result_dict + + def decode(self, z_quantized: torch.Tensor) -> torch.Tensor: + """Decodes the quantized latent representation, i.e. runs the decoder. + + Args: + z_quantized -> torch.Tensor: The quantized latent representation. + + Returns: + decoded -> torch.Tensor: The decoded image. + """ + decoded = self.decoder(z_quantized) + return decoded + + def decode_tokens(self, tokens: torch.Tensor) -> torch.Tensor: + """Decodes from tokens, i.e. runs the decoder after converting tokens to latent representations. + + Args: + tokens -> torch.Tensor: The tokens. + + Returns: + decoded -> torch.Tensor: The decoded image. + """ + z_quantized = self.quantize.get_codebook_entry(tokens) + ss = int(math.sqrt(float(z_quantized.size(1)))) + z_quantized = z_quantized.reshape(z_quantized.size(0), ss, ss, -1) + z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous() + decoded = self.decode(z_quantized) + return decoded + + def forward( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Runs the model on the input tensor. + + Args: + input -> torch.Tensor: The input image. + + Returns: + decoded -> torch.Tensor: The decoded image. + result_dict -> Mapping[Text, torch.Tensor]: A dictionary containing additional results + and losses from the quantizer. + """ + if self.finetune_decoder: + self.encoder.eval() + z_quantized, result_dict = self._finetuning_encoder_forward(input) + else: + z_quantized, result_dict = self.encode(input) + + decoded = self.decode(z_quantized) + return decoded, result_dict["min_encoding_indices"], z_quantized + + def _finetuning_encoder_forward( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Runs the encoder on the input tensor without gradients and sets quantizer losses to 0. + + Args: + input -> torch.Tensor: The input image. + + Returns: + z_quantized -> torch.Tensor: The quantized latent representation. + result_dict -> Mapping[Text, torch.Tensor]: A dictionary containing additional results + and losses from the quantizer. + """ + with torch.no_grad(): + z_quantized, result_dict = self.encode(input) + result_dict["quantizer_loss"] *= 0 + result_dict["commitment_loss"] *= 0 + if "codebook_loss" in result_dict: + result_dict["codebook_loss"] *= 0 + result_dict["entropy_loss"] *= 0 + return z_quantized, result_dict diff --git a/src/vqvaes/maskbit/modules/__init__.py b/src/vqvaes/maskbit/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aea8c7a60fdee23b500370e863f372afc888f301 --- /dev/null +++ b/src/vqvaes/maskbit/modules/__init__.py @@ -0,0 +1,17 @@ +from .autoencoder import ( + ConvEncoder, + ConvDecoder, + ConvDecoderLegacy, + Conv2dSame, + ResidualStage, + GroupNorm, +) +from .base_model import BaseModel +from .ema_model import EMAModel +from .discriminator import OriginalNLayerDiscriminator, NLayerDiscriminatorv2 +from .losses import VQGANLoss, MLMLoss +from .perceptual_loss import PerceptualLoss +from .lpips import LPIPS +from .masking import get_mask_tokens, get_masking_ratio +from .factorization import combine_factorized_tokens, split_factorized_tokens +from .sampling import sample diff --git a/src/vqvaes/maskbit/modules/autoencoder.py b/src/vqvaes/maskbit/modules/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..25498406403ead6a75ce7dbecf5f6c8618234eb7 --- /dev/null +++ b/src/vqvaes/maskbit/modules/autoencoder.py @@ -0,0 +1,530 @@ +"""This file contains the definition of the the autoencoder parts""" + +import math +import torch +import torch.nn.functional as F + + +class Conv2dSame(torch.nn.Conv2d): + """Convolution wrapper for 2D convolutions using `SAME` padding.""" + + def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: + """Calculate padding such that the output has the same height/width when stride=1. + + Args: + i -> int: Input size. + k -> int: Kernel size. + s -> int: Stride size. + d -> int: Dilation rate. + """ + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the convolution applying explicit `same` padding. + + Args: + x -> torch.Tensor: Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + ih, iw = x.size()[-2:] + + pad_h = self.calc_same_pad( + i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0] + ) + pad_w = self.calc_same_pad( + i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1] + ) + + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + return super().forward(x) + + +def GroupNorm(in_channels): + """GroupNorm with 32 groups.""" + if in_channels % 32 != 0: + raise ValueError( + f"GroupNorm requires in_channels to be divisible by 32, got {in_channels}." + ) + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class ResidualBlock(torch.nn.Module): + """Residual block with two convolutional layers.""" + + def __init__(self, in_channels: int, out_channels: int = None, norm_func=GroupNorm): + """Initializes the residual block. + + Args: + in_channels -> int: Number of input channels. + out_channels -> int: Number of output channels. Default is in_channels. + norm_func -> Callable: Normalization function. Default is GroupNorm. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = self.in_channels if out_channels is None else out_channels + + self.norm1 = norm_func(self.in_channels) + self.conv1 = Conv2dSame( + self.in_channels, self.out_channels, kernel_size=3, bias=False + ) + + self.norm2 = norm_func(self.out_channels) + self.conv2 = Conv2dSame( + self.out_channels, self.out_channels, kernel_size=3, bias=False + ) + + if self.in_channels != self.out_channels: + self.nin_shortcut = Conv2dSame( + self.out_channels, self.out_channels, kernel_size=1, bias=False + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward pass of the residual block. + + Args: + hidden_states -> torch.Tensor: Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(hidden_states) + + return hidden_states + residual + + +class ResidualStage(torch.nn.Module): + """Residual stage with multiple residual blocks.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + num_res_blocks: int, + norm_func=GroupNorm, + ): + """Initializes the residual stage. + + Args: + in_channels -> int: Number of input channels. + out_channels -> int: Number of output channels. + num_res_blocks -> int: Number of residual blocks. + norm_func -> Callable: Normalization function. Default is GroupNorm. + """ + super().__init__() + + self.res_blocks = torch.nn.ModuleList() + for _ in range(num_res_blocks): + self.res_blocks.append( + ResidualBlock(in_channels, out_channels, norm_func=norm_func) + ) + in_channels = out_channels + + def forward(self, hidden_states: torch.Tensor, *unused_args) -> torch.Tensor: + """Forward pass of the residual stage. + + Args: + hidden_states -> torch.Tensor: Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + for res_block in self.res_blocks: + hidden_states = res_block(hidden_states) + + return hidden_states + + +class DownsamplingStage(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_res_blocks: int, + sample_with_conv: bool = False, + norm_func=GroupNorm, + ): + """Initializes the downsampling stage. + + Args: + in_channels -> int: Number of input channels. + out_channels -> int: Number of output channels. + num_res_blocks -> int: Number of residual blocks. + sample_with_conv -> bool: Whether to sample with a convolution or with a stride. Default is False. + norm_func -> Callable: Normalization function. Default is GroupNorm. + """ + super().__init__() + + self.res_blocks = torch.nn.ModuleList() + for _ in range(num_res_blocks): + self.res_blocks.append(ResidualBlock(in_channels, out_channels, norm_func)) + in_channels = out_channels + + self.sample_with_conv = sample_with_conv + if self.sample_with_conv: + self.down_conv = Conv2dSame( + in_channels, in_channels, kernel_size=3, stride=2 + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward pass of the downsampling stage. + + Args: + hidden_states -> torch.Tensor: Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + for res_block in self.res_blocks: + hidden_states = res_block(hidden_states) + + if self.sample_with_conv: + hidden_states = self.down_conv(hidden_states) + else: + hidden_states = F.avg_pool2d(hidden_states, kernel_size=2, stride=2) + + return hidden_states + + +class UpsamplingStage(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_res_blocks: int, + norm_func=GroupNorm, + ): + """Initializes the upsampling stage. + + Args: + in_channels -> int: Number of input channels. + out_channels -> int: Number of output channels. + num_res_blocks -> int: Number of residual blocks. + norm_func -> Callable: Normalization function. Default is GroupNorm. + """ + super().__init__() + + self.res_blocks = torch.nn.ModuleList() + for _ in range(num_res_blocks): + self.res_blocks.append(ResidualBlock(in_channels, out_channels, norm_func)) + in_channels = out_channels + + self.upsample_conv = Conv2dSame(out_channels, out_channels, kernel_size=3) + + def forward(self, hidden_states: torch.Tensor, *unused_args) -> torch.Tensor: + """Forward pass of the upsampling stage. + + Args: + hidden_states -> torch.Tensor: Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + for res_block in self.res_blocks: + hidden_states = res_block(hidden_states) + + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.upsample_conv(hidden_states) + + return hidden_states + + +class ConvEncoder(torch.nn.Module): + def __init__(self, config): + """Initializes the convolutional encoder. + + Args: + config: Configuration of the model architecture. + """ + super().__init__() + self.config = config + + self.conv_in = Conv2dSame( + self.config.num_channels, + self.config.hidden_channels, + kernel_size=3, + bias=False, + ) + + in_channel_mult = (1,) + tuple(self.config.channel_mult) + num_res_blocks = self.config.num_res_blocks + hidden_channels = self.config.hidden_channels + + encoder_blocks = [] + for i_level in range(self.config.num_resolutions): + in_channels = hidden_channels * in_channel_mult[i_level] + out_channels = hidden_channels * in_channel_mult[i_level + 1] + + if i_level < (self.config.num_resolutions - 1): + encoder_blocks.append( + DownsamplingStage( + in_channels, + out_channels, + num_res_blocks, + self.config.sample_with_conv, + ) + ) + else: + encoder_blocks.append( + ResidualStage(in_channels, out_channels, num_res_blocks) + ) + self.down = torch.nn.ModuleList(encoder_blocks) + + # middle + mid_channels = out_channels + self.mid = ResidualStage(mid_channels, mid_channels, num_res_blocks) + + # end + self.norm_out = GroupNorm(mid_channels) + self.conv_out = Conv2dSame(mid_channels, self.config.token_size, kernel_size=1) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """Forward pass of the convolutional encoder. + + Args: + pixel_values -> torch.Tensor: Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + # downsampling + hidden_states = self.conv_in(pixel_values) + + for block in self.down: + hidden_states = block(hidden_states) + # middle + hidden_states = self.mid(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class ConvDecoderLegacy(torch.nn.Module): + """ + This is a legacy decoder class. It is used to support older weights. + """ + + def __init__(self, config): + """Initializes the convolutional decoder in a legacy variant. + + Args: + config: Configuration of the model architecture. + """ + super().__init__() + + self.config = config + + # compute in_channel_mult, block_in and curr_res at lowest res + block_in = ( + self.config.hidden_channels + * self.config.channel_mult[self.config.num_resolutions - 1] + ) + num_res_blocks = self.config.num_res_blocks + hidden_channels = self.config.hidden_channels + in_channel_mult = tuple(self.config.channel_mult) + ( + self.config.channel_mult[-1], + ) + + # z to block_in + self.conv_in = Conv2dSame(self.config.token_size, block_in, kernel_size=3) + + # middle + self.mid = ResidualStage(block_in, block_in, num_res_blocks) + + # upsampling + decoder_blocks = [] + for i_level in reversed(range(self.config.num_resolutions)): + in_channels = hidden_channels * in_channel_mult[i_level + 1] + out_channels = hidden_channels * in_channel_mult[i_level] + if i_level > 0: + decoder_blocks.append( + UpsamplingStage(in_channels, out_channels, num_res_blocks) + ) + else: + decoder_blocks.append( + ResidualStage(in_channels, out_channels, num_res_blocks) + ) + + self.up = torch.nn.ModuleList(list(reversed(decoder_blocks))) + + # end + self.norm_out = GroupNorm(out_channels) + self.conv_out = Conv2dSame( + out_channels, self.config.num_channels, kernel_size=3 + ) + + def forward(self, z_quantized: torch.Tensor) -> torch.Tensor: + """Forward pass of the convolutional decoder. + + Args: + z_quantized -> torch.Tensor: Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + # z to block_in + hidden_states = self.conv_in(z_quantized) + + # middle + hidden_states = self.mid(hidden_states) + + # upsampling decoder + for block in reversed(self.up): + hidden_states = block(hidden_states, z_quantized) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class ConvDecoder(torch.nn.Module): + def __init__(self, config): + """Initializes the convolutional decoder. + + Args: + config: Configuration of the model architecture. + """ + super().__init__() + + self.config = config + + # compute in_channel_mult, block_in and curr_res at lowest res + block_in = ( + self.config.hidden_channels + * self.config.channel_mult[self.config.num_resolutions - 1] + ) + num_res_blocks = self.config.get( + "num_res_blocks_decoder", self.config.num_res_blocks + ) + hidden_channels = self.config.hidden_channels + in_channel_mult = tuple(self.config.channel_mult) + ( + self.config.channel_mult[-1], + ) + + # z to block_in + if config.quantizer_type == "vae": + self.conv_in = Conv2dSame( + self.config.token_size // 2, block_in, kernel_size=3 + ) + else: + self.conv_in = Conv2dSame(self.config.token_size, block_in, kernel_size=3) + + # middle + self.mid = ResidualStage(block_in, block_in, num_res_blocks) + + # upsampling + decoder_blocks = [] + for i_level in reversed(range(self.config.num_resolutions)): + in_channels = hidden_channels * in_channel_mult[i_level + 1] + out_channels = hidden_channels * in_channel_mult[i_level] + if i_level > 0: + decoder_blocks.append( + UpsamplingStage(in_channels, out_channels, num_res_blocks) + ) + else: + decoder_blocks.append( + ResidualStage(in_channels, out_channels, num_res_blocks) + ) + self.up = torch.nn.ModuleList(decoder_blocks) + + # end + self.norm_out = GroupNorm(out_channels) + self.conv_out = Conv2dSame( + out_channels, self.config.num_channels, kernel_size=3 + ) + + def forward(self, z_quantized: torch.Tensor) -> torch.Tensor: + """Forward pass of the convolutional decoder. + + Args: + z_quantized -> torch.Tensor: Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + # z to block_in + hidden_states = self.conv_in(z_quantized) + + # middle + hidden_states = self.mid(hidden_states) + + # upsampling decoder + for block in self.up: + hidden_states = block(hidden_states, z_quantized) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +if __name__ == "__main__": + + class Config: + def __init__(self, **kwargs): + for key in kwargs: + setattr(self, key, kwargs[key]) + + def get(self, key, default): + return getattr(self, key, default) + + config_dict = dict( + resolution=256, + num_channels=3, + hidden_channels=128, + channel_mult=(1, 2, 2, 4), + num_res_blocks=2, + codebook_size=1024, + token_size=256, + num_resolutions=4, + sample_with_conv=False, + quantizer_type="lookup", + ) + config = Config(**config_dict) + + encoder = ConvEncoder(config) + decoder = ConvDecoder(config) + + config.sample_with_conv = True + encoder_conv_down = ConvEncoder(config) + + print("Encoder:\n{}".format(encoder)) + print("Encoder downsampling with conv:\n{}".format(encoder_conv_down)) + print("Decoder:\n{}".format(decoder)) + + x = torch.randn((1, 3, 256, 256)) + x_enc = encoder(x) + x_enc_down_with_conv = encoder_conv_down(x) + x_dec = decoder(x_enc) + x_dec_down_with_conv = decoder(x_enc_down_with_conv) + + print(f"Input shape: {x.shape}") + print(f"Encoder output shape: {x_enc.shape}") + print(f"Encoder with conv as down output shape: {x_enc_down_with_conv.shape}") + print(f"Decoder output shape: {x_dec.shape}") + print(f"Decoder with conv as down output shape: {x_dec_down_with_conv.shape}") diff --git a/src/vqvaes/maskbit/modules/base_model.py b/src/vqvaes/maskbit/modules/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6c1d186fdffe15359fdc0a89c44c5f6103cb51e3 --- /dev/null +++ b/src/vqvaes/maskbit/modules/base_model.py @@ -0,0 +1,203 @@ +"""This file contains the definition of base classes. + +We thank the following public implementations for inspiring this code: + https://github.com/huggingface/open-muse/blob/main/muse/modeling_utils.py +""" + +import copy +import os +from typing import Union, Callable, Tuple, Dict, Optional, List + +import torch + + +def get_parameter_device(parameter: torch.nn.Module): + try: + return next(parameter.parameters()).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes( + module: torch.nn.Module, + ) -> List[Tuple[str, torch.Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module): + try: + return next(parameter.parameters()).dtype + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes( + module: torch.nn.Module, + ) -> List[Tuple[str, torch.Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +class BaseModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + save_function: Callable = None, + state_dict: Optional[Dict[str, torch.Tensor]] = None, + ): + """Save a model to a directory, so that it can be re-loaded using the + load_pretrained class method. + + Args: + save_directory -> Union[str, os.PathLike]: Directory to which to save. Will be created + if it doesn't exist. + save_function -> Optional[Callable]: The function to use to save the state dictionary. + Useful on distributed training like TPUs when one need to replace `torch.save` by another method. + state_dict -> Optional[Dict[str, torch.Tensor]]: The state dictionary to save. If `None`, the model's + state dictionary will be saved. + """ + if os.path.isfile(save_directory): + print(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self + + # Save the model + if state_dict is None: + state_dict = model_to_save.state_dict() + + weights_name = "pytorch_model.bin" + + # Save the model + save_function(state_dict, os.path.join(save_directory, weights_name)) + + print(f"Model weights saved in {os.path.join(save_directory, weights_name)}") + + def load_pretrained( + self, + # pretrained_model_path: Union[str, os.PathLike], + checkpoint, + strict_loading: bool = True, + torch_dtype: Optional[torch.dtype] = None, + rename_keys: Optional[Dict[str, str]] = None, + ): + """Instantiate a pretrained pytorch model from a weights path. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). + To train the model, you should first set it back in training mode with `model.train()`. + + Args: + pretrained_model_path -> Union[str, os.PathLike]: Path to a pretrained model. + strict_loading -> bool: Whether or not to strictly enforce that the provided weights file matches the + architecture of this model. + torch_dtype -> Optional[torch.dtype]: The dtype to use for the model. Defaults to `None`, which means + no conversion. + rename_keys -> Optional[Dict[str, str]]: A dictionary containing the keys to rename. + Defaults to `None`, which means no renaming. + """ + # if os.path.isfile(pretrained_model_path): + # model_file = pretrained_model_path + # elif os.path.isdir(pretrained_model_path): + # pretrained_model_path = os.path.join(pretrained_model_path, "pytorch_model.bin") + # if os.path.isfile(pretrained_model_path): + # model_file = pretrained_model_path + # else: + # raise ValueError(f"{pretrained_model_path} does not exist") + # else: + # raise ValueError(f"{pretrained_model_path} does not exist") + # + # checkpoint = torch.load(model_file, map_location="cpu") + new_checkpoint = copy.deepcopy(checkpoint) + + if rename_keys is not None: + for p_key in checkpoint: + for r_key in rename_keys: + if p_key.startswith(r_key): + new_checkpoint[p_key.replace(r_key, rename_keys[r_key])] = ( + checkpoint[p_key] + ) + new_checkpoint.pop(p_key) + break + + checkpoint = new_checkpoint + + self.load_state_dict(checkpoint, strict=strict_loading) + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + self.to(torch_dtype) + + # Set model in evaluation mode to deactivate DropOut modules by default + self.eval() + + @property + def device(self): + """Returns the device of the model. + + Returns: + `torch.device`: The device of the model. + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """Returns the dtype of the model.""" + return get_parameter_dtype(self) + + def num_parameters( + self, only_trainable: bool = False, exclude_embeddings: bool = False + ) -> int: + """ + Get number of (optionally, trainable or non-embeddings) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embeddings parameters + + Returns: + `int`: The number of parameters. + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.named_modules() + if isinstance(module_type, torch.nn.Embedding) + ] + non_embedding_parameters = [ + parameter + for name, parameter in self.named_parameters() + if name not in embedding_param_names + ] + return sum( + p.numel() + for p in non_embedding_parameters + if p.requires_grad or not only_trainable + ) + else: + return sum( + p.numel() + for p in self.parameters() + if p.requires_grad or not only_trainable + ) diff --git a/src/vqvaes/maskbit/modules/discriminator.py b/src/vqvaes/maskbit/modules/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..0a8d11b0040892318a04eaddb40442a153952b56 --- /dev/null +++ b/src/vqvaes/maskbit/modules/discriminator.py @@ -0,0 +1,302 @@ +"""This file contains the definition of the discriminator.""" + +import functools +import math +from typing import Tuple + +import torch +import torch.nn.functional as F + +from .autoencoder import Conv2dSame + + +class BlurBlock(torch.nn.Module): + def __init__(self, kernel: Tuple[int] = (1, 3, 3, 1)): + """Initializes the blur block. + + Args: + kernel -> Tuple[int]: The kernel size. + """ + super().__init__() + + self.kernel_size = len(kernel) + + kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False) + kernel = kernel[None, :] * kernel[:, None] + kernel /= kernel.sum() + kernel = kernel.unsqueeze(0).unsqueeze(0) + self.register_buffer("kernel", kernel) + + def calc_same_pad(self, i: int, k: int, s: int) -> int: + """Calculates the same padding for the BlurBlock. + + Args: + i -> int: Input size. + k -> int: Kernel size. + s -> int: Stride. + + Returns: + pad -> int: The padding. + """ + return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x -> torch.Tensor: The input tensor. + + Returns: + out -> torch.Tensor: The output tensor. + """ + ic, ih, iw = x.size()[-3:] + pad_h = self.calc_same_pad(i=ih, k=self.kernel_size, s=2) + pad_w = self.calc_same_pad(i=iw, k=self.kernel_size, s=2) + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + + weight = self.kernel.expand(ic, -1, -1, -1) + + out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1]) + return out + + +class NLayerDiscriminatorv2(torch.nn.Module): + def __init__( + self, + num_channels: int = 3, + hidden_channels: int = 64, + num_stages: int = 3, + activation_fn: str = "leaky_relu", + blur_resample: bool = False, + blur_kernel_size: int = 4, + ): + """Initializes the NLayerDiscriminatorv2. + + Args: + num_channels -> int: The number of input channels. + hidden_channels -> int: The number of hidden channels. + num_stages -> int: The number of stages. + activation_fn -> str: The activation function. + blur_resample -> bool: Whether to use blur resampling. + blur_kernel_size -> int: The blur kernel size. + """ + super().__init__() + assert num_stages > 0, "Discriminator cannot have 0 stages" + assert (not blur_resample) or ( + blur_kernel_size >= 3 and blur_kernel_size <= 5 + ), "Blur kernel size must be in [3,5] when sampling]" + + in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages))) + init_kernel_size = 5 + if activation_fn == "leaky_relu": + activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1) + else: + activation = torch.nn.SiLU + + self.block_in = torch.nn.Sequential( + Conv2dSame(num_channels, hidden_channels, kernel_size=init_kernel_size), + activation(), + ) + + BLUR_KERNEL_MAP = { + 3: (1, 2, 1), + 4: (1, 3, 3, 1), + 5: (1, 4, 6, 4, 1), + } + + discriminator_blocks = [] + for i_level in range(num_stages): + in_channels = hidden_channels * in_channel_mult[i_level] + out_channels = hidden_channels * in_channel_mult[i_level + 1] + block = torch.nn.Sequential( + Conv2dSame( + in_channels, + out_channels, + kernel_size=3, + ), + ( + torch.nn.AvgPool2d(kernel_size=2, stride=2) + if not blur_resample + else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size]) + ), + torch.nn.GroupNorm(32, out_channels), + activation(), + ) + discriminator_blocks.append(block) + + self.blocks = torch.nn.ModuleList(discriminator_blocks) + + self.pool = torch.nn.AdaptiveMaxPool2d((16, 16)) + + self.to_logits = torch.nn.Sequential( + Conv2dSame(out_channels, out_channels, 1), + activation(), + Conv2dSame(out_channels, 1, kernel_size=5), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x -> torch.Tensor: The input tensor. + + Returns: + output -> torch.Tensor: The output tensor. + """ + hidden_states = self.block_in(x) + for block in self.blocks: + hidden_states = block(hidden_states) + + hidden_states = self.pool(hidden_states) + + return self.to_logits(hidden_states) + + +class OriginalNLayerDiscriminator(torch.nn.Module): + """Defines a PatchGAN discriminator like in Pix2Pix as used by Taming VQGAN + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + + def __init__( + self, + num_channels: int = 3, + hidden_channels: int = 64, + num_stages: int = 3, + ): + """Initializes a PatchGAN discriminator. + + Args: + num_channels -> int: The number of input channels. + hidden_channels -> int: The number of hidden channels. + num_stages -> int: The number of stages. + """ + super(OriginalNLayerDiscriminator, self).__init__() + norm_layer = torch.nn.BatchNorm2d + + sequence = [ + torch.nn.Conv2d( + num_channels, hidden_channels, kernel_size=4, stride=2, padding=1 + ), + torch.nn.LeakyReLU(0.2, True), + ] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, num_stages): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + torch.nn.Conv2d( + hidden_channels * nf_mult_prev, + hidden_channels * nf_mult, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + norm_layer(hidden_channels * nf_mult), + torch.nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**num_stages, 8) + sequence += [ + torch.nn.Conv2d( + hidden_channels * nf_mult_prev, + hidden_channels * nf_mult, + kernel_size=4, + stride=1, + padding=1, + bias=False, + ), + norm_layer(hidden_channels * nf_mult), + torch.nn.LeakyReLU(0.2, True), + ] + + sequence += [ + torch.nn.Conv2d( + hidden_channels * nf_mult, 1, kernel_size=4, stride=1, padding=1 + ) + ] # output 1 channel prediction map + self.main = torch.nn.Sequential(*sequence) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x -> torch.Tensor: The input tensor. + + Returns: + output -> torch.Tensor: The output tensor. + """ + return self.main(x) + + +if __name__ == "__main__": + patch_discriminator_v2 = NLayerDiscriminatorv2( + num_channels=3, hidden_channels=128, num_stages=3 + ) + patch_discriminator_v2_blur = NLayerDiscriminatorv2( + num_channels=3, hidden_channels=128, num_stages=3, blur_resample=True + ) + original_discriminiator = OriginalNLayerDiscriminator( + num_channels=3, hidden_channels=128, num_stages=3 + ) + + from torchinfo import summary + + print("Original Discriminator") + summary( + original_discriminiator, + input_size=(1, 3, 256, 256), + depth=3, + col_names=( + "input_size", + "output_size", + "num_params", + "params_percent", + "kernel_size", + "mult_adds", + ), + ) + print("Patch Discriminator v2") + summary( + patch_discriminator_v2, + input_size=(1, 3, 256, 256), + depth=3, + col_names=( + "input_size", + "output_size", + "num_params", + "params_percent", + "kernel_size", + "mult_adds", + ), + ) + print("Patch Discriminator v2 (blur)") + summary( + patch_discriminator_v2_blur, + input_size=(1, 3, 256, 256), + depth=3, + col_names=( + "input_size", + "output_size", + "num_params", + "params_percent", + "kernel_size", + "mult_adds", + ), + ) + + x = torch.randn((1, 3, 256, 256)).to(next(original_discriminiator.parameters())) + + out_original = original_discriminiator(x) + out_patch_v2 = patch_discriminator_v2(x) + out_patch_v2_blur = patch_discriminator_v2_blur(x) + + print(f"Input shape: {x.shape}") + print(f"Patch Discriminator v2 output shape: {out_patch_v2.shape}") + print(f"Patch Discriminator v2 (blur) output shape: {out_patch_v2_blur.shape}") + print(f"Original Discriminator output shape: {out_original.shape}") diff --git a/src/vqvaes/maskbit/modules/ema_model.py b/src/vqvaes/maskbit/modules/ema_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cb716998d769f560478bfbbac82114c2c68c3e64 --- /dev/null +++ b/src/vqvaes/maskbit/modules/ema_model.py @@ -0,0 +1,276 @@ +"""This file contains the definition of the EMA class. + +We thank the following public implementations for inspiring this code: + https://github.com/fadel/pytorch_ema +""" + +import copy +from typing import Any, Iterable, Optional, Union + +import torch + + +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float = 0.9999, + min_decay: float = 0.0, + update_after_step: int = 0, + update_every: int = 1, + current_step: int = 0, + use_ema_warmup: bool = False, + inv_gamma: Union[float, int] = 1.0, + power: Union[float, int] = 2 / 3, + model_cls: Optional[Any] = None, + **model_config_kwargs + ): + """ + Args: + parameters -> Iterable[torch.nn.Parameter]: The parameters to track. + decay -> float: The decay factor for the exponential moving average. + min_decay -> float: The minimum decay factor for the exponential moving average. + update_after_step -> int: The number of steps to wait before starting to update the EMA weights. + update_every -> int: The number of steps between each EMA update. + current_step -> int: The current training step. + use_ema_warmup -> bool: Whether to use EMA warmup. + inv_gamma -> float: Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` + is True. + power -> float: Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. + + notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + """ + + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + self.temp_stored_params = None + + self.decay = decay + self.min_decay = min_decay + self.update_after_step = update_after_step + self.update_every = update_every + self.use_ema_warmup = use_ema_warmup + self.inv_gamma = inv_gamma + self.power = power + self.optimization_step = current_step + self.cur_decay_value = None # set in `step()` + + self.model_cls = model_cls + self.model_config_kwargs = model_config_kwargs + + @classmethod + def from_pretrained( + cls, checkpoint, model_cls, **model_config_kwargs + ) -> "EMAModel": + model = model_cls(**model_config_kwargs) + model.load_pretrained(checkpoint) + + ema_model = cls(model.parameters(), model_cls=model_cls, **model_config_kwargs) + return ema_model + + def save_pretrained(self, path): + if self.model_cls is None: + raise ValueError( + "`save_pretrained` can only be used if `model_cls` was defined at __init__." + ) + + if self.model_config_kwargs is None: + raise ValueError( + "`save_pretrained` can only be used if `model_config_kwargs` was defined at __init__." + ) + + model = self.model_cls(**self.model_config_kwargs) + self.copy_to(model.parameters()) + model.save_pretrained(path) + + def set_step(self, optimization_step: int): + """ + Set the current optimization step. + + Args: + optimization_step -> int: the current optimization step. + """ + self.optimization_step = optimization_step + + def get_decay(self, optimization_step: int) -> float: + """ + Compute the decay factor for the exponential moving average. + + Args: + optimization_step -> int: the current optimization step. + """ + step = max(0, optimization_step - self.update_after_step - 1) + + if step <= 0: + return 0.0 + + if self.use_ema_warmup: + cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power + else: + cur_decay_value = (1 + step) / (10 + step) + + cur_decay_value = min(cur_decay_value, self.decay) + # make sure decay is not smaller than min_decay + cur_decay_value = max(cur_decay_value, self.min_decay) + return cur_decay_value + + @torch.no_grad() + def step(self, parameters: Iterable[torch.nn.Parameter]): + """ + Update the stored parameters with the current parameters. + + Args: + parameters -> Iterable[torch.nn.Parameter]: the parameters used to update the EMA model. + """ + parameters = list(parameters) + + self.optimization_step += 1 + + if (self.optimization_step - 1) % self.update_every != 0: + return + + # Compute the decay factor for the exponential moving average. + decay = self.get_decay(self.optimization_step) + self.cur_decay_value = decay + one_minus_decay = 1 - decay + + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) + + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters -> Iterable[torch.nn.Parameter]: the parameters to be updated with the stored moving averages. + If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.to(param.device).data) + + def to(self, device=None, dtype=None) -> None: + """Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + ( + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + ) + for p in self.shadow_params + ] + + def state_dict(self) -> dict: + """ + Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during + checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "min_decay": self.min_decay, + "optimization_step": self.optimization_step, + "update_after_step": self.update_after_step, + "use_ema_warmup": self.use_ema_warmup, + "inv_gamma": self.inv_gamma, + "power": self.power, + "shadow_params": self.shadow_params, + } + + def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Save the current parameters for restoring later. + + Args: + parameters -> Iterable[torch.nn.Parameter]: the parameters to be temporarily stored. + """ + self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] + + def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: + affecting the original optimization process. Store the parameters before the `copy_to()` method. After + validation (or model saving), use this to restore the former parameters. + + Args: + parameters -> Iterable[torch.nn.Parameter]: the parameters to be updated with the stored parameters. + If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. + """ + if self.temp_stored_params is None: + raise RuntimeError( + "This ExponentialMovingAverage has no `store()`ed weights to `restore()`" + ) + for c_param, param in zip(self.temp_stored_params, parameters): + param.data.copy_(c_param.data) + + # Better memory-wise. + self.temp_stored_params = None + + def load_state_dict(self, state_dict: dict) -> None: + """ + Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the + ema state dict. + + Args: + state_dict -> dict: EMA state. Should be an object returned from a call to `state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict.get("decay", self.decay) + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.min_decay = state_dict.get("min_decay", self.min_decay) + if not isinstance(self.min_decay, float): + raise ValueError("Invalid min_decay") + + self.optimization_step = state_dict.get( + "optimization_step", self.optimization_step + ) + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.update_after_step = state_dict.get( + "update_after_step", self.update_after_step + ) + if not isinstance(self.update_after_step, int): + raise ValueError("Invalid update_after_step") + + self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) + if not isinstance(self.use_ema_warmup, bool): + raise ValueError("Invalid use_ema_warmup") + + self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) + if not isinstance(self.inv_gamma, (float, int)): + raise ValueError("Invalid inv_gamma") + + self.power = state_dict.get("power", self.power) + if not isinstance(self.power, (float, int)): + raise ValueError("Invalid power") + + shadow_params = state_dict.get("shadow_params", None) + if shadow_params is not None: + self.shadow_params = shadow_params + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") diff --git a/src/vqvaes/maskbit/modules/factorization.py b/src/vqvaes/maskbit/modules/factorization.py new file mode 100644 index 0000000000000000000000000000000000000000..6e36743e0279da61fca17736bfbbbcd78af48607 --- /dev/null +++ b/src/vqvaes/maskbit/modules/factorization.py @@ -0,0 +1,73 @@ +"""This file contains the definition of utility functions to group tokens.""" + +import math +import torch + + +def combine_factorized_tokens( + tokens: torch.Tensor, codebook_size: int, splits: int +) -> torch.Tensor: + """ + Combine the tokens into a single token. + + Args: + tokens -> torch.Tensor: Tensor of shape (batch_size, n, m). + codebook_size -> int: The size of the codebook. + splits -> int: Number of splits. + + Returns: + combined_tokens -> torch.Tensor: Tensor of shape (batch_size, n). + """ + combined_tokens = torch.zeros( + (tokens.shape[0], tokens.shape[1]), device=tokens.device + ) + bit_shift = int(math.log2(codebook_size)) // splits + for i in range(splits): + combined_tokens += tokens[..., i] << (i * bit_shift) + + return combined_tokens + + +def split_factorized_tokens( + tokens: torch.Tensor, codebook_size: int, splits: int +) -> torch.Tensor: + """ + Split the tokens into multiple tokens. + + Args: + tokens -> torch.Tensor: Tensor of shape (batch_size, n). + codebook_size -> int: The size of the codebook. + splits -> int: Number of splits. + + Returns: + split_tokens -> torch.Tensor: Tensor of shape (batch_size, n, m). + """ + bit_shift = int(math.log2(codebook_size)) // splits + bit_mask = (1 << bit_shift) - 1 + + split_tokens = [] + for i in range(splits): + split_tokens.append((tokens & (bit_mask << (i * bit_shift))) >> (i * bit_shift)) + + return torch.stack(split_tokens, dim=2) + + +if __name__ == "__main__": + tokens = torch.randint(0, 1023, (1, 16)) + split_tokens = split_factorized_tokens(tokens, 1024, 1) + + assert split_tokens.shape == (1, 16, 1) + assert split_tokens.dtype == torch.int64 + + combined_tokens = combine_factorized_tokens(split_tokens, 1024, 1) + + assert (tokens == combined_tokens).all() + + split_tokens = split_factorized_tokens(tokens, 1024, 2) + combined_tokens = combine_factorized_tokens(split_tokens, 1024, 2) + + assert split_tokens.shape == (1, 16, 2) + assert (tokens == combined_tokens).all(), f"{tokens} != {combined_tokens}" + + assert (torch.bitwise_right_shift(tokens, 5) == split_tokens[..., 1]).all() + assert (tokens & 31 == split_tokens[..., 0]).all() diff --git a/src/vqvaes/maskbit/modules/gan_utils.py b/src/vqvaes/maskbit/modules/gan_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b3b0d5f2e6ff7d42ec4629620e0ca0c551f480e2 --- /dev/null +++ b/src/vqvaes/maskbit/modules/gan_utils.py @@ -0,0 +1,211 @@ +"""This file contains the definition of utility functions for GANs.""" + +import torch +import torch.nn.functional as F + +from . import OriginalNLayerDiscriminator, NLayerDiscriminatorv2 + + +def toggle_off_gradients(model: torch.nn.Module): + """Toggles off gradients for all parameters in a model.""" + for param in model.parameters(): + param.requires_grad = False + + +def toggle_on_gradients(model: torch.nn.Module): + """Toggles on gradients for all parameters in a model.""" + for param in model.parameters(): + param.requires_grad = True + + +def discriminator_weights_init(m): + """Initialize weights for convolutions in the discriminator.""" + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + torch.nn.init.normal_(m.weight.data, 0.0, 0.02) + + +def adopt_weight( + weight: float, global_step: int, threshold: int = 0, value: float = 0.0 +) -> float: + """If global_step is less than threshold, return value, else return weight.""" + if global_step < threshold: + weight = value + return weight + + +def compute_lecam_loss( + logits_real_mean: torch.Tensor, + logits_fake_mean: torch.Tensor, + ema_logits_real_mean: torch.Tensor, + ema_logits_fake_mean: torch.Tensor, +) -> torch.Tensor: + """Computes the LeCam loss for the given average real and fake logits. + + Args: + logits_real_mean -> torch.Tensor: The average real logits. + logits_fake_mean -> torch.Tensor: The average fake logits. + ema_logits_real_mean -> torch.Tensor: The EMA of the average real logits. + ema_logits_fake_mean -> torch.Tensor: The EMA of the average fake logits. + + Returns: + lecam_loss -> torch.Tensor: The LeCam loss. + """ + lecam_loss = torch.mean( + torch.pow(F.relu(logits_real_mean - ema_logits_fake_mean), 2) + ) + lecam_loss += torch.mean( + torch.pow(F.relu(ema_logits_real_mean - logits_fake_mean), 2) + ) + return lecam_loss + + +def hinge_g_loss(logits_fake: torch.Tensor) -> torch.Tensor: + """Computes the hinge loss for the generator given the fake logits. + + Args: + logits_fake -> torch.Tensor: The fake logits. + + Returns: + g_loss -> torch.Tensor: The hinge loss. + """ + g_loss = -torch.mean(logits_fake) + return g_loss + + +def hinge_d_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor) -> torch.Tensor: + """Computes the hinge loss for the discriminator given the real and fake logits. + + Args: + logits_real -> torch.Tensor: The real logits. + logits_fake -> torch.Tensor: The fake logits. + + Returns: + d_loss -> torch.Tensor: The hinge loss. + """ + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def sigmoid_cross_entropy_with_logits( + logits: torch.Tensor, label: torch.Tensor +) -> torch.Tensor: + """Credits to Magvit. + We use a stable formulation that is equivalent to the one used in TensorFlow. + The following derivation shows how we arrive at the formulation: + + .. math:: + z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) + = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) + = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) + = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) + = (1 - z) * x + log(1 + exp(-x)) + = x - x * z + log(1 + exp(-x)) + + For x < 0, the following formula is more stable: + .. math:: + x - x * z + log(1 + exp(-x)) + = log(exp(x)) - x * z + log(1 + exp(-x)) + = - x * z + log(1 + exp(x)) + + We combine the two cases (x<0, x>=0) into one formula as follows: + .. math:: + max(x, 0) - x * z + log(1 + exp(-abs(x))) + """ + zeros = torch.zeros_like(logits) + cond = logits >= zeros + relu_logits = torch.where(cond, logits, zeros) + neg_abs_logits = torch.where(cond, -logits, logits) + loss = relu_logits - logits * label + torch.log1p(neg_abs_logits.exp()) + return loss + + +def non_saturating_d_loss( + logits_real: torch.Tensor, logits_fake: torch.Tensor +) -> torch.Tensor: + """Computes the non-saturating loss for the discriminator given the real and fake logits. + + Args: + logits_real -> torch.Tensor: The real logits. + logits_fake -> torch.Tensor: The fake logits. + + Returns: + loss -> torch.Tensor: The non-saturating loss. + """ + real_loss = torch.mean( + sigmoid_cross_entropy_with_logits( + logits_real, label=torch.ones_like(logits_real) + ) + ) + fake_loss = torch.mean( + sigmoid_cross_entropy_with_logits( + logits_fake, label=torch.zeros_like(logits_fake) + ) + ) + return torch.mean(real_loss) + torch.mean(fake_loss) + + +def non_saturating_g_loss(logits_fake: torch.Tensor) -> torch.Tensor: + """Computes the non-saturating loss for the generator given the fake logits. + + Args: + logits_fake -> torch.Tensor: The fake logits. + + Returns: + loss -> torch.Tensor: The non-saturating loss. + """ + return torch.mean( + sigmoid_cross_entropy_with_logits( + logits_fake, label=torch.ones_like(logits_fake) + ) + ) + + +def vanilla_d_loss( + logits_real: torch.Tensor, logits_fake: torch.Tensor +) -> torch.Tensor: + """Computes the vanilla loss for the discriminator given the real and fake logits. + + Args: + logits_real -> torch.Tensor: The real logits. + logits_fake -> torch.Tensor: The fake logits. + + Returns: + loss -> torch.Tensor: The vanilla loss. + """ + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake)) + ) + return d_loss + + +def create_discriminator(discriminator_config) -> torch.nn.Module: + """Creates a discriminator based on the given config. + + Args: + discriminator_config: The config for the discriminator. + + Returns: + discriminator -> torch.nn.Module: The discriminator. + """ + if discriminator_config.name == "Original": + return OriginalNLayerDiscriminator( + num_channels=discriminator_config.num_channels, + num_stages=discriminator_config.num_stages, + hidden_channels=discriminator_config.hidden_channels, + ).apply(discriminator_weights_init) + elif discriminator_config.name == "VQGAN+Discriminator": + return NLayerDiscriminatorv2( + num_channels=discriminator_config.num_channels, + num_stages=discriminator_config.num_stages, + hidden_channels=discriminator_config.hidden_channels, + blur_resample=discriminator_config.blur_resample, + blur_kernel_size=discriminator_config.get("blur_kernel_size", 4), + ) + else: + raise ValueError( + f"Discriminator {discriminator_config.name} is not implemented." + ) diff --git a/src/vqvaes/maskbit/modules/losses.py b/src/vqvaes/maskbit/modules/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..d0416417cbaa7ac1957c0710206ed372c294e77e --- /dev/null +++ b/src/vqvaes/maskbit/modules/losses.py @@ -0,0 +1,392 @@ +from typing import Mapping, Text, Tuple +import torch +import torch.nn.functional as F + + +from .lpips import LPIPS +from .perceptual_loss import PerceptualLoss +from . import gan_utils + + +def create_perception_loss( + perception_loss: str, compute_on_logits: bool = True +) -> torch.nn.Module: + """Creates the perception loss. + + Args: + perception_loss -> str: The name of the perception loss. + compute_on_logits -> bool: Whether to compute the loss on logits or on multiple features. + + Returns: + perception_loss -> torch.nn.Module: The perception loss. + """ + if perception_loss == "lpips": + return LPIPS().eval() + elif perception_loss in ("resnet50", "convnext_s"): + return PerceptualLoss( + model_name=perception_loss, + compute_perceptual_loss_on_logits=compute_on_logits, + ).eval() + else: + raise ValueError(f"Perception loss {perception_loss} is not supported.") + + +class VQGANLoss(torch.nn.Module): + def __init__( + self, + discriminator_config, + loss_config, + ): + """Initializes the VQGAN loss. + + Args: + discriminator_config: The configuration of the discriminator. + loss_config: The configuration of the loss. + """ + super().__init__() + assert loss_config.discriminator_loss in ("hinge", "vanilla", "non-saturating") + assert loss_config.reconstruction_loss in ("l2", "l1") + assert loss_config.discriminator_gradient_penalty in ("none", "adopt_weight") + + self.discriminator = gan_utils.create_discriminator(discriminator_config) + + self.reconstruction_loss = loss_config.reconstruction_loss + self.reconstruction_weight = loss_config.get("reconstruction_weight", 1.0) + self.quantizer_weight = loss_config.quantizer_weight + self.perceptual_loss = create_perception_loss( + loss_config.perceptual_loss, + loss_config.get("perceptual_loss_on_logits", True), + ) + self.perceptual_weight = loss_config.perceptual_weight + self.lecam_regularization_weight = loss_config.lecam_regularization_weight + self.ema_decay = loss_config.get("ema_decay", 0.999) + + self.entropy_annealing_steps = loss_config.get("entropy_annealing_steps", 2000) + self.entropy_annealing_factor = loss_config.get("entropy_annealing_factor", 0.0) + + self.discriminator_iter_start = loss_config.discriminator_start + + if loss_config.discriminator_loss == "hinge": + self.discriminator_loss = gan_utils.hinge_d_loss + elif loss_config.discriminator_loss == "vanilla": + self.discriminator_loss = gan_utils.vanilla_d_loss + elif loss_config.discriminator_loss == "non-saturating": + self.discriminator_loss = gan_utils.non_saturating_d_loss + else: + raise ValueError(f"Unknown GAN loss '{loss_config.discriminator_loss}'.") + + if loss_config.discriminator_loss == "hinge": + self.generator_loss = gan_utils.hinge_g_loss + elif loss_config.discriminator_loss == "vanilla": + self.generator_loss = gan_utils.hinge_g_loss + elif loss_config.discriminator_loss == "non-saturating": + self.generator_loss = gan_utils.non_saturating_g_loss + else: + raise ValueError(f"Unknown GAN loss '{loss_config.discriminator_loss}'.") + + self.discriminator_factor = loss_config.discriminator_factor + self.discriminator_weight = loss_config.discriminator_weight + + self.discriminator_gradient_penalty = ( + "" + if loss_config.discriminator_gradient_penalty == "none" + else loss_config.discriminator_gradient_penalty + ) + self.discriminator_penalty_cost = loss_config.discriminator_penalty_cost + + if self.lecam_regularization_weight > 0.0: + self.register_buffer("ema_real_logits_mean", torch.zeros((1))) + self.register_buffer("ema_fake_logits_mean", torch.zeros((1))) + + def calculate_adaptive_weight( + self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer + ) -> torch.Tensor: + """Calculates the adaptive weight for the discriminator loss. + + Args: + nll_loss -> torch.Tensor: The NLL loss. + g_loss -> torch.Tensor: The generator loss. + last_layer: The last layer of the model. + + Returns: + d_weight -> torch.Tensor: The adaptive weight for the discriminator loss. + """ + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + return d_weight + + def forward( + self, + inputs: torch.Tensor, + reconstructions: torch.Tensor, + extra_result_dict: Mapping[Text, torch.Tensor], + global_step: int, + last_layer, + mode: str = "gen", + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Computes the VQGAN loss for the generator or discriminator. + + Args: + inputs -> torch.Tensor: The input images. + reconstructions -> torch.Tensor: The reconstructed images. + extra_result_dict -> Mapping[Text, torch.Tensor]: The extra result dictionary. + global_step -> int: The global step. + last_layer: The last layer of the model. + mode -> str: The mode. Must be either "gen" or "disc". + + Returns: + loss -> torch.Tensor: The loss. + loss_dict -> Mapping[Text, torch.Tensor]: The loss dictionary for logging individual losses. + """ + assert mode in ("gen", "disc") + if mode == "gen": + return self._forward_generator( + inputs, reconstructions, extra_result_dict, global_step, last_layer + ) + elif mode == "disc": + return self._forward_discriminator( + inputs, reconstructions, extra_result_dict, global_step + ) + + def should_discriminator_be_trained(self, global_step: int): + """Returns if the discriminator should be trained at given step.""" + return global_step >= self.discriminator_iter_start + + def _forward_generator( + self, + inputs: torch.Tensor, + reconstructions: torch.Tensor, + extra_result_dict: Mapping[Text, torch.Tensor], + global_step: int, + last_layer, + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Computes the VQGAN loss for the generator. + + Args: + inputs -> torch.Tensor: The input images. + reconstructions -> torch.Tensor: The reconstructed images. + extra_result_dict -> Mapping[Text, torch.Tensor]: The extra result dictionary. + global_step -> int: The global step. + last_layer: The last layer of the model. + + Returns: + loss -> torch.Tensor: The loss. + loss_dict -> Mapping[Text, torch.Tensor]: The loss dictionary for logging individual losses. + """ + inputs = inputs.contiguous() + reconstructions = reconstructions.contiguous() + + if self.reconstruction_loss == "l1": + reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean") + else: + reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean") + reconstruction_loss *= self.reconstruction_weight + + perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean() + + generator_loss = torch.zeros((), device=inputs.device) + extra_generator_loss = torch.zeros((), device=inputs.device) + + discriminator_factor = gan_utils.adopt_weight( + self.discriminator_factor, + global_step, + threshold=self.discriminator_iter_start, + ) + + d_weight = 1.0 + if discriminator_factor > 0.0: + # Disable discriminator gradients + gan_utils.toggle_off_gradients(self.discriminator) + + logits_fake = self.discriminator(reconstructions) + generator_loss = self.generator_loss(logits_fake) + + if self.discriminator_gradient_penalty == "adopt_weight": + d_weight *= self.calculate_adaptive_weight( + reconstruction_loss + self.perceptual_weight * perceptual_loss, + generator_loss, + last_layer=last_layer, + ) + d_weight *= self.discriminator_weight + + quantizer_loss = extra_result_dict["quantizer_loss"] + if self.entropy_annealing_factor > 0.0: + quantizer_loss += ( + max(0.0, 1 - global_step / self.entropy_annealing_steps) + * self.entropy_annealing_factor + * extra_result_dict["entropy_loss"] + ) + + total_loss = ( + reconstruction_loss + + self.perceptual_weight * perceptual_loss + + self.quantizer_weight * quantizer_loss + + d_weight * discriminator_factor * (generator_loss + extra_generator_loss) + ) + + loss_dict = dict( + total_loss=total_loss.clone().detach(), + reconstruction_loss=reconstruction_loss.detach(), + perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(), + quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(), + weighted_gan_loss=( + d_weight + * discriminator_factor + * (generator_loss + extra_generator_loss) + ).detach(), + discriminator_factor=torch.tensor(discriminator_factor), + commitment_loss=extra_result_dict["commitment_loss"].detach(), + entropy_loss=extra_result_dict["entropy_loss"].detach(), + per_sample_entropy=extra_result_dict["per_sample_entropy"], + avg_entropy=extra_result_dict["avg_entropy"], + d_weight=d_weight, + gan_loss=generator_loss.detach(), + ) + if "codebook_loss" in extra_result_dict: + loss_dict["codebook_loss"] = extra_result_dict["codebook_loss"].detach() + + return total_loss, loss_dict + + def _forward_discriminator( + self, + inputs: torch.Tensor, + reconstructions: torch.Tensor, + extra_result_dict: Mapping[Text, torch.Tensor], + global_step: int, + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Computes the VQGAN loss for the discriminator. + + Args: + inputs -> torch.Tensor: The input images. + reconstructions -> torch.Tensor: The reconstructed images. + extra_result_dict -> Mapping[Text, torch.Tensor]: The extra result dictionary. + global_step -> int: The global step. + + Returns: + loss -> torch.Tensor: The loss. + loss_dict -> Mapping[Text, torch.Tensor]: The loss dictionary for logging individual losses. + """ + + discriminator_factor = gan_utils.adopt_weight( + self.discriminator_factor, + global_step, + threshold=self.discriminator_iter_start, + ) + loss_dict = {} + # Turn on gradients on + gan_utils.toggle_on_gradients(self.discriminator) + + real_images = inputs.detach().requires_grad_(True) + logits_real = self.discriminator(real_images) + logits_fake = self.discriminator(reconstructions.detach()) + + discriminator_loss = discriminator_factor * self.discriminator_loss( + logits_real=logits_real, logits_fake=logits_fake + ) + + lecam_loss = torch.zeros((), device=inputs.device) + if self.lecam_regularization_weight > 0.0: + lecam_loss = ( + gan_utils.compute_lecam_loss( + torch.mean(logits_real), + torch.mean(logits_fake), + self.ema_real_logits_mean, + self.ema_fake_logits_mean, + ) + * self.lecam_regularization_weight + ) + + self.ema_real_logits_mean = ( + self.ema_real_logits_mean * self.ema_decay + + torch.mean(logits_real).detach() * (1 - self.ema_decay) + ) + self.ema_fake_logits_mean = ( + self.ema_fake_logits_mean * self.ema_decay + + torch.mean(logits_fake).detach() * (1 - self.ema_decay) + ) + + discriminator_loss += lecam_loss + + loss_dict = dict( + discriminator_loss=discriminator_loss.detach(), + logits_real=logits_real.detach().mean(), + logits_fake=logits_fake.detach().mean(), + lecam_loss=lecam_loss.detach(), + ) + + return discriminator_loss, loss_dict + + +class MLMLoss(torch.nn.Module): + def __init__(self, label_smoothing: float = 0.1, sum_splits: bool = False): + """Initializes the MLM loss, which is essentially a CrossEntropy loss with label smoothing. + + Args: + label_smoothing -> float: The label smoothing factor. + sum_splits -> bool: Whether to sum the loss over the splits. + """ + super().__init__() + self.label_smoothing = label_smoothing + self.criterion = torch.nn.CrossEntropyLoss(label_smoothing=self.label_smoothing) + self.sum_splits = sum_splits + + def forward( + self, inputs: torch.Tensor, targets: torch.Tensor, masks: torch.Tensor + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Computes the MLM loss. + + Args: + inputs -> torch.Tensor: The input logits. + targets -> torch.Tensor: The target tokens. + masks -> torch.Tensor: The mask for the tokens. + + Returns: + loss -> torch.Tensor: The loss. + loss_dict -> Mapping[Text, torch.Tensor]: The loss dictionary for logging individual losses. + """ + b, n, m, codebook_size = inputs.shape + loss = self.criterion(inputs.reshape(-1, codebook_size), targets.view(-1)) + + correct_tokens = ( + torch.argmax(inputs.detach(), dim=-1) == targets + ).float().mean() ** m + + masked_input = inputs[masks, :].detach() + masked_loss = self.criterion(masked_input, targets[masks]) + masked_correct_tokens = ( + torch.argmax(masked_input, dim=-1) == targets[masks] + ).float().mean() ** m + + if self.sum_splits: + loss *= m + masked_loss *= m + + loss_dict = { + "mlm_loss": loss, + "correct_tokens": correct_tokens, + "masked_token_loss": masked_loss, + "masked_correct_tokens": masked_correct_tokens, + } + + return loss, loss_dict + + +if __name__ == "__main__": + loss_module = MLMLoss() + + batchsize = 2 + codebook_dim = 4 + num_codebooks = 1 + + logits = torch.rand((batchsize, 3, num_codebooks, codebook_dim)) + targets = torch.randint(0, codebook_dim, (batchsize, 3, num_codebooks)) + masks = torch.randint(0, 2, (batchsize, 3, num_codebooks), dtype=bool) + + loss, loss_dict = loss_module(logits, targets, masks) + print(logits) + print(targets) + print(masks) + print(loss, loss_dict) diff --git a/src/vqvaes/maskbit/modules/lpips.py b/src/vqvaes/maskbit/modules/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..8166a7521643073b7085acb2837d2377cb2bb411 --- /dev/null +++ b/src/vqvaes/maskbit/modules/lpips.py @@ -0,0 +1,152 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +from pathlib import Path +from collections import namedtuple + +import torch +import torch.nn as nn + +from torchvision import models + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_pretrained(self): + current_file_dir = Path(__file__).parent + vgg_path = ( + current_file_dir / Path("..") / Path("..") / "pretrained" / "vgg_lpips.pth" + ).resolve() + + self.load_state_dict( + torch.load(vgg_path, map_location=torch.device("cpu")), strict=False + ) + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( + outs1[kk] + ) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [ + spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns)) + ] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer( + "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] + ) + self.register_buffer( + "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] + ) + + def forward(self, inp): + inp = inp * 2.0 - 1.0 # Rescale to [-1, 1], expects to be in range [0, 1] + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16( + weights=models.VGG16_Weights.IMAGENET1K_V1 + ).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] + ) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) + + +if __name__ == "__main__": + lpips = LPIPS() + + x = torch.randn((1, 3, 256, 256)) + offset = torch.randn((1, 3, 256, 256)) + out = lpips(x, x + offset) + print(f"Output shape: {out.shape}") + print(f"Output: {out}") diff --git a/src/vqvaes/maskbit/modules/masking.py b/src/vqvaes/maskbit/modules/masking.py new file mode 100644 index 0000000000000000000000000000000000000000..58b45da9cd2115b32bc2d8dacc809f9a3bd21ae4 --- /dev/null +++ b/src/vqvaes/maskbit/modules/masking.py @@ -0,0 +1,70 @@ +"""This file contains the definition of utility functions for masking.""" + +import math +from typing import Text, Tuple +import torch + + +def get_mask_tokens( + tokens: torch.Tensor, + mask_token: int, + mode: Text = "arccos", + min_masking_ratio: float = 0.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the masked tokens. + Args: + tokens -> torch.Tensor: The input tokens. + mask_token -> int: The special `mask` token. + mode -> Text: The masking function to use (default: "arccos"). + Returns: + masked_tokens -> torch.Tensor: The masked input tokens. Each masked token is set to mask_token. + mask -> torch.Tensor: A boolean tensor mask indicating which tokens are masked. + """ + r = torch.rand(tokens.size(0)) * (1 - min_masking_ratio) + if mode == "linear": + val_to_mask = 1 - r + elif mode == "square": + val_to_mask = 1 - (r**2) + elif mode == "cosine": + val_to_mask = torch.cos(r * math.pi * 0.5) + elif mode == "arccos": + val_to_mask = torch.acos(r) / (math.pi * 0.5) + else: + raise ValueError( + "Invalid mode. Choose between 'linear','square', 'cosine', 'arccos'." + ) + + masked_tokens = tokens.detach().clone() + mask = torch.rand(tokens.size()) < val_to_mask.view(-1, 1, 1) + + masked_tokens[mask] = torch.full_like(masked_tokens[mask], mask_token) + return masked_tokens, mask + + +def get_masking_ratio(progress: float, mode: Text = "arccos") -> torch.Tensor: + """Get masking ratio. + Args: + progress -> float: The percentage of iterations already done. + mode -> Text: The masking function to use (default: "arccos"). + + Returns: + val_to_mask -> torch.Tensor: The masking ratio. + """ + r = torch.tensor(progress) + if mode == "root": + val_to_mask = 1 - (r**0.5) + elif mode == "square": + val_to_mask = 1 - (r**2) + elif mode == "cosine": + val_to_mask = torch.cos(r * math.pi * 0.5) + elif mode == "arccos": + val_to_mask = torch.acos(r) / (math.pi * 0.5) + elif mode == "linear": + val_to_mask = 1 - r + else: + raise ValueError( + "Invalid mode. Choose between 'linear','square', 'cosine', 'arccos', 'root'." + ) + + val_to_mask = torch.clamp(val_to_mask, 1e-6, 1.0) + return val_to_mask diff --git a/src/vqvaes/maskbit/modules/perceptual_loss.py b/src/vqvaes/maskbit/modules/perceptual_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..8be2e1fa99c9fdaa8332701e1db09530b6e29dd0 --- /dev/null +++ b/src/vqvaes/maskbit/modules/perceptual_loss.py @@ -0,0 +1,93 @@ +"""This file contains the definition of the perceptual loss.""" + +import torch + +from torchvision import models +from torchvision.models.feature_extraction import create_feature_extractor + + +class PerceptualLoss(torch.nn.Module): + def __init__( + self, + model_name: str = "resnet50", + compute_perceptual_loss_on_logits: bool = True, + ): + """Initialize the perceptual loss. + + Args: + model_name -> str: The name of the model to use. + compute_perceptual_loss_on_logits -> bool: Whether to compute the perceptual loss on the logits + or the features. + """ + super().__init__() + if model_name == "resnet50": + model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) + return_nodes = {"layer4": "features", "fc": "logits"} + elif model_name == "convnext_s": + model = models.convnext_small( + weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1 + ) + return_nodes = {"features": "features", "classifier": "logits"} + + if compute_perceptual_loss_on_logits: + self.model = model + else: + self.model = create_feature_extractor(model, return_nodes=return_nodes) + + self.compute_perceptual_loss_on_logits = compute_perceptual_loss_on_logits + + self.register_buffer( + "mean", torch.Tensor([0.485, 0.456, 0.406])[None, :, None, None] + ) + self.register_buffer( + "std", torch.Tensor([0.229, 0.224, 0.225])[None, :, None, None] + ) + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Compute the perceptual loss. + + Args: + input -> torch.Tensor: The input tensor. + target -> torch.Tensor: The target tensor. + + Returns: + loss -> torch.Tensor: The perceptual loss. + """ + input = torch.nn.functional.interpolate( + input, size=224, mode="bilinear", antialias=True, align_corners=False + ) + target = torch.nn.functional.interpolate( + target, size=224, mode="bilinear", antialias=True, align_corners=False + ) + + input = (input - self.mean) / self.std + target = (target - self.mean) / self.std + + features_input = self.model(input) + features_target = self.model(target) + + if self.compute_perceptual_loss_on_logits: + loss = torch.nn.functional.mse_loss( + features_input, features_target, reduction="mean" + ) + else: + loss = torch.nn.functional.mse_loss( + features_input["features"], + features_target["features"], + reduction="mean", + ) + loss += torch.nn.functional.mse_loss( + features_input["logits"], features_target["logits"], reduction="mean" + ) + return loss + + +if __name__ == "__main__": + model = PerceptualLoss() + input = torch.randn(2, 3, 256, 256).clamp_(0, 1) + target = torch.randn(2, 3, 256, 256).clamp_(0, 1) + loss = model(input, target) + print(loss) diff --git a/src/vqvaes/maskbit/modules/sampling.py b/src/vqvaes/maskbit/modules/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..901b8bb0a597715503d66512fb8b3d261d22b55c --- /dev/null +++ b/src/vqvaes/maskbit/modules/sampling.py @@ -0,0 +1,163 @@ +"""This file contains the definition of the sampling function.""" + +from typing import Optional, Tuple, List, Text +import tqdm + +import torch + +from .masking import get_masking_ratio +from .factorization import combine_factorized_tokens + + +@torch.no_grad() +def sample( + model, + vqgan_model, + num_samples: int = 10, + labels: Optional[torch.Tensor] = None, + softmax_temperature: float = 1.0, + randomize_temperature: float = 4.5, + mask_schedule_strategy: Text = "linear", + num_steps: int = 12, + guidance_scale: float = 3.0, + mask_token: int = 1024, + patch_size: int = 16, + guidance_annealing: Text = "none", + use_sampling_annealing: bool = False, + scale_pow: float = 4.0, + codebook_size: int = 1024, + codebook_splits: int = 1, + use_tqdm: bool = False, +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Sample from the model. + + Args: + model -> torch.nn.Module: The model to sample from. + vqgan_model -> torch.nn.Module: The VQGAN model. + num_samples -> int: The number of samples to generate. + labels -> Optional[torch.Tensor]: The labels to use for the generation. + softmax_temperature -> float: The temperature for the softmax. + randomize_temperature -> float: The temperature for the randomization. + mask_schedule_strategy -> Text: The strategy for the mask schedule. + num_steps -> int: The number of steps to use for the sampling. + guidance_scale -> float: The scale for the guidance. + mask_token -> int: The token to use for the masking. + patch_size -> int: The size of the patches. + guidance_annealing -> Text: The annealing strategy for the guidance. + use_sampling_annealing -> bool: Whether to use the sampling annealing. + scale_pow -> float: The power for the scaling. + codebook_size -> int: The size of the codebook. + codebook_splits -> int: The number of splits for the codebook. + + Returns: + Tuple[torch.Tensor, List[torch.Tensor]]: The generated samples and the tokens at each step. + """ + device = model.device + + model.eval() + vqgan_model.eval() + + if labels is None: + # goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner, teddy bear, random + labels = [ + 1, + 7, + 282, + 604, + 724, + 179, + 751, + 404, + 850, + torch.randint(0, 999, size=(1,)), + ] * (num_samples // 10) + labels = torch.LongTensor(labels).to(device) + + drop_labels = torch.ones(num_samples, dtype=bool, device=device) + spatial_size = int(patch_size**2) + num_splits = int(codebook_splits) + + masked_tokens = torch.full( + (num_samples, spatial_size, num_splits), mask_token, device=device + ) + num_maskable = spatial_size * num_splits + mask = masked_tokens == mask_token + + l_full_tokens = [] + gumbel = torch.distributions.Gumbel(loc=0.0, scale=1.0) + + if use_tqdm: + step_iterable = tqdm.tqdm(range(num_steps), desc="Sampling steps", position=1) + else: + step_iterable = range(num_steps) + + for i in step_iterable: + progress = (i + 1) / num_steps + if guidance_scale != 0.0: + logits = model( + torch.cat([masked_tokens.clone(), masked_tokens.clone()], dim=0), + torch.cat([labels, labels], dim=0), + torch.cat([~drop_labels, drop_labels], dim=0), + ) + # Classifier-free guidance + logits_with_class, logits_without_class = torch.chunk(logits, 2, dim=0) + if guidance_annealing == "none": + scale_step = 1.0 + elif guidance_annealing == "linear": + scale_step = i / num_steps + elif guidance_annealing == "cosine": + scale_pow = torch.ones((1), device=device) * scale_pow + scale_step = ( + (1 - torch.cos(((i / num_steps) ** scale_pow) * torch.pi)) * 1 / 2 + ) # power-cos scaling + scale = guidance_scale * scale_step + logits = logits_with_class + scale * ( + logits_with_class - logits_without_class + ) + else: + logits = model(masked_tokens.clone(), labels, ~drop_labels) + + if use_sampling_annealing: + softmax_temperature = 0.5 + 0.8 * (1 - progress) + probabilities = torch.softmax(logits / softmax_temperature, dim=-1) + distribution = torch.distributions.Categorical(probabilities) + predicted_tokens = distribution.sample() + + num_masked = torch.sum(mask, dim=(1, 2))[0] + + predicted_tokens = torch.where(mask, predicted_tokens, masked_tokens) + + confidence = torch.gather( + probabilities, -1, predicted_tokens.unsqueeze(-1) + ).squeeze(-1) + # Ignore existing tokens by overwriting the confidence. + confidence = torch.where(mask, confidence, torch.inf) + + noise = ( + gumbel.sample(predicted_tokens.size()) + * randomize_temperature + * (1 - progress) + ) + confidence = torch.log(confidence) + noise.to(device) + + mask_ratio = get_masking_ratio(progress, mode=mask_schedule_strategy).to(device) + + # min = 1, max = num_masked - 1 + mask_len = torch.floor(mask_ratio * num_maskable) + num_tokens_to_mask = torch.clamp( + mask_len, torch.ones_like(num_masked), num_masked - 1 + ).long() + sorted_confidence = torch.sort(confidence.view(num_samples, -1), dim=-1).values + threshold = sorted_confidence[:, num_tokens_to_mask - 1] + + should_mask = confidence <= threshold.unsqueeze(-1).unsqueeze(-1) + masked_tokens = torch.where(should_mask, mask_token, predicted_tokens) + mask = masked_tokens == mask_token + l_full_tokens.append(predicted_tokens) + + predicted_tokens = combine_factorized_tokens( + predicted_tokens, codebook_size, codebook_splits + ) + + generated_image = vqgan_model.decode_tokens(predicted_tokens) + return generated_image, l_full_tokens diff --git a/src/vqvaes/maskbit/quantizer/__init__.py b/src/vqvaes/maskbit/quantizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1c09902f925203e523243fb37e3bd94a27dcc3ae --- /dev/null +++ b/src/vqvaes/maskbit/quantizer/__init__.py @@ -0,0 +1,2 @@ +from .quantizer import SimpleVectorizer +from .lookup_free import LookupFreeQuantizer diff --git a/src/vqvaes/maskbit/quantizer/lookup_free.py b/src/vqvaes/maskbit/quantizer/lookup_free.py new file mode 100644 index 0000000000000000000000000000000000000000..f6edb0e0072acc536ffcf1f16b23637c92b8bbbc --- /dev/null +++ b/src/vqvaes/maskbit/quantizer/lookup_free.py @@ -0,0 +1,168 @@ +"""This file contains the definition of the look-free quantizer.""" + +from typing import Mapping, Text, Tuple + +import torch +from einops import rearrange, reduce + +from .quantizer_utils import entropy_loss_fn + + +class LookupFreeQuantizer(torch.nn.Module): + def __init__( + self, + token_bits: int = 10, + commitment_cost: float = 0.25, + entropy_loss_weight: float = 0.1, + entropy_loss_temperature: float = 0.01, + entropy_gamma: float = 1.0, + ): + """Initializes the lookup-free quantizer. + + Args: + token_bits -> int: The number of bits per token. + commitment_cost -> float: The commitment cost. + entropy_loss_weight -> float: The weight of the entropy loss. + entropy_loss_temperature -> float: The temperature for the entropy loss. + entropy_gamma -> float: The gamma for the entropy loss. + """ + super().__init__() + self.token_size = token_bits + self.codebook_size = 2**token_bits + + self.commitment_cost = commitment_cost + self.entropy_loss_weight = entropy_loss_weight + self.entropy_loss_temperature = entropy_loss_temperature + self.entropy_gamma = entropy_gamma + + bits_to_indices = torch.pow( + 2.0, torch.arange(0, self.token_size, dtype=torch.float32) + ) + self.register_buffer("bits_to_indices", bits_to_indices.int()) + + all_codes = torch.arange(self.codebook_size) + bits = ((all_codes[..., None].int() & self.bits_to_indices) != 0).float() + self.register_buffer("codebook", bits * 2.0 - 1.0) + + def forward( + self, z: torch.Tensor + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Forward pass of the quantizer. + + Args: + z -> torch.Tensor: The input tensor. + + Returns: + z_quantized -> torch.Tensor: The quantized latent representation. + result_dict -> Mapping[Text, torch.Tensor]: A dictionary containing additional results + and losses from the quantizer. + """ + z = rearrange(z, "b c h w -> b h w c").contiguous() + ones = torch.ones_like(z) + sign_mask = z > 0.0 + z_quantized = torch.where(sign_mask, ones, -ones) + + min_encoding_indices = self.convert_bits_to_indices(z_quantized) + + # compute loss for embedding + commitment_loss = self.commitment_cost * torch.mean( + (z_quantized.detach() - z) ** 2 + ) + entropy_loss = torch.zeros((), device=z.device) + per_sample_entropy = torch.zeros((), device=z.device) + avg_entropy = torch.zeros((), device=z.device) + + # Use entropy loss on the codebook + if self.entropy_loss_weight != 0.0 and self.training: + d = -2 * torch.einsum("b h w c, n c -> b h w n", z, self.codebook) + + per_sample_entropy, avg_entropy = entropy_loss_fn( + -1 * d, self.entropy_loss_temperature, self.entropy_gamma + ) + entropy_loss = self.entropy_loss_weight * (per_sample_entropy - avg_entropy) + + loss = commitment_loss + entropy_loss + + # preserve gradients + z_quantized = z + (z_quantized - z).detach() + + # reshape back to match original input shape + z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous() + + result_dict = dict( + quantizer_loss=loss, + commitment_loss=commitment_loss, + entropy_loss=entropy_loss, + per_sample_entropy=per_sample_entropy, + avg_entropy=avg_entropy, + min_encoding_indices=min_encoding_indices, + ) + + return z_quantized, result_dict + + def get_codebook_entry(self, indices: torch.Tensor) -> torch.Tensor: + """Returns the `codebook entry` for the given indices. + + As the codebook exists only implicitly, this is mainly an integer conversion to a bit representation. + Note: The bits are represented by {-1, 1}. + + Args: + indices -> torch.Tensor: The indices in range 0 to codebook size - 1. + + Returns: + tokens -> torch.Tensor: The bit representation. + """ + indices = indices.long() + bits = ((indices[..., None].int() & self.bits_to_indices) != 0).float() + tokens = bits * 2.0 - 1.0 # scale to -1..1 + return tokens + + def convert_bits_to_indices(self, tokens: torch.Tensor) -> torch.Tensor: + """Converts the given tokens to index numbers. + + As the codebook exists only implicitly, this is mainly an integer conversion from a bit representation. + Note: The bits are represented by {-1, 1}. + + Args: + tokens -> torch.Tensor: The tokens. + + Returns: + indices -> torch.Tensor: The indices in range 0 to codebook size - 1. + """ + tokens = rearrange(tokens, "b h w c -> b h w c").contiguous() + sign_mask = tokens > 0.0 + return reduce(sign_mask.int() * self.bits_to_indices, "b h w c -> b h w", "sum") + + def convert_indices_to_bits(self, indices: torch.Tensor) -> torch.Tensor: + """Converts the given indices to tokens. + + As the codebook exists only implicitly, this is mainly an integer conversion to a bit representation. + Note: The bits are represented by {-1, 1}. + + Args: + indices -> torch.Tensor: The indices in range 0 to codebook size - 1. + + Returns: + tokens -> torch.Tensor: The bit representation. + """ + indices = indices.long() + return self.get_codebook_entry(indices) + + +if __name__ == "__main__": + quantizer = LookupFreeQuantizer( + token_bits=10, + commitment_cost=0.25, + entropy_loss_weight=0.1, + entropy_loss_temperature=0.01, + entropy_gamma=1.0, + ) + all_entries = torch.arange(1024).reshape(1, 1, 1024) + indices = quantizer.convert_bits_to_indices( + quantizer.convert_indices_to_bits(all_entries) + ) + assert torch.equal(indices, all_entries) + assert torch.equal( + quantizer.convert_bits_to_indices(quantizer.codebook.reshape(1, 1, 1024, 10)), + all_entries, + ) diff --git a/src/vqvaes/maskbit/quantizer/quantizer.py b/src/vqvaes/maskbit/quantizer/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..db9fff115202120beb8a7fb351b341e9862547d9 --- /dev/null +++ b/src/vqvaes/maskbit/quantizer/quantizer.py @@ -0,0 +1,130 @@ +"""This file contains the definition of the VQ quantizer.""" + +from typing import Mapping, Text, Tuple + +import torch +from einops import rearrange + +from .quantizer_utils import entropy_loss_fn + + +class SimpleVectorizer(torch.nn.Module): + def __init__( + self, + codebook_size: int = 1024, + token_size: int = 256, + commitment_cost: float = 0.25, + entropy_loss_weight: float = 0.0, + entropy_loss_temperature: float = 0.01, + entropy_gamma: float = 1.0, + use_l2_normalisation: bool = False, + ): + """Initializes the quantizer. + + Args: + codebook_size -> int: The size of the codebook. + token_size -> int: The feature dimensions of the tokens. + commitment_cost -> float: The commitment cost. + entropy_loss_weight -> float: The weight of the entropy loss. + entropy_loss_temperature -> float: The temperature of the entropy loss. + entropy_gamma -> float: The gamma of the entropy loss. + use_l2_normalisation -> bool: Whether to use L2 normalisation. + """ + + super().__init__() + self.commitment_cost = commitment_cost + + self.embedding = torch.nn.Embedding(codebook_size, token_size) + self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) + + self.entropy_loss_weight = entropy_loss_weight + self.entropy_loss_temperature = entropy_loss_temperature + self.entropy_gamma = entropy_gamma + self.use_l2_normalisation = use_l2_normalisation + + def forward( + self, z: torch.Tensor + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Computes the quantization loss and returns the quantized latent representation. + + Args: + z -> torch.Tensor: The latent representation. + + Returns: + z_quantized -> torch.Tensor: The quantized latent representation. + result_dict -> Mapping[Text, torch.Tensor]: A dictionary containing additional results + and losses from the quantizer. + """ + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, "b c h w -> b h w c").contiguous() + + if self.use_l2_normalisation: + z = torch.nn.functional.normalize(z, dim=-1) + embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1) + else: + embedding = self.embedding.weight + + z_flattened = rearrange(z, "b h w c -> (b h w) c") + + # distances from z to embeddings e_j d = (z - e)^2 = z^2 + e^2 - 2 e * z + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(embedding**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", z_flattened, embedding.T) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape) + + # compute loss for embedding + commitment_loss = self.commitment_cost * torch.mean( + (z_quantized.detach() - z) ** 2 + ) + codebook_loss = torch.mean((z_quantized - z.detach()) ** 2) + entropy_loss = torch.zeros((), device=z.device) + per_sample_entropy = torch.zeros((), device=z.device) + avg_entropy = torch.zeros((), device=z.device) + + # Use entropy loss on the codebook + if self.entropy_loss_weight != 0.0 and self.training: + per_sample_entropy, avg_entropy = entropy_loss_fn( + -1 * d, self.entropy_loss_temperature, self.entropy_gamma + ) + entropy_loss = self.entropy_loss_weight * (per_sample_entropy - avg_entropy) + + loss = commitment_loss + codebook_loss + entropy_loss + + # preserve gradients + z_quantized = z + (z_quantized - z).detach() + + # reshape back to match original input shape + z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous() + + result_dict = dict( + quantizer_loss=loss, + commitment_loss=commitment_loss, + codebook_loss=codebook_loss, + entropy_loss=entropy_loss, + per_sample_entropy=per_sample_entropy, + avg_entropy=avg_entropy, + min_encoding_indices=min_encoding_indices.view( + z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3] + ), + ) + + return z_quantized, result_dict + + def get_codebook_entry(self, indices: torch.Tensor) -> torch.Tensor: + """Returns the codebook entry for the given indices. + + Args: + indices -> torch.Tensor: The indices of the codebook entries. + + Returns: + z_quantized -> torch.Tensor: The codebook entries. + """ + # get quantized latent vectors + z_quantized = self.embedding(indices.int()) + if self.use_l2_normalisation: + z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) + return z_quantized diff --git a/src/vqvaes/maskbit/quantizer/quantizer_utils.py b/src/vqvaes/maskbit/quantizer/quantizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1dd26f13296dda65b5d279b49fdd1296eac4b69 --- /dev/null +++ b/src/vqvaes/maskbit/quantizer/quantizer_utils.py @@ -0,0 +1,46 @@ +"""This file contains the definition of some utility functions for the quantizer.""" + +from typing import Tuple +import torch + + +def clamp_log(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + """Clamps the input tensor and computes the log. + + Args: + x -> torch.Tensor: The input tensor. + eps -> float: The epsilon value serving as the lower bound. + + Returns: + torch.Tensor: The log of the clamped input tensor. + """ + return torch.log(torch.clamp(x, eps)) + + +def entropy_loss_fn( + affinity: torch.Tensor, + temperature: float, + entropy_gamma: float = 1.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes the entropy loss. + + Args: + affinity -> torch.Tensor: The affinity matrix. + temperature -> float: The temperature. + entropy_gamma -> float: The entropy gamma. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The per-sample and average entropy. + """ + flat_affinity = affinity.reshape(-1, affinity.shape[-1]) + flat_affinity /= temperature + + probability = flat_affinity.softmax(dim=-1) + average_probability = torch.mean(probability, dim=0) + + per_sample_entropy = -1 * torch.mean( + torch.sum(probability * clamp_log(probability), dim=-1) + ) + avg_entropy = torch.sum(-1 * average_probability * clamp_log(average_probability)) + + return (per_sample_entropy, avg_entropy * entropy_gamma) diff --git a/src/vqvaes/open_magvit2/ema.py b/src/vqvaes/open_magvit2/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebc5197d13e631853f00d05e495a656647bdeab --- /dev/null +++ b/src/vqvaes/open_magvit2/ema.py @@ -0,0 +1,93 @@ +""" +Refer to +https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/ema.py +""" + +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.999, use_num_upates=False): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + ( + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int) + ), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/src/vqvaes/open_magvit2/improved_model.py b/src/vqvaes/open_magvit2/improved_model.py new file mode 100644 index 0000000000000000000000000000000000000000..325e081584b109c855afbe19928e9834b7b41816 --- /dev/null +++ b/src/vqvaes/open_magvit2/improved_model.py @@ -0,0 +1,306 @@ +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F + + +def swish(x): + # swish + return x * torch.sigmoid(x) + + +class ResBlock(nn.Module): + def __init__( + self, + in_filters, + out_filters, + use_conv_shortcut=False, + use_agn=False, + ) -> None: + super().__init__() + + self.in_filters = in_filters + self.out_filters = out_filters + self.use_conv_shortcut = use_conv_shortcut + self.use_agn = use_agn + + if not use_agn: ## agn is GroupNorm likewise skip it if has agn before + self.norm1 = nn.GroupNorm(32, in_filters, eps=1e-6) + self.norm2 = nn.GroupNorm(32, out_filters, eps=1e-6) + + self.conv1 = nn.Conv2d( + in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False + ) + self.conv2 = nn.Conv2d( + out_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False + ) + + if in_filters != out_filters: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d( + in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False + ) + else: + self.nin_shortcut = nn.Conv2d( + in_filters, out_filters, kernel_size=(1, 1), padding=0, bias=False + ) + + def forward(self, x, **kwargs): + residual = x + + if not self.use_agn: + x = self.norm1(x) + x = swish(x) + x = self.conv1(x) + x = self.norm2(x) + x = swish(x) + x = self.conv2(x) + if self.in_filters != self.out_filters: + if self.use_conv_shortcut: + residual = self.conv_shortcut(residual) + else: + residual = self.nin_shortcut(residual) + + return x + residual + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + in_channels, + num_res_blocks, + z_channels, + ch_mult=(1, 2, 2, 4), + resolution, + double_z=False, + ): + super().__init__() + + self.in_channels = in_channels + self.z_channels = z_channels + self.resolution = resolution + + self.num_res_blocks = num_res_blocks + self.num_blocks = len(ch_mult) + + self.conv_in = nn.Conv2d( + in_channels, ch, kernel_size=(3, 3), padding=1, bias=False + ) + + ## construct the model + self.down = nn.ModuleList() + + in_ch_mult = (1,) + tuple(ch_mult) + for i_level in range(self.num_blocks): + block = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] # [1, 1, 2, 2, 4] + block_out = ch * ch_mult[i_level] # [1, 2, 2, 4] + for _ in range(self.num_res_blocks): + block.append(ResBlock(block_in, block_out)) + block_in = block_out + + down = nn.Module() + down.block = block + if i_level < self.num_blocks - 1: + down.downsample = nn.Conv2d( + block_out, block_out, kernel_size=(3, 3), stride=(2, 2), padding=1 + ) + + self.down.append(down) + + ### mid + self.mid_block = nn.ModuleList() + for res_idx in range(self.num_res_blocks): + self.mid_block.append(ResBlock(block_in, block_in)) + + ### end + self.norm_out = nn.GroupNorm(32, block_out, eps=1e-6) + self.conv_out = nn.Conv2d(block_out, z_channels, kernel_size=(1, 1)) + + def forward(self, x): + + ## down + x = self.conv_in(x) + for i_level in range(self.num_blocks): + for i_block in range(self.num_res_blocks): + x = self.down[i_level].block[i_block](x) + + if i_level < self.num_blocks - 1: + x = self.down[i_level].downsample(x) + + ## mid + for res in range(self.num_res_blocks): + x = self.mid_block[res](x) + + x = self.norm_out(x) + x = swish(x) + x = self.conv_out(x) + + return x + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + in_channels, + num_res_blocks, + z_channels, + ch_mult=(1, 2, 2, 4), + resolution, + double_z=False, + ) -> None: + super().__init__() + + self.ch = ch + self.num_blocks = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + block_in = ch * ch_mult[self.num_blocks - 1] + + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=(3, 3), padding=1, bias=True + ) + + self.mid_block = nn.ModuleList() + for res_idx in range(self.num_res_blocks): + self.mid_block.append(ResBlock(block_in, block_in)) + + self.up = nn.ModuleList() + + self.adaptive = nn.ModuleList() + + for i_level in reversed(range(self.num_blocks)): + block = nn.ModuleList() + block_out = ch * ch_mult[i_level] + self.adaptive.insert(0, AdaptiveGroupNorm(z_channels, block_in)) + for i_block in range(self.num_res_blocks): + # if i_block == 0: + # block.append(ResBlock(block_in, block_out, use_agn=True)) + # else: + block.append(ResBlock(block_in, block_out)) + block_in = block_out + + up = nn.Module() + up.block = block + if i_level > 0: + up.upsample = Upsampler(block_in) + self.up.insert(0, up) + + self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6) + + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=(3, 3), padding=1) + + def forward(self, z): + + style = z.clone() # for adaptive groupnorm + + z = self.conv_in(z) + + ## mid + for res in range(self.num_res_blocks): + z = self.mid_block[res](z) + + ## upsample + for i_level in reversed(range(self.num_blocks)): + ### pass in each resblock first adaGN + z = self.adaptive[i_level](z, style) + for i_block in range(self.num_res_blocks): + z = self.up[i_level].block[i_block](z) + + if i_level > 0: + z = self.up[i_level].upsample(z) + + z = self.norm_out(z) + z = swish(z) + z = self.conv_out(z) + + return z + + +def depth_to_space(x: torch.Tensor, block_size: int) -> torch.Tensor: + """Depth-to-Space DCR mode (depth-column-row) core implementation. + + Args: + x (torch.Tensor): input tensor. The channels-first (*CHW) layout is supported. + block_size (int): block side size + """ + # check inputs + if x.dim() < 3: + raise ValueError( + f"Expecting a channels-first (*CHW) tensor of at least 3 dimensions" + ) + c, h, w = x.shape[-3:] + + s = block_size**2 + if c % s != 0: + raise ValueError( + f"Expecting a channels-first (*CHW) tensor with C divisible by {s}, but got C={c} channels" + ) + + outer_dims = x.shape[:-3] + + # splitting two additional dimensions from the channel dimension + x = x.view(-1, block_size, block_size, c // s, h, w) + + # putting the two new dimensions along H and W + x = x.permute(0, 3, 4, 1, 5, 2) + + # merging the two new dimensions with H and W + x = x.contiguous().view(*outer_dims, c // s, h * block_size, w * block_size) + + return x + + +class Upsampler(nn.Module): + def __init__(self, dim, dim_out=None): + super().__init__() + dim_out = dim * 4 + self.conv1 = nn.Conv2d(dim, dim_out, (3, 3), padding=1) + self.depth2space = depth_to_space + + def forward(self, x): + """ + input_image: [B C H W] + """ + out = self.conv1(x) + out = self.depth2space(out, block_size=2) + return out + + +class AdaptiveGroupNorm(nn.Module): + def __init__(self, z_channel, in_filters, num_groups=32, eps=1e-6): + super().__init__() + self.gn = nn.GroupNorm( + num_groups=32, num_channels=in_filters, eps=eps, affine=False + ) + # self.lin = nn.Linear(z_channels, in_filters * 2) + self.gamma = nn.Linear(z_channel, in_filters) + self.beta = nn.Linear(z_channel, in_filters) + self.eps = eps + + def forward(self, x, quantizer): + B, C, _, _ = x.shape + # quantizer = F.adaptive_avg_pool2d(quantizer, (1, 1)) + ### calcuate var for scale + scale = rearrange(quantizer, "b c h w -> b c (h w)") + scale = scale.var(dim=-1) + self.eps # not unbias + scale = scale.sqrt() + scale = self.gamma(scale).view(B, C, 1, 1) + + ### calculate mean for bias + bias = rearrange(quantizer, "b c h w -> b c (h w)") + bias = bias.mean(dim=-1) + bias = self.beta(bias).view(B, C, 1, 1) + + x = self.gn(x) + x = scale * x + bias + + return x diff --git a/src/vqvaes/open_magvit2/lookup_free_quantize.py b/src/vqvaes/open_magvit2/lookup_free_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..49266c54fe8dca4000ba408bd69f83d6749a42c9 --- /dev/null +++ b/src/vqvaes/open_magvit2/lookup_free_quantize.py @@ -0,0 +1,401 @@ +""" +Lookup Free Quantization +Proposed in https://arxiv.org/abs/2310.05737 + +In the simplest setup, each dimension is quantized into {-1, 1}. +An entropy penalty is used to encourage utilization. + +Refer to +https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py +https://github.com/theAdamColton/ijepa-enhanced/blob/7edef5f7288ae8f537f0db8a10044a2a487f70c9/ijepa_enhanced/lfq.py +""" + +from math import log2, ceil +from collections import namedtuple + +import torch +from torch import nn, einsum +import torch.nn.functional as F +from torch.nn import Module + +from einops import rearrange, reduce, pack, unpack + +# constants + +LossBreakdown = namedtuple( + "LossBreakdown", + ["per_sample_entropy", "codebook_entropy", "commitment", "avg_probs"], +) + +# helper functions + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg() if callable(arg) else arg + return None + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +# entropy + +# def log(t, eps = 1e-5): +# return t.clamp(min = eps).log() + + +def entropy(prob): + return (-prob * torch.log(prob + 1e-5)).sum(dim=-1) + + +# class + + +def mult_along_first_dims(x, y): + """ + returns x * y elementwise along the leading dimensions of y + """ + ndim_to_expand = x.ndim - y.ndim + for _ in range(ndim_to_expand): + y = y.unsqueeze(-1) + return x * y + + +def masked_mean(x, m): + """ + takes the mean of the elements of x that are not masked + the mean is taken along the shared leading dims of m + equivalent to: x[m].mean(tuple(range(m.ndim))) + + The benefit of using masked_mean rather than using + tensor indexing is that masked_mean is much faster + for torch-compile on batches. + + The drawback is larger floating point errors + """ + x = mult_along_first_dims(x, m) + x = x / m.sum() + return x.sum(tuple(range(m.ndim))) + + +def entropy_loss( + logits, + mask=None, + temperature=0.01, + sample_minimization_weight=1.0, + batch_maximization_weight=1.0, + eps=1e-5, +): + """ + Entropy loss of unnormalized logits + + logits: Affinities are over the last dimension + + https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279 + LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024) + """ + probs = F.softmax(logits / temperature, -1) + log_probs = F.log_softmax(logits / temperature + eps, -1) + + if mask is not None: + # avg_probs = probs[mask].mean(tuple(range(probs.ndim - 1))) + # avg_probs = einx.mean("... D -> D", probs[mask]) + + avg_probs = masked_mean(probs, mask) + # avg_probs = einx.mean("... D -> D", avg_probs) + else: + avg_probs = reduce(probs, "... D -> D", "mean") + + avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps)) + + sample_entropy = -torch.sum(probs * log_probs, -1) + if mask is not None: + # sample_entropy = sample_entropy[mask].mean() + sample_entropy = masked_mean(sample_entropy, mask).mean() + else: + sample_entropy = torch.mean(sample_entropy) + + loss = (sample_minimization_weight * sample_entropy) - ( + batch_maximization_weight * avg_entropy + ) + + return sample_entropy, avg_entropy, loss + + +class LFQ(Module): + def __init__( + self, + *, + dim=None, + codebook_size=None, + num_codebooks=1, + sample_minimization_weight=1.0, + batch_maximization_weight=1.0, + token_factorization=False, + factorized_bits=[9, 9], + ): + super().__init__() + + # some assert validations + + assert exists(dim) or exists( + codebook_size + ), "either dim or codebook_size must be specified for LFQ" + assert ( + not exists(codebook_size) or log2(codebook_size).is_integer() + ), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})" + + self.codebook_size = default(codebook_size, lambda: 2**dim) + self.codebook_dim = int(log2(codebook_size)) + + codebook_dims = self.codebook_dim * num_codebooks + dim = default(dim, codebook_dims) + + has_projections = dim != codebook_dims + self.has_projections = has_projections + + self.dim = dim + self.codebook_dim = self.codebook_dim + self.num_codebooks = num_codebooks + + # for entropy loss + self.sample_minimization_weight = sample_minimization_weight + self.batch_maximization_weight = batch_maximization_weight + + # for no auxiliary loss, during inference + self.token_factorization = token_factorization + if not self.token_factorization: # for first stage model + self.register_buffer( + "mask", 2 ** torch.arange(self.codebook_dim), persistent=False + ) + else: + self.factorized_bits = factorized_bits + self.register_buffer( + "pre_mask", 2 ** torch.arange(factorized_bits[0]), persistent=False + ) + self.register_buffer( + "post_mask", 2 ** torch.arange(factorized_bits[1]), persistent=False + ) + + self.register_buffer("zero", torch.tensor(0.0), persistent=False) + + # codes + + all_codes = torch.arange(codebook_size) + bits = self.indices_to_bits(all_codes) + codebook = bits * 2.0 - 1.0 + + self.register_buffer("codebook", codebook, persistent=False) + + @property + def dtype(self): + return self.codebook.dtype + + def indices_to_bits(self, x): + """ + x: long tensor of indices + + returns big endian bits + """ + mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) + # x is now big endian bits, the last dimension being the bits + x = (x.unsqueeze(-1) & mask) != 0 + return x + + def get_codebook_entry(self, x, bhwc, order=None): # 0610 + if self.token_factorization: + if order == "pre": + mask = 2 ** torch.arange( + self.factorized_bits[0], device=x.device, dtype=torch.long + ) + else: + mask = 2 ** torch.arange( + self.factorized_bits[1], device=x.device, dtype=torch.long + ) + else: + mask = 2 ** torch.arange( + self.codebook_dim, device=x.device, dtype=torch.long + ) + + x = (x.unsqueeze(-1) & mask) != 0 + x = x * 2.0 - 1.0 # back to the float + ## scale back to the + b, h, w, c = bhwc + x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c) + x = rearrange(x, "b h w c -> b c h w") + return x + + def bits_to_indices(self, bits): + """ + bits: bool tensor of big endian bits, where the last dimension is the bit dimension + + returns indices, which are long integers from 0 to self.codebook_size + """ + assert bits.shape[-1] == self.codebook_dim + indices = 2 ** torch.arange( + 0, + self.codebook_dim, + 1, + dtype=torch.long, + device=bits.device, + ) + return (bits * indices).sum(-1) + + def decode(self, x): + """ + x: ... NH + where NH is number of codebook heads + A longtensor of codebook indices, containing values from + 0 to self.codebook_size + """ + x = self.indices_to_bits(x) + # to some sort of float + x = x.to(self.dtype) + # -1 or 1 + x = x * 2 - 1 + x = rearrange(x, "... NC Z-> ... (NC Z)") + return x + + def forward( + self, + x, + inv_temperature=100.0, + return_loss_breakdown=False, + mask=None, + return_loss=True, + ): + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension, which is also log2(codebook size) + c - number of codebook dim + """ + + x = rearrange(x, "b d ... -> b ... d") + x, ps = pack_one(x, "b * d") + # split out number of codebooks + + x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks) + + codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype) + quantized = torch.where( + x > 0, codebook_value, -codebook_value + ) # higher than 0 filled + + # calculate indices + if self.token_factorization: + indices_pre = reduce( + (quantized[..., : self.factorized_bits[0]] > 0).int() + * self.pre_mask.int(), + "b n c d -> b n c", + "sum", + ) + indices_post = reduce( + (quantized[..., self.factorized_bits[0] :] > 0).int() + * self.post_mask.int(), + "b n c d -> b n c", + "sum", + ) + else: + indices = reduce( + (quantized > 0).int() * self.mask.int(), "b n c d -> b n c", "sum" + ) + + # entropy aux loss + + if self.training and return_loss: + logits = 2 * einsum("... i d, j d -> ... i j", x, self.codebook) + # the same as euclidean distance up to a constant + per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss( + logits=logits, + sample_minimization_weight=self.sample_minimization_weight, + batch_maximization_weight=self.batch_maximization_weight, + ) + + avg_probs = self.zero + else: + # logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook) + # probs = F.softmax(logits / 0.01, -1) + # avg_probs = reduce(probs, "b n c d -> b d", "mean") + # avg_probs = torch.sum(avg_probs, 0) #batch dimension + # if not training, just return dummy 0 + per_sample_entropy = codebook_entropy = self.zero + ## calculate the codebook_entropy needed for one batch evaluation + entropy_aux_loss = self.zero + avg_probs = self.zero + + # commit loss + + if self.training: + commit_loss = F.mse_loss(x, quantized.detach(), reduction="none") + + if exists(mask): + commit_loss = commit_loss[mask] + + commit_loss = commit_loss.mean() + else: + commit_loss = self.zero + + # use straight-through gradients (optionally with custom activation fn) if training + + quantized = x + (quantized - x).detach() # transfer to quantized + + # merge back codebook dim + + quantized = rearrange(quantized, "b n c d -> b n (c d)") + + # reconstitute image or video dimensions + + quantized = unpack_one(quantized, ps, "b * d") + quantized = rearrange(quantized, "b ... d -> b d ...") + + if self.token_factorization: + indices_pre = unpack_one(indices_pre, ps, "b * c") + indices_post = unpack_one(indices_post, ps, "b * c") + indices_pre = indices_pre.flatten() + indices_post = indices_post.flatten() + indices = (indices_pre, indices_post) + else: + indices = unpack_one(indices, ps, "b * c") + indices = indices.flatten() + + ret = (quantized, entropy_aux_loss, indices) + + if not return_loss_breakdown: + return ret + + return ret, LossBreakdown( + per_sample_entropy, codebook_entropy, commit_loss, avg_probs + ) + + +if __name__ == "__main__": + quantizer = LFQ( + codebook_size=2**18, # codebook size, must be a power of 2 + dim=18, # this is the input feature dimension, defaults to log2(codebook_size) if not defined + sample_minimization_weight=1.0, # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894 + batch_maximization_weight=1.0, + ) + + image_feats = torch.randn( + 2, 18, 16, 16 + ) # 16 is dim, must be power of 2 of codebook_size + + quantized, indices, entropy_aux_loss = quantizer( + image_feats, inv_temperature=100.0 + ) # you may want to experiment with temperature + + assert image_feats.shape == quantized.shape + assert (quantized == quantizer.indices_to_codes(indices)).all() diff --git a/src/vqvaes/open_magvit2/open_magvit2.py b/src/vqvaes/open_magvit2/open_magvit2.py new file mode 100644 index 0000000000000000000000000000000000000000..a13fd697a1704d7cc9a2001770ec72dc1b8699fc --- /dev/null +++ b/src/vqvaes/open_magvit2/open_magvit2.py @@ -0,0 +1,204 @@ +import torch +import torch.nn.functional as F +import lightning as L + +from contextlib import contextmanager +from collections import OrderedDict + +from .improved_model import Encoder, Decoder +from .lookup_free_quantize import LFQ +from .ema import LitEma + + +class VQModel(L.LightningModule): + def __init__( + self, + ddconfig, + lossconfig, + ## Quantize Related + n_embed, + embed_dim, + sample_minimization_weight, + batch_maximization_weight, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + learning_rate=None, + resume_lr=None, + ### scheduler config + warmup_epochs=1.0, # warmup epochs + scheduler_type="linear-warmup_cosine-decay", + min_learning_rate=0, + use_ema=False, + token_factorization=False, + stage=None, + lr_drop_epoch=None, + lr_drop_rate=0.1, + factorized_bits=[9, 9], + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.quantize = LFQ( + dim=embed_dim, + codebook_size=n_embed, + sample_minimization_weight=sample_minimization_weight, + batch_maximization_weight=batch_maximization_weight, + token_factorization=token_factorization, + factorized_bits=factorized_bits, + ) + + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = use_ema + if ( + self.use_ema and stage is None + ): # no need to construct EMA when training Transformer + self.model_ema = LitEma(self) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, stage=stage) + self.resume_lr = resume_lr + self.learning_rate = learning_rate + self.lr_drop_epoch = lr_drop_epoch + self.lr_drop_rate = lr_drop_rate + self.scheduler_type = scheduler_type + self.warmup_epochs = warmup_epochs + self.min_learning_rate = min_learning_rate + self.automatic_optimization = False + + self.strict_loading = False + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def load_state_dict(self, *args, strict=False): + """ + Resume not strict loading + """ + return super().load_state_dict(*args, strict=strict) + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + """ + filter out the non-used keys + """ + return { + k: v + for k, v in super() + .state_dict(*args, destination, prefix, keep_vars) + .items() + if ( + "inception_model" not in k + and "lpips_vgg" not in k + and "lpips_alex" not in k + ) + } + + def init_from_ckpt(self, path, ignore_keys=list(), stage="transformer"): + sd = torch.load(path, map_location="cpu")["state_dict"] + ema_mapping = {} + new_params = OrderedDict() + if stage == "transformer": ### directly use ema encoder and decoder parameter + if self.use_ema: + for k, v in sd.items(): + if "encoder" in k: + if "model_ema" in k: + k = k.replace( + "model_ema.", "" + ) # load EMA Encoder or Decoder + new_k = ema_mapping[k] + new_params[new_k] = v + s_name = k.replace(".", "") + ema_mapping.update({s_name: k}) + continue + if "decoder" in k: + if "model_ema" in k: + k = k.replace( + "model_ema.", "" + ) # load EMA Encoder or Decoder + new_k = ema_mapping[k] + new_params[new_k] = v + s_name = k.replace(".", "") + ema_mapping.update({s_name: k}) + continue + else: # also only load the Generator + for k, v in sd.items(): + if "encoder" in k: + new_params[k] = v + elif "decoder" in k: + new_params[k] = v + missing_keys, unexpected_keys = self.load_state_dict( + new_params, strict=False + ) # first stage + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + (quant, emb_loss, info), loss_breakdown = self.quantize( + h, return_loss_breakdown=True + ) + return quant, emb_loss, info, loss_breakdown + + def decode(self, quant): + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, img_toks, loss_break = self.encode(input) + pixels = self.decode(quant) + return pixels, img_toks, quant + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).contiguous() + return x.float() + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x diff --git a/src/vqvaes/open_magvit2/quantize.py b/src/vqvaes/open_magvit2/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..61f014fbcc6fee5e23b4d3c25a0f7ac6cf110e8c --- /dev/null +++ b/src/vqvaes/open_magvit2/quantize.py @@ -0,0 +1,505 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch import einsum +from einops import rearrange + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for + # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be + # used wherever VectorQuantizer has been used before and is additionally + # more efficient. + def __init__(self, n_e, e_dim, beta): + super(VectorQuantizer, self).__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) + ) + + ## could possible replace this here + # #\start... + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + + min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # dtype min encodings: torch.float32 + # min_encodings shape: torch.Size([2048, 512]) + # min_encoding_indices.shape: torch.Size([2048, 1]) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + # .........\end + + # with: + # .........\start + # min_encoding_indices = torch.argmin(d, dim=1) + # z_q = self.embedding(min_encoding_indices) + # ......\end......... (TODO) + + # compute loss for embedding + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean( + (z_q - z.detach()) ** 2 + ) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + # TODO: check for more easy handling with nn.Embedding + min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) + min_encodings.scatter_(1, indices[:, None], 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: + z_q = z_q.view(shape) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantize(nn.Module): + """ + credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + + def __init__( + self, + num_hiddens, + embedding_dim, + n_embed, + straight_through=True, + kl_weight=5e-4, + temp_init=1.0, + use_vqinterface=True, + remap=None, + unknown_index="random", + ): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_embed = n_embed + + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + + self.use_vqinterface = use_vqinterface + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( + device=new.device + ) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, return_logits=False): + # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:, self.used, ...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:, self.used, ...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = ( + self.kl_weight + * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + ) + + ind = soft_one_hot.argmax(dim=1) + if self.remap is not None: + ind = self.remap_to_used(ind) + if self.use_vqinterface: + if return_logits: + return z_q, diff, (None, None, ind), logits + return z_q, diff, (None, None, ind) + return z_q, diff, ind + + def get_codebook_entry(self, indices, shape): + b, h, w, c = shape + assert b * h * w == indices.shape[0] + indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = ( + F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() + ) + z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) + return z_q + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__( + self, + n_e, + e_dim, + beta, + remap=None, + unknown_index="random", + sane_index_shape=False, + legacy=True, + ): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( + device=new.device + ) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits == False, "Only for interface compatible with Gumbel" + assert return_logits == False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, "b c h w -> b h w c").contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n") + ) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean( + (z_q - z.detach()) ** 2 + ) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean( + (z_q - z.detach()) ** 2 + ) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape( + z.shape[0], -1 + ) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3] + ) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super().__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad=False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_( + new_cluster_size, alpha=1 - self.decay + ) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + # normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(nn.Module): + def __init__( + self, + n_embed, + embedding_dim, + beta, + decay=0.99, + eps=1e-5, + remap=None, + unknown_index="random", + ): + super().__init__() + self.codebook_dim = codebook_dim + self.num_tokens = num_tokens + self.beta = beta + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( + device=new.device + ) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + # z, 'b c h w -> b h w c' + z = rearrange(z, "b c h w -> b h w c") + z_flattened = z.reshape(-1, self.codebook_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = ( + z_flattened.pow(2).sum(dim=1, keepdim=True) + + self.embedding.weight.pow(2).sum(dim=1) + - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight) + ) # 'n d -> d n' + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + if self.training and self.embedding.update: + # EMA cluster size + encodings_sum = encodings.sum(0) + self.embedding.cluster_size_ema_update(encodings_sum) + # EMA embedding average + embed_sum = encodings.transpose(0, 1) @ z_flattened + self.embedding.embed_avg_ema_update(embed_sum) + # normalize embed_avg and update weight + self.embedding.weight_update(self.num_tokens) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + # z_q, 'b h w c -> b c h w' + z_q = rearrange(z_q, "b h w c -> b c h w") + return z_q, loss, (perplexity, encodings, encoding_indices) diff --git a/src/vqvaes/titok/diffusion/__init__.py b/src/vqvaes/titok/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0c3139d7f4471a0529b36386681bcefa1a9c7de --- /dev/null +++ b/src/vqvaes/titok/diffusion/__init__.py @@ -0,0 +1,47 @@ +# Adopted from DiT, which is modified from OpenAI's diffusion repos +# DiT: https://github.com/facebookresearch/DiT/diffusion +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +def create_diffusion( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000, +): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + # rescale_timesteps=rescale_timesteps, + ) diff --git a/src/vqvaes/titok/diffusion/diffusion_utils.py b/src/vqvaes/titok/diffusion/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..86ef6a9ed81767302fbe36900013dfe1abaf25f4 --- /dev/null +++ b/src/vqvaes/titok/diffusion/diffusion_utils.py @@ -0,0 +1,73 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import torch as th +import numpy as np + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/src/vqvaes/titok/diffusion/gaussian_diffusion.py b/src/vqvaes/titok/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..bbae6e1ca6c76ab2480a78bffa8eee93c08749a6 --- /dev/null +++ b/src/vqvaes/titok/diffusion/gaussian_diffusion.py @@ -0,0 +1,904 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace( + beta_start, beta_end, warmup_time, dtype=np.float64 + ) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__(self, *, betas, model_mean_type, model_var_type, loss_type): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = ( + np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) + if len(self.posterior_variance) > 1 + else np.array([]) + ) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape + ) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + temperature=1.0, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param temperature: temperature scaling during Diff Loss sampling. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + # scale the noise by temperature + sample = ( + out["mean"] + + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise * temperature + ) + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + temperature=1.0, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :param temperature: temperature scaling during Diff Loss sampling. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + temperature=temperature, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + temperature=1.0, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape).cuda() + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0]).cuda() + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + temperature=temperature, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps + ) + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape).cuda() + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0]).cuda() + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/src/vqvaes/titok/diffusion/respace.py b/src/vqvaes/titok/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..252206b80dd89f3ea3bf9311d5ee78f9216ab73a --- /dev/null +++ b/src/vqvaes/titok/diffusion/respace.py @@ -0,0 +1,127 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/src/vqvaes/titok/diffusion/timestep_sampler.py b/src/vqvaes/titok/diffusion/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b1356edf5631f3b6019ea2cadf96e12d186d90 --- /dev/null +++ b/src/vqvaes/titok/diffusion/timestep_sampler.py @@ -0,0 +1,150 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/src/vqvaes/titok/modules/__init__.py b/src/vqvaes/titok/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87f485bfaba210d1ea113cb90c43d247d9061ebc --- /dev/null +++ b/src/vqvaes/titok/modules/__init__.py @@ -0,0 +1,12 @@ +from .base_model import BaseModel +from .ema_model import EMAModel +from .losses import ( + ReconstructionLoss_Stage1, + ReconstructionLoss_Stage2, + ReconstructionLoss_Single_Stage, + MLMLoss, + ARLoss, +) +from .blocks import TiTokEncoder, TiTokDecoder, TATiTokDecoder, UViTBlock +from .maskgit_vqgan import Decoder as Pixel_Decoder +from .maskgit_vqgan import VectorQuantizer as Pixel_Quantizer diff --git a/src/vqvaes/titok/modules/base_model.py b/src/vqvaes/titok/modules/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..97b640cab3cfddc3a5013b272289e77f3ff15661 --- /dev/null +++ b/src/vqvaes/titok/modules/base_model.py @@ -0,0 +1,140 @@ +"""This file contains some base class implementation for models. + +This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). +All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. + +Reference: + https://github.com/huggingface/open-muse/blob/main/muse/modeling_utils.py +""" + +import os +from typing import Union, Callable, Dict, Optional + +import torch + + +class BaseModel(torch.nn.Module): + + def __init__(self): + super().__init__() + + def save_pretrained_weight( + self, + save_directory: Union[str, os.PathLike], + save_function: Callable = None, + state_dict: Optional[Dict[str, torch.Tensor]] = None, + ): + """Saves a model and its configuration file to a directory. + + Args: + save_directory: A string or os.PathLike, directory to which to save. + Will be created if it doesn't exist. + save_function: A Callable function, the function to use to save the state dictionary. + Useful on distributed training like TPUs when one need to replace `torch.save` by + another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. + state_dict: A dictionary from str to torch.Tensor, the state dictionary to save. + If `None`, the model's state dictionary will be saved. + """ + if os.path.isfile(save_directory): + print(f"Provided path ({save_directory}) should be a directory, not a file") + return + + if save_function is None: + save_function = torch.save + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self + + if state_dict is None: + state_dict = model_to_save.state_dict() + weights_name = "pytorch_model.bin" + + save_function(state_dict, os.path.join(save_directory, weights_name)) + + print(f"Model weights saved in {os.path.join(save_directory, weights_name)}") + + def load_pretrained_weight( + self, + # pretrained_model_path: Union[str, os.PathLike], + checkpoint, + strict_loading: bool = True, + torch_dtype: Optional[torch.dtype] = None, + ): + r"""Instantiates a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + Args: + pretrained_model_path: A string or os.PathLike, a path to a *directory* or *file* containing model weights. + + Raises: + ValueError: If pretrained_model_path does not exist. + """ + # If pretrained_model_path is a file, set model_file to this file. + # if os.path.isfile(pretrained_model_path): + # model_file = pretrained_model_path + # # If pretrained_model_path is a directory, set model_file to the path of the + # # file "pytorch_model.bin" in this directory. + # elif os.path.isdir(pretrained_model_path): + # pretrained_model_path = os.path.join(pretrained_model_path, "pytorch_model.bin") + # if os.path.isfile(pretrained_model_path): + # model_file = pretrained_model_path + # else: + # raise ValueError(f"{pretrained_model_path} does not exist") + # else: + # raise ValueError(f"{pretrained_model_path} does not exist") + # + # # Load model state from checkpoint. + # checkpoint = torch.load(model_file, map_location="cpu") + # Load state dictionary into self. + msg = self.load_state_dict(checkpoint, strict=strict_loading) + # Print information about loading weights. + print(f"loading weight from {model_file}, msg: {msg}") + # If torch_dtype is specified and is a valid torch.dtype, convert self to this dtype. + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + self.to(torch_dtype) + + # Set model in evaluation mode to deactivate DropOut modules by default. + self.eval() + + def num_parameters( + self, only_trainable: bool = False, exclude_embeddings: bool = False + ) -> int: + """Gets the number of parameters in the module. + + Args: + only_trainable: A boolean, whether to only include trainable parameters. + exclude_embeddings: A boolean, whether to exclude parameters associated with embeddings. + + Returns: + An integer, the number of parameters. + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.named_modules() + if isinstance(module_type, torch.nn.Embedding) + ] + non_embedding_parameters = [ + parameter + for name, parameter in self.named_parameters() + if name not in embedding_param_names + ] + return sum( + p.numel() + for p in non_embedding_parameters + if p.requires_grad or not only_trainable + ) + else: + return sum( + p.numel() + for p in self.parameters() + if p.requires_grad or not only_trainable + ) diff --git a/src/vqvaes/titok/modules/blocks.py b/src/vqvaes/titok/modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..59f8d9317fcbd26017e4bc152c9bac0f34b9e725 --- /dev/null +++ b/src/vqvaes/titok/modules/blocks.py @@ -0,0 +1,742 @@ +"""Building blocks for TiTok. + +Copyright (2024) Bytedance Ltd. and/or its affiliates + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Reference: + https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py + https://github.com/baofff/U-ViT/blob/main/libs/timm.py +""" + +import math +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from collections import OrderedDict +import einops +from einops.layers.torch import Rearrange + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, d_model, n_head, mlp_ratio=4.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.mlp_ratio = mlp_ratio + # optionally we can disable the FFN + if mlp_ratio > 0: + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)), + ] + ) + ) + + def attention(self, x: torch.Tensor): + return self.attn(x, x, x, need_weights=False)[0] + + def forward( + self, + x: torch.Tensor, + ): + attn_output = self.attention(x=self.ln_1(x)) + x = x + attn_output + if self.mlp_ratio > 0: + x = x + self.mlp(self.ln_2(x)) + return x + + +if hasattr(torch.nn.functional, "scaled_dot_product_attention"): + ATTENTION_MODE = "flash" +else: + try: + import xformers + import xformers.ops + + ATTENTION_MODE = "xformers" + except: + ATTENTION_MODE = "math" +print(f"attention mode is {ATTENTION_MODE}") + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, L, C = x.shape + + qkv = self.qkv(x) + if ATTENTION_MODE == "flash": + qkv = einops.rearrange( + qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ).float() + q, k, v = qkv[0], qkv[1], qkv[2] # B H L D + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = einops.rearrange(x, "B H L D -> B L (H D)") + elif ATTENTION_MODE == "xformers": + qkv = einops.rearrange( + qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads + ) + q, k, v = qkv[0], qkv[1], qkv[2] # B L H D + x = xformers.ops.memory_efficient_attention(q, k, v) + x = einops.rearrange(x, "B L H D -> B L (H D)", H=self.num_heads) + elif ATTENTION_MODE == "math": + qkv = einops.rearrange( + qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + q, k, v = qkv[0], qkv[1], qkv[2] # B H L D + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, L, C) + else: + raise NotImplemented + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class UViTBlock(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + skip=False, + use_checkpoint=False, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + self.skip_linear = nn.Linear(2 * dim, dim) if skip else None + self.use_checkpoint = use_checkpoint + + def forward(self, x, skip=None): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, skip) + else: + return self._forward(x, skip) + + def _forward(self, x, skip=None): + if self.skip_linear is not None: + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +def _expand_token(token, batch_size: int): + return token.unsqueeze(0).expand(batch_size, -1, -1) + + +class TiTokEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.image_size = config.dataset.preprocessing.crop_size + self.patch_size = config.model.vq_model.vit_enc_patch_size + self.grid_size = self.image_size // self.patch_size + self.model_size = config.model.vq_model.vit_enc_model_size + self.num_latent_tokens = config.model.vq_model.num_latent_tokens + self.token_size = config.model.vq_model.token_size + + if config.model.vq_model.get("quantize_mode", "vq") == "vae": + self.token_size = self.token_size * 2 # needs to split into mean and std + + self.is_legacy = config.model.vq_model.get("is_legacy", True) + + self.width = { + "small": 512, + "base": 768, + "large": 1024, + }[self.model_size] + self.num_layers = { + "small": 8, + "base": 12, + "large": 24, + }[self.model_size] + self.num_heads = { + "small": 8, + "base": 12, + "large": 16, + }[self.model_size] + + self.patch_embed = nn.Conv2d( + in_channels=3, + out_channels=self.width, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=True, + ) + + scale = self.width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn(self.grid_size**2 + 1, self.width) + ) + self.latent_token_positional_embedding = nn.Parameter( + scale * torch.randn(self.num_latent_tokens, self.width) + ) + self.ln_pre = nn.LayerNorm(self.width) + self.transformer = nn.ModuleList() + for i in range(self.num_layers): + self.transformer.append( + ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0) + ) + self.ln_post = nn.LayerNorm(self.width) + self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True) + + def forward(self, pixel_values, latent_tokens): + batch_size = pixel_values.shape[0] + x = pixel_values + x = self.patch_embed(x) + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + # class embeddings and positional embeddings + x = torch.cat( + [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1 + ) + x = x + self.positional_embedding.to( + x.dtype + ) # shape = [*, grid ** 2 + 1, width] + + latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype) + latent_tokens = latent_tokens + self.latent_token_positional_embedding.to( + x.dtype + ) + x = torch.cat([x, latent_tokens], dim=1) + + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + for i in range(self.num_layers): + x = self.transformer[i](x) + x = x.permute(1, 0, 2) # LND -> NLD + + latent_tokens = x[:, 1 + self.grid_size**2 :] + latent_tokens = self.ln_post(latent_tokens) + # fake 2D shape + if self.is_legacy: + latent_tokens = latent_tokens.reshape( + batch_size, self.width, self.num_latent_tokens, 1 + ) + else: + # Fix legacy problem. + latent_tokens = latent_tokens.reshape( + batch_size, self.num_latent_tokens, self.width, 1 + ).permute(0, 2, 1, 3) + latent_tokens = self.conv_out(latent_tokens) + latent_tokens = latent_tokens.reshape( + batch_size, self.token_size, 1, self.num_latent_tokens + ) + return latent_tokens + + +class TiTokDecoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.image_size = config.dataset.preprocessing.crop_size + self.patch_size = config.model.vq_model.vit_dec_patch_size + self.grid_size = self.image_size // self.patch_size + self.model_size = config.model.vq_model.vit_dec_model_size + self.num_latent_tokens = config.model.vq_model.num_latent_tokens + self.token_size = config.model.vq_model.token_size + self.is_legacy = config.model.vq_model.get("is_legacy", True) + self.width = { + "small": 512, + "base": 768, + "large": 1024, + }[self.model_size] + self.num_layers = { + "small": 8, + "base": 12, + "large": 24, + }[self.model_size] + self.num_heads = { + "small": 8, + "base": 12, + "large": 16, + }[self.model_size] + + self.decoder_embed = nn.Linear(self.token_size, self.width, bias=True) + scale = self.width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn(self.grid_size**2 + 1, self.width) + ) + # add mask token and query pos embed + self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width)) + self.latent_token_positional_embedding = nn.Parameter( + scale * torch.randn(self.num_latent_tokens, self.width) + ) + self.ln_pre = nn.LayerNorm(self.width) + self.transformer = nn.ModuleList() + for i in range(self.num_layers): + self.transformer.append( + ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0) + ) + self.ln_post = nn.LayerNorm(self.width) + + if self.is_legacy: + self.ffn = nn.Sequential( + nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True), + nn.Tanh(), + nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True), + ) + self.conv_out = nn.Identity() + else: + # Directly predicting RGB pixels + self.ffn = nn.Sequential( + nn.Conv2d( + self.width, + self.patch_size * self.patch_size * 3, + 1, + padding=0, + bias=True, + ), + Rearrange( + "b (p1 p2 c) h w -> b c (h p1) (w p2)", + p1=self.patch_size, + p2=self.patch_size, + ), + ) + self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True) + + def forward(self, z_quantized): + N, C, H, W = z_quantized.shape + assert ( + H == 1 and W == self.num_latent_tokens + ), f"{H}, {W}, {self.num_latent_tokens}" + x = z_quantized.reshape(N, C * H, W).permute(0, 2, 1) # NLD + x = self.decoder_embed(x) + + batchsize, seq_len, _ = x.shape + + mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to( + x.dtype + ) + mask_tokens = torch.cat( + [ + _expand_token(self.class_embedding, mask_tokens.shape[0]).to( + mask_tokens.dtype + ), + mask_tokens, + ], + dim=1, + ) + mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype) + x = x + self.latent_token_positional_embedding[:seq_len] + x = torch.cat([mask_tokens, x], dim=1) + + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + for i in range(self.num_layers): + x = self.transformer[i](x) + x = x.permute(1, 0, 2) # LND -> NLD + x = x[:, 1 : 1 + self.grid_size**2] # remove cls embed + x = self.ln_post(x) + # N L D -> N D H W + x = x.permute(0, 2, 1).reshape( + batchsize, self.width, self.grid_size, self.grid_size + ) + x = self.ffn(x.contiguous()) + x = self.conv_out(x) + return x + + +class TATiTokDecoder(TiTokDecoder): + def __init__(self, config): + super().__init__(config) + scale = self.width**-0.5 + self.text_context_length = config.model.vq_model.get("text_context_length", 77) + self.text_embed_dim = config.model.vq_model.get("text_embed_dim", 768) + self.text_guidance_proj = nn.Linear(self.text_embed_dim, self.width) + self.text_guidance_positional_embedding = nn.Parameter( + scale * torch.randn(self.text_context_length, self.width) + ) + + def forward(self, z_quantized, text_guidance): + N, C, H, W = z_quantized.shape + assert ( + H == 1 and W == self.num_latent_tokens + ), f"{H}, {W}, {self.num_latent_tokens}" + x = z_quantized.reshape(N, C * H, W).permute(0, 2, 1) # NLD + x = self.decoder_embed(x) + + batchsize, seq_len, _ = x.shape + + mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to( + x.dtype + ) + mask_tokens = torch.cat( + [ + _expand_token(self.class_embedding, mask_tokens.shape[0]).to( + mask_tokens.dtype + ), + mask_tokens, + ], + dim=1, + ) + mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype) + x = x + self.latent_token_positional_embedding[:seq_len] + x = torch.cat([mask_tokens, x], dim=1) + + text_guidance = self.text_guidance_proj(text_guidance) + text_guidance = text_guidance + self.text_guidance_positional_embedding + x = torch.cat([x, text_guidance], dim=1) + + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + for i in range(self.num_layers): + x = self.transformer[i](x) + x = x.permute(1, 0, 2) # LND -> NLD + x = x[:, 1 : 1 + self.grid_size**2] # remove cls embed + x = self.ln_post(x) + # N L D -> N D H W + x = x.permute(0, 2, 1).reshape( + batchsize, self.width, self.grid_size, self.grid_size + ) + x = self.ffn(x.contiguous()) + x = self.conv_out(x) + return x + + +class WeightTiedLMHead(nn.Module): + def __init__(self, embeddings, target_codebook_size): + super().__init__() + self.weight = embeddings.weight + self.target_codebook_size = target_codebook_size + + def forward(self, x): + # x shape: [batch_size, seq_len, embed_dim] + # Get the weights for the target codebook size + weight = self.weight[ + : self.target_codebook_size + ] # Shape: [target_codebook_size, embed_dim] + # Compute the logits by matrix multiplication + logits = torch.matmul( + x, weight.t() + ) # Shape: [batch_size, seq_len, target_codebook_size] + return logits + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + """ + + def __init__(self, channels): + super().__init__() + self.channels = channels + + self.in_ln = nn.LayerNorm(channels, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(channels, channels, bias=True), + nn.SiLU(), + nn.Linear(channels, channels, bias=True), + ) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True) + ) + + def forward(self, x, y): + shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) + h = modulate(self.in_ln(x), shift_mlp, scale_mlp) + h = self.mlp(h) + return x + gate_mlp * h + + +class FinalLayer(nn.Module): + """ + The final layer adopted from DiT. + """ + + def __init__(self, model_channels, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm( + model_channels, elementwise_affine=False, eps=1e-6 + ) + self.linear = nn.Linear(model_channels, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class SimpleMLPAdaLN(nn.Module): + """ + The MLP for Diffusion Loss. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param z_channels: channels in the condition. + :param num_res_blocks: number of residual blocks per downsample. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + z_channels, + num_res_blocks, + grad_checkpointing=False, + ): + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.grad_checkpointing = grad_checkpointing + + self.time_embed = TimestepEmbedder(model_channels) + self.cond_embed = nn.Linear(z_channels, model_channels) + + self.input_proj = nn.Linear(in_channels, model_channels) + + res_blocks = [] + for i in range(num_res_blocks): + res_blocks.append( + ResBlock( + model_channels, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + self.final_layer = FinalLayer(model_channels, out_channels) + + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding MLP + nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) + nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers + for block in self.res_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, c): + """ + Apply the model to an input batch. + :param x: an [N x C] Tensor of inputs. + :param t: a 1-D batch of timesteps. + :param c: conditioning from AR transformer. + :return: an [N x C] Tensor of outputs. + """ + x = self.input_proj(x) + t = self.time_embed(t) + c = self.cond_embed(c) + + y = t + c + + if self.grad_checkpointing and not torch.jit.is_scripting(): + for block in self.res_blocks: + x = checkpoint(block, x, y) + else: + for block in self.res_blocks: + x = block(x, y) + + return self.final_layer(x, y) + + def forward_with_cfg(self, x, t, c, cfg_scale): + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, c) + eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) diff --git a/src/vqvaes/titok/modules/discriminator.py b/src/vqvaes/titok/modules/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..60ec5e1c610555ce5b1d4f9223262a5d6d54e78b --- /dev/null +++ b/src/vqvaes/titok/modules/discriminator.py @@ -0,0 +1,144 @@ +"""This file contains some base implementation for discrminators. + +Copyright (2024) Bytedance Ltd. and/or its affiliates + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +TODO: Add reference to Mark Weber's tech report on the improved discriminator architecture. +""" + +import functools +import math +from typing import Tuple + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .maskgit_vqgan import Conv2dSame + + +class BlurBlock(torch.nn.Module): + def __init__(self, kernel: Tuple[int] = (1, 3, 3, 1)): + super().__init__() + + kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False) + kernel = kernel[None, :] * kernel[:, None] + kernel /= kernel.sum() + kernel = kernel.unsqueeze(0).unsqueeze(0) + self.register_buffer("kernel", kernel) + + def calc_same_pad(self, i: int, k: int, s: int) -> int: + return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + ic, ih, iw = x.size()[-3:] + pad_h = self.calc_same_pad(i=ih, k=4, s=2) + pad_w = self.calc_same_pad(i=iw, k=4, s=2) + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + + weight = self.kernel.expand(ic, -1, -1, -1) + + out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1]) + return out + + +class NLayerDiscriminator(torch.nn.Module): + def __init__( + self, + num_channels: int = 3, + hidden_channels: int = 128, + num_stages: int = 3, + blur_resample: bool = True, + blur_kernel_size: int = 4, + ): + """Initializes the NLayerDiscriminator. + + Args: + num_channels -> int: The number of input channels. + hidden_channels -> int: The number of hidden channels. + num_stages -> int: The number of stages. + blur_resample -> bool: Whether to use blur resampling. + blur_kernel_size -> int: The blur kernel size. + """ + super().__init__() + assert num_stages > 0, "Discriminator cannot have 0 stages" + assert (not blur_resample) or ( + blur_kernel_size >= 3 and blur_kernel_size <= 5 + ), "Blur kernel size must be in [3,5] when sampling]" + + in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages))) + init_kernel_size = 5 + activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1) + + self.block_in = torch.nn.Sequential( + Conv2dSame(num_channels, hidden_channels, kernel_size=init_kernel_size), + activation(), + ) + + BLUR_KERNEL_MAP = { + 3: (1, 2, 1), + 4: (1, 3, 3, 1), + 5: (1, 4, 6, 4, 1), + } + + discriminator_blocks = [] + for i_level in range(num_stages): + in_channels = hidden_channels * in_channel_mult[i_level] + out_channels = hidden_channels * in_channel_mult[i_level + 1] + block = torch.nn.Sequential( + Conv2dSame( + in_channels, + out_channels, + kernel_size=3, + ), + ( + torch.nn.AvgPool2d(kernel_size=2, stride=2) + if not blur_resample + else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size]) + ), + torch.nn.GroupNorm(32, out_channels), + activation(), + ) + discriminator_blocks.append(block) + + self.blocks = torch.nn.ModuleList(discriminator_blocks) + + self.pool = torch.nn.AdaptiveMaxPool2d((16, 16)) + + self.to_logits = torch.nn.Sequential( + Conv2dSame(out_channels, out_channels, 1), + activation(), + Conv2dSame(out_channels, 1, kernel_size=5), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x -> torch.Tensor: The input tensor. + + Returns: + output -> torch.Tensor: The output tensor. + """ + hidden_states = self.block_in(x) + for block in self.blocks: + hidden_states = block(hidden_states) + + hidden_states = self.pool(hidden_states) + + return self.to_logits(hidden_states) diff --git a/src/vqvaes/titok/modules/ema_model.py b/src/vqvaes/titok/modules/ema_model.py new file mode 100644 index 0000000000000000000000000000000000000000..12d2ab448fd1c6cde0bfd81a9c7cda7604fe9bf6 --- /dev/null +++ b/src/vqvaes/titok/modules/ema_model.py @@ -0,0 +1,260 @@ +"""This file contains some base class implementation for EMA. + +This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). +All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. + +Reference: + https://github.com/huggingface/open-muse/blob/64e1afe033717d795866ab8204484705cd4dc3f7/muse/modeling_ema.py#L8 +""" + +import copy +from typing import Any, Iterable, Optional, Union + +import torch + + +class EMAModel: + """Exponential Moving Average of models weights.""" + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float = 0.9999, + min_decay: float = 0.0, + update_after_step: int = 0, + update_every: int = 1, + current_step: int = 0, + use_ema_warmup: bool = False, + inv_gamma: Union[float, int] = 1.0, + power: Union[float, int] = 2 / 3, + model_cls: Optional[Any] = None, + **model_config_kwargs + ): + """ + Args: + parameters (Iterable[torch.nn.Parameter]): The parameters to track. + decay (float): The decay factor for the exponential moving average. + min_decay (float): The minimum decay factor for the exponential moving average. + update_after_step (int): The number of steps to wait before starting to update the EMA weights. + update_every (int): The number of steps between each EMA update. + current_step (int): The current training step. + use_ema_warmup (bool): Whether to use EMA warmup. + inv_gamma (float): + Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. + power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. + + notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + """ + + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + self.temp_stored_params = None + + self.decay = decay + self.min_decay = min_decay + self.update_after_step = update_after_step + self.update_every = update_every + self.use_ema_warmup = use_ema_warmup + self.inv_gamma = inv_gamma + self.power = power + self.optimization_step = current_step + self.cur_decay_value = None # set in `step()` + + self.model_cls = model_cls + self.model_config_kwargs = model_config_kwargs + + @classmethod + def from_pretrained( + cls, checkpoint, model_cls, **model_config_kwargs + ) -> "EMAModel": + model = model_cls(**model_config_kwargs) + model.load_pretrained_weight(checkpoint) + + ema_model = cls(model.parameters(), model_cls=model_cls, **model_config_kwargs) + return ema_model + + def save_pretrained(self, path): + if self.model_cls is None: + raise ValueError( + "`save_pretrained` can only be used if `model_cls` was defined at __init__." + ) + + if self.model_config_kwargs is None: + raise ValueError( + "`save_pretrained` can only be used if `model_config_kwargs` was defined at __init__." + ) + + model = self.model_cls(**self.model_config_kwargs) + self.copy_to(model.parameters()) + model.save_pretrained_weight(path) + + def set_step(self, optimization_step: int): + self.optimization_step = optimization_step + + def get_decay(self, optimization_step: int) -> float: + """Computes the decay factor for the exponential moving average.""" + step = max(0, optimization_step - self.update_after_step - 1) + + if step <= 0: + return 0.0 + + if self.use_ema_warmup: + cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power + else: + cur_decay_value = (1 + step) / (10 + step) + + cur_decay_value = min(cur_decay_value, self.decay) + # Make sure decay is not smaller than min_decay. + cur_decay_value = max(cur_decay_value, self.min_decay) + return cur_decay_value + + @torch.no_grad() + def step(self, parameters: Iterable[torch.nn.Parameter]): + parameters = list(parameters) + + self.optimization_step += 1 + + if (self.optimization_step - 1) % self.update_every != 0: + return + + # Compute the decay factor for the exponential moving average. + decay = self.get_decay(self.optimization_step) + self.cur_decay_value = decay + one_minus_decay = 1 - decay + + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) + + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """Copies current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.to(param.device).data) + + def to(self, device=None, dtype=None) -> None: + r"""Moves internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + ( + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + ) + for p in self.shadow_params + ] + + def state_dict(self) -> dict: + r"""Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during + checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "min_decay": self.min_decay, + "optimization_step": self.optimization_step, + "update_after_step": self.update_after_step, + "use_ema_warmup": self.use_ema_warmup, + "inv_gamma": self.inv_gamma, + "power": self.power, + "shadow_params": self.shadow_params, + } + + def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Save the current parameters for restoring later. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] + + def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r"""Restores the parameters stored with the `store` method. Useful to validate + the model with EMA parameters without affecting the original optimization process. + Store the parameters before the `copy_to()` method. After validation (or + model saving), use this to restore the former parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + if self.temp_stored_params is None: + raise RuntimeError( + "This ExponentialMovingAverage has no `store()`ed weights to `restore()`" + ) + for c_param, param in zip(self.temp_stored_params, parameters): + param.data.copy_(c_param.data) + + # Better memory-wise. + self.temp_stored_params = None + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the + ema state dict. + + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # Deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict.get("decay", self.decay) + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.min_decay = state_dict.get("min_decay", self.min_decay) + if not isinstance(self.min_decay, float): + raise ValueError("Invalid min_decay") + + self.optimization_step = state_dict.get( + "optimization_step", self.optimization_step + ) + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.update_after_step = state_dict.get( + "update_after_step", self.update_after_step + ) + if not isinstance(self.update_after_step, int): + raise ValueError("Invalid update_after_step") + + self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) + if not isinstance(self.use_ema_warmup, bool): + raise ValueError("Invalid use_ema_warmup") + + self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) + if not isinstance(self.inv_gamma, (float, int)): + raise ValueError("Invalid inv_gamma") + + self.power = state_dict.get("power", self.power) + if not isinstance(self.power, (float, int)): + raise ValueError("Invalid power") + + shadow_params = state_dict.get("shadow_params", None) + if shadow_params is not None: + self.shadow_params = shadow_params + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") diff --git a/src/vqvaes/titok/modules/losses.py b/src/vqvaes/titok/modules/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb97684bfeff3ae57c8ea296c942ce943480c69 --- /dev/null +++ b/src/vqvaes/titok/modules/losses.py @@ -0,0 +1,517 @@ +"""This files contains training loss implementation. + +Copyright (2024) Bytedance Ltd. and/or its affiliates + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Ref: + https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py +""" + +from typing import Mapping, Text, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.cuda.amp import autocast + +from ..diffusion import create_diffusion +from .blocks import SimpleMLPAdaLN +from .perceptual_loss import PerceptualLoss +from .discriminator import NLayerDiscriminator + + +def hinge_d_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor) -> torch.Tensor: + """Hinge loss for discrminator. + + This function is borrowed from + https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py#L20 + """ + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def compute_lecam_loss( + logits_real_mean: torch.Tensor, + logits_fake_mean: torch.Tensor, + ema_logits_real_mean: torch.Tensor, + ema_logits_fake_mean: torch.Tensor, +) -> torch.Tensor: + """Computes the LeCam loss for the given average real and fake logits. + + Args: + logits_real_mean -> torch.Tensor: The average real logits. + logits_fake_mean -> torch.Tensor: The average fake logits. + ema_logits_real_mean -> torch.Tensor: The EMA of the average real logits. + ema_logits_fake_mean -> torch.Tensor: The EMA of the average fake logits. + + Returns: + lecam_loss -> torch.Tensor: The LeCam loss. + """ + lecam_loss = torch.mean( + torch.pow(F.relu(logits_real_mean - ema_logits_fake_mean), 2) + ) + lecam_loss += torch.mean( + torch.pow(F.relu(ema_logits_real_mean - logits_fake_mean), 2) + ) + return lecam_loss + + +class ReconstructionLoss_Stage1(torch.nn.Module): + def __init__(self, config): + super().__init__() + loss_config = config.losses + self.quantizer_weight = loss_config.quantizer_weight + self.target_codebook_size = 1024 + + def forward( + self, + target_codes: torch.Tensor, + reconstructions: torch.Tensor, + quantizer_loss: torch.Tensor, + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + return self._forward_generator(target_codes, reconstructions, quantizer_loss) + + def _forward_generator( + self, + target_codes: torch.Tensor, + reconstructions: torch.Tensor, + quantizer_loss: Mapping[Text, torch.Tensor], + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + reconstructions = reconstructions.contiguous() + loss_fct = nn.CrossEntropyLoss(reduction="mean") + batch_size = reconstructions.shape[0] + reconstruction_loss = loss_fct( + reconstructions.view(batch_size, self.target_codebook_size, -1), + target_codes.view(batch_size, -1), + ) + total_loss = ( + reconstruction_loss + + self.quantizer_weight * quantizer_loss["quantizer_loss"] + ) + + loss_dict = dict( + total_loss=total_loss.clone().detach(), + reconstruction_loss=reconstruction_loss.detach(), + quantizer_loss=( + self.quantizer_weight * quantizer_loss["quantizer_loss"] + ).detach(), + commitment_loss=quantizer_loss["commitment_loss"].detach(), + codebook_loss=quantizer_loss["codebook_loss"].detach(), + ) + + return total_loss, loss_dict + + +class ReconstructionLoss_Stage2(torch.nn.Module): + def __init__(self, config): + """Initializes the losses module. + + Args: + config: A dictionary, the configuration for the model and everything else. + """ + super().__init__() + loss_config = config.losses + self.discriminator = NLayerDiscriminator() + + self.reconstruction_loss = loss_config.reconstruction_loss + self.reconstruction_weight = loss_config.reconstruction_weight + self.quantizer_weight = loss_config.quantizer_weight + self.perceptual_loss = PerceptualLoss(loss_config.perceptual_loss).eval() + self.perceptual_weight = loss_config.perceptual_weight + self.discriminator_iter_start = loss_config.discriminator_start + + self.discriminator_factor = loss_config.discriminator_factor + self.discriminator_weight = loss_config.discriminator_weight + self.lecam_regularization_weight = loss_config.lecam_regularization_weight + self.lecam_ema_decay = loss_config.get("lecam_ema_decay", 0.999) + if self.lecam_regularization_weight > 0.0: + self.register_buffer("ema_real_logits_mean", torch.zeros((1))) + self.register_buffer("ema_fake_logits_mean", torch.zeros((1))) + + self.config = config + + @autocast(enabled=False) + def forward( + self, + inputs: torch.Tensor, + reconstructions: torch.Tensor, + extra_result_dict: Mapping[Text, torch.Tensor], + global_step: int, + mode: str = "generator", + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + # Both inputs and reconstructions are in range [0, 1]. + inputs = inputs.float() + reconstructions = reconstructions.float() + + if mode == "generator": + return self._forward_generator( + inputs, reconstructions, extra_result_dict, global_step + ) + elif mode == "discriminator": + return self._forward_discriminator(inputs, reconstructions, global_step) + else: + raise ValueError(f"Unsupported mode {mode}") + + def should_discriminator_be_trained(self, global_step: int): + return global_step >= self.discriminator_iter_start + + def _forward_generator( + self, + inputs: torch.Tensor, + reconstructions: torch.Tensor, + extra_result_dict: Mapping[Text, torch.Tensor], + global_step: int, + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Generator training step.""" + inputs = inputs.contiguous() + reconstructions = reconstructions.contiguous() + if self.reconstruction_loss == "l1": + reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean") + elif self.reconstruction_loss == "l2": + reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean") + else: + raise ValueError( + f"Unsuppored reconstruction_loss {self.reconstruction_loss}" + ) + reconstruction_loss *= self.reconstruction_weight + + # Compute perceptual loss. + perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean() + + # Compute discriminator loss. + generator_loss = torch.zeros((), device=inputs.device) + discriminator_factor = ( + self.discriminator_factor + if self.should_discriminator_be_trained(global_step) + else 0 + ) + d_weight = 1.0 + if discriminator_factor > 0.0 and self.discriminator_weight > 0.0: + # Disable discriminator gradients. + for param in self.discriminator.parameters(): + param.requires_grad = False + logits_fake = self.discriminator(reconstructions) + generator_loss = -torch.mean(logits_fake) + + d_weight *= self.discriminator_weight + + # Compute quantizer loss. + quantizer_loss = extra_result_dict["quantizer_loss"] + total_loss = ( + reconstruction_loss + + self.perceptual_weight * perceptual_loss + + self.quantizer_weight * quantizer_loss + + d_weight * discriminator_factor * generator_loss + ) + loss_dict = dict( + total_loss=total_loss.clone().detach(), + reconstruction_loss=reconstruction_loss.detach(), + perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(), + quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(), + weighted_gan_loss=( + d_weight * discriminator_factor * generator_loss + ).detach(), + discriminator_factor=torch.tensor(discriminator_factor), + commitment_loss=extra_result_dict["commitment_loss"].detach(), + codebook_loss=extra_result_dict["codebook_loss"].detach(), + d_weight=d_weight, + gan_loss=generator_loss.detach(), + ) + + return total_loss, loss_dict + + def _forward_discriminator( + self, + inputs: torch.Tensor, + reconstructions: torch.Tensor, + global_step: int, + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Discrminator training step.""" + discriminator_factor = ( + self.discriminator_factor + if self.should_discriminator_be_trained(global_step) + else 0 + ) + loss_dict = {} + # Turn the gradients on. + for param in self.discriminator.parameters(): + param.requires_grad = True + + real_images = inputs.detach().requires_grad_(True) + logits_real = self.discriminator(real_images) + logits_fake = self.discriminator(reconstructions.detach()) + + discriminator_loss = discriminator_factor * hinge_d_loss( + logits_real=logits_real, logits_fake=logits_fake + ) + + # optional lecam regularization + lecam_loss = torch.zeros((), device=inputs.device) + if self.lecam_regularization_weight > 0.0: + lecam_loss = ( + compute_lecam_loss( + torch.mean(logits_real), + torch.mean(logits_fake), + self.ema_real_logits_mean, + self.ema_fake_logits_mean, + ) + * self.lecam_regularization_weight + ) + + self.ema_real_logits_mean = ( + self.ema_real_logits_mean * self.lecam_ema_decay + + torch.mean(logits_real).detach() * (1 - self.lecam_ema_decay) + ) + self.ema_fake_logits_mean = ( + self.ema_fake_logits_mean * self.lecam_ema_decay + + torch.mean(logits_fake).detach() * (1 - self.lecam_ema_decay) + ) + + discriminator_loss += lecam_loss + + loss_dict = dict( + discriminator_loss=discriminator_loss.detach(), + logits_real=logits_real.detach().mean(), + logits_fake=logits_fake.detach().mean(), + lecam_loss=lecam_loss.detach(), + ) + return discriminator_loss, loss_dict + + +class ReconstructionLoss_Single_Stage(ReconstructionLoss_Stage2): + def __init__(self, config): + super().__init__(config) + loss_config = config.losses + self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq") + + if self.quantize_mode == "vae": + self.kl_weight = loss_config.get("kl_weight", 1e-6) + logvar_init = loss_config.get("logvar_init", 0.0) + self.logvar = nn.Parameter( + torch.ones(size=()) * logvar_init, requires_grad=False + ) + + def _forward_generator( + self, + inputs: torch.Tensor, + reconstructions: torch.Tensor, + extra_result_dict: Mapping[Text, torch.Tensor], + global_step: int, + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + """Generator training step.""" + inputs = inputs.contiguous() + reconstructions = reconstructions.contiguous() + if self.reconstruction_loss == "l1": + reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean") + elif self.reconstruction_loss == "l2": + reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean") + else: + raise ValueError( + f"Unsuppored reconstruction_loss {self.reconstruction_loss}" + ) + reconstruction_loss *= self.reconstruction_weight + + # Compute perceptual loss. + perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean() + + # Compute discriminator loss. + generator_loss = torch.zeros((), device=inputs.device) + discriminator_factor = ( + self.discriminator_factor + if self.should_discriminator_be_trained(global_step) + else 0 + ) + d_weight = 1.0 + if discriminator_factor > 0.0 and self.discriminator_weight > 0.0: + # Disable discriminator gradients. + for param in self.discriminator.parameters(): + param.requires_grad = False + logits_fake = self.discriminator(reconstructions) + generator_loss = -torch.mean(logits_fake) + + d_weight *= self.discriminator_weight + + if self.quantize_mode == "vq": + # Compute quantizer loss. + quantizer_loss = extra_result_dict["quantizer_loss"] + total_loss = ( + reconstruction_loss + + self.perceptual_weight * perceptual_loss + + self.quantizer_weight * quantizer_loss + + d_weight * discriminator_factor * generator_loss + ) + loss_dict = dict( + total_loss=total_loss.clone().detach(), + reconstruction_loss=reconstruction_loss.detach(), + perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(), + quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(), + weighted_gan_loss=( + d_weight * discriminator_factor * generator_loss + ).detach(), + discriminator_factor=torch.tensor(discriminator_factor), + commitment_loss=extra_result_dict["commitment_loss"].detach(), + codebook_loss=extra_result_dict["codebook_loss"].detach(), + d_weight=d_weight, + gan_loss=generator_loss.detach(), + ) + elif self.quantize_mode == "vae": + # Compute kl loss. + reconstruction_loss = reconstruction_loss / torch.exp(self.logvar) + posteriors = extra_result_dict + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + total_loss = ( + reconstruction_loss + + self.perceptual_weight * perceptual_loss + + self.kl_weight * kl_loss + + d_weight * discriminator_factor * generator_loss + ) + loss_dict = dict( + total_loss=total_loss.clone().detach(), + reconstruction_loss=reconstruction_loss.detach(), + perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(), + kl_loss=(self.kl_weight * kl_loss).detach(), + weighted_gan_loss=( + d_weight * discriminator_factor * generator_loss + ).detach(), + discriminator_factor=torch.tensor(discriminator_factor), + d_weight=d_weight, + gan_loss=generator_loss.detach(), + ) + else: + raise NotImplementedError + + return total_loss, loss_dict + + +class MLMLoss(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.label_smoothing = config.losses.label_smoothing + self.loss_weight_unmasked_token = config.losses.loss_weight_unmasked_token + self.criterion = torch.nn.CrossEntropyLoss( + label_smoothing=self.label_smoothing, reduction="none" + ) + + def forward( + self, inputs: torch.Tensor, targets: torch.Tensor, weights=None + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + inputs = rearrange(inputs, "b n c -> b c n") + loss = self.criterion(inputs, targets) + weights = weights.to(loss) + loss_weights = ( + 1.0 - weights + ) * self.loss_weight_unmasked_token + weights # set 0 to self.loss_weight_unasked_token + loss = (loss * loss_weights).sum() / (loss_weights.sum() + 1e-8) + # we only compute correct tokens on masked tokens + correct_tokens = ((torch.argmax(inputs, dim=1) == targets) * weights).sum( + dim=1 + ) / (weights.sum(1) + 1e-8) + return loss, {"loss": loss, "correct_tokens": correct_tokens.mean()} + + +class ARLoss(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.target_vocab_size = config.model.vq_model.codebook_size + self.criterion = torch.nn.CrossEntropyLoss(reduction="mean") + + def forward( + self, logits: torch.Tensor, labels: torch.Tensor + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + shift_logits = logits[..., :-1, :].permute(0, 2, 1).contiguous() # NLC->NCL + shift_labels = labels.contiguous() + shift_logits = shift_logits.view( + shift_logits.shape[0], self.target_vocab_size, -1 + ) + shift_labels = shift_labels.view(shift_labels.shape[0], -1) + shift_labels = shift_labels.to(shift_logits.device) + loss = self.criterion(shift_logits, shift_labels) + correct_tokens = (torch.argmax(shift_logits, dim=1) == shift_labels).sum( + dim=1 + ) / shift_labels.size(1) + return loss, {"loss": loss, "correct_tokens": correct_tokens.mean()} + + +class DiffLoss(nn.Module): + """Diffusion Loss""" + + def __init__(self, config): + super(DiffLoss, self).__init__() + self.in_channels = config.model.vq_model.token_size + + self.net = SimpleMLPAdaLN( + in_channels=self.in_channels, + model_channels=config.losses.diffloss_w, + out_channels=self.in_channels * 2, # for vlb loss + z_channels=config.model.maskgen.decoder_embed_dim, + num_res_blocks=config.losses.diffloss_d, + grad_checkpointing=config.get("training.grad_checkpointing", False), + ) + + self.train_diffusion = create_diffusion( + timestep_respacing="", noise_schedule="cosine" + ) + self.gen_diffusion = create_diffusion( + timestep_respacing=config.losses.get("num_sampling_steps", "100"), + noise_schedule="cosine", + ) + + def forward(self, target, z, mask=None): + t = torch.randint( + 0, + self.train_diffusion.num_timesteps, + (target.shape[0],), + device=target.device, + ) + model_kwargs = dict(c=z) + loss_dict = self.train_diffusion.training_losses( + self.net, target, t, model_kwargs + ) + loss = loss_dict["loss"] + if mask is not None: + loss = (loss * mask).sum() / mask.sum() + + loss_dict = dict( + diff_loss=loss.clone().mean().detach(), + ) + + return loss.mean(), loss_dict + + def sample(self, z, temperature=1.0, cfg=1.0): + # diffusion loss sampling + if not cfg == 1.0: + noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda() + noise = torch.cat([noise, noise], dim=0) + model_kwargs = dict(c=z, cfg_scale=cfg) + sample_fn = self.net.forward_with_cfg + else: + noise = torch.randn(z.shape[0], self.in_channels).cuda() + model_kwargs = dict(c=z) + sample_fn = self.net.forward + + sampled_token_latent = self.gen_diffusion.p_sample_loop( + sample_fn, + noise.shape, + noise, + clip_denoised=False, + model_kwargs=model_kwargs, + progress=False, + temperature=temperature, + ) + + return sampled_token_latent diff --git a/src/vqvaes/titok/modules/lpips.py b/src/vqvaes/titok/modules/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..804f4a1d4300b4e67e0e8cbdbc1b7ce060c3b72d --- /dev/null +++ b/src/vqvaes/titok/modules/lpips.py @@ -0,0 +1,191 @@ +"""This file contains code for LPIPS. + +This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). +All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. + +Reference: + https://github.com/richzhang/PerceptualSimilarity/ + https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/lpips.py + https://github.com/CompVis/taming-transformers/blob/master/taming/util.py +""" + +import os +import hashlib +import requests +from collections import namedtuple +from tqdm import tqdm + +import torch +import torch.nn as nn + +from torchvision import models + +_LPIPS_MEAN = [-0.030, -0.088, -0.188] +_LPIPS_STD = [0.458, 0.448, 0.450] + + +URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} + +CKPT_MAP = {"vgg_lpips": "vgg.pth"} + +MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric. + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_pretrained(self): + workspace = os.environ.get("WORKSPACE", "") + VGG_PATH = get_ckpt_path( + "vgg_lpips", os.path.join(workspace, "models/vgg_lpips.pth"), check=True + ) + self.load_state_dict( + torch.load(VGG_PATH, map_location=torch.device("cpu")), strict=False + ) + + def forward(self, input, target): + # Notably, the LPIPS w/ pre-trained weights expect the input in the range of [-1, 1]. + # However, our codebase assumes all inputs are in range of [0, 1], and thus a scaling is needed. + input = input * 2.0 - 1.0 + target = target * 2.0 - 1.0 + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( + outs1[kk] + ) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [ + spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns)) + ] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer("shift", torch.Tensor(_LPIPS_MEAN)[None, :, None, None]) + self.register_buffer("scale", torch.Tensor(_LPIPS_STD)[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """A single linear layer which does a 1x1 conv.""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16( + weights=models.VGG16_Weights.IMAGENET1K_V1 + ).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] + ) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/src/vqvaes/titok/modules/maskgit_vqgan.py b/src/vqvaes/titok/modules/maskgit_vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..805fabe27de74b38b5e8c0c26e7e97fb7bccc5fc --- /dev/null +++ b/src/vqvaes/titok/modules/maskgit_vqgan.py @@ -0,0 +1,428 @@ +"""This file contains code for MaskGIT-VQGAN. + +This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). +All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. + +Reference: + https://github.com/huggingface/open-muse/blob/main/muse/modeling_maskgit_vqgan.py +""" + +# Copyright 2023 Google LLC and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""MaskGIT Tokenizer based on VQGAN. + +This tokenizer is a reimplementation of VQGAN [https://arxiv.org/abs/2012.09841] +with several modifications. The non-local layers are removed from VQGAN for +faster speed. +""" + +import math + +import torch +import torch.nn.functional as F +from torch import nn + + +# Conv2D with same padding +class Conv2dSame(nn.Conv2d): + def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + ih, iw = x.size()[-2:] + + pad_h = self.calc_same_pad( + i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0] + ) + pad_w = self.calc_same_pad( + i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1] + ) + + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + return super().forward(x) + + +class ResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int = None, + dropout_prob: float = 0.0, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.out_channels_ = ( + self.in_channels if self.out_channels is None else self.out_channels + ) + + self.norm1 = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + self.conv1 = Conv2dSame( + self.in_channels, self.out_channels_, kernel_size=3, bias=False + ) + + self.norm2 = nn.GroupNorm( + num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True + ) + self.dropout = nn.Dropout(dropout_prob) + self.conv2 = Conv2dSame( + self.out_channels_, self.out_channels_, kernel_size=3, bias=False + ) + + if self.in_channels != self.out_channels_: + self.nin_shortcut = Conv2dSame( + self.out_channels_, self.out_channels_, kernel_size=1, bias=False + ) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels_: + residual = self.nin_shortcut(hidden_states) + + return hidden_states + residual + + +class DownsamplingBlock(nn.Module): + def __init__(self, config, block_idx: int): + super().__init__() + + self.config = config + self.block_idx = block_idx + + in_channel_mult = (1,) + tuple(self.config.channel_mult) + block_in = self.config.hidden_channels * in_channel_mult[self.block_idx] + block_out = ( + self.config.hidden_channels * self.config.channel_mult[self.block_idx] + ) + + res_blocks = nn.ModuleList() + for _ in range(self.config.num_res_blocks): + res_blocks.append( + ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout) + ) + block_in = block_out + self.block = res_blocks + + self.downsample = self.block_idx != self.config.num_resolutions - 1 + + def forward(self, hidden_states): + for res_block in self.block: + hidden_states = res_block(hidden_states) + + if self.downsample: + hidden_states = F.avg_pool2d(hidden_states, kernel_size=2, stride=2) + + return hidden_states + + +class UpsamplingBlock(nn.Module): + def __init__(self, config, block_idx: int): + super().__init__() + + self.config = config + self.block_idx = block_idx + + if self.block_idx == self.config.num_resolutions - 1: + block_in = self.config.hidden_channels * self.config.channel_mult[-1] + else: + block_in = ( + self.config.hidden_channels + * self.config.channel_mult[self.block_idx + 1] + ) + + block_out = ( + self.config.hidden_channels * self.config.channel_mult[self.block_idx] + ) + + res_blocks = [] + for _ in range(self.config.num_res_blocks): + res_blocks.append( + ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout) + ) + block_in = block_out + self.block = nn.ModuleList(res_blocks) + + self.add_upsample = self.block_idx != 0 + if self.add_upsample: + self.upsample_conv = Conv2dSame(block_out, block_out, kernel_size=3) + + def forward(self, hidden_states): + for res_block in self.block: + hidden_states = res_block(hidden_states) + + if self.add_upsample: + hidden_states = F.interpolate( + hidden_states, scale_factor=2.0, mode="nearest" + ) + hidden_states = self.upsample_conv(hidden_states) + + return hidden_states + + +class Encoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # downsampling + self.conv_in = Conv2dSame( + self.config.num_channels, + self.config.hidden_channels, + kernel_size=3, + bias=False, + ) + + downsample_blocks = [] + for i_level in range(self.config.num_resolutions): + downsample_blocks.append(DownsamplingBlock(self.config, block_idx=i_level)) + self.down = nn.ModuleList(downsample_blocks) + + # middle + mid_channels = self.config.hidden_channels * self.config.channel_mult[-1] + res_blocks = nn.ModuleList() + for _ in range(self.config.num_res_blocks): + res_blocks.append( + ResnetBlock( + mid_channels, mid_channels, dropout_prob=self.config.dropout + ) + ) + self.mid = res_blocks + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True + ) + self.conv_out = Conv2dSame(mid_channels, self.config.z_channels, kernel_size=1) + + def forward(self, pixel_values): + # downsampling + hidden_states = self.conv_in(pixel_values) + for block in self.down: + hidden_states = block(hidden_states) + + # middle + for block in self.mid: + hidden_states = block(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class Decoder(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + + # compute in_channel_mult, block_in and curr_res at lowest res + block_in = ( + self.config.hidden_channels + * self.config.channel_mult[self.config.num_resolutions - 1] + ) + curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1) + self.z_shape = (1, self.config.z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = Conv2dSame(self.config.z_channels, block_in, kernel_size=3) + + # middle + res_blocks = nn.ModuleList() + for _ in range(self.config.num_res_blocks): + res_blocks.append( + ResnetBlock(block_in, block_in, dropout_prob=self.config.dropout) + ) + self.mid = res_blocks + + # upsampling + upsample_blocks = [] + for i_level in reversed(range(self.config.num_resolutions)): + upsample_blocks.append(UpsamplingBlock(self.config, block_idx=i_level)) + self.up = nn.ModuleList( + list(reversed(upsample_blocks)) + ) # reverse to get consistent order + + # end + block_out = self.config.hidden_channels * self.config.channel_mult[0] + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_out, eps=1e-6, affine=True + ) + self.conv_out = Conv2dSame(block_out, self.config.num_channels, kernel_size=3) + + def forward(self, hidden_states): + # z to block_in + hidden_states = self.conv_in(hidden_states) + + # middle + for block in self.mid: + hidden_states = block(hidden_states) + + # upsampling + for block in reversed(self.up): + hidden_states = block(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states = F.silu(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + Discretization bottleneck part of the VQ-VAE. + """ + + def __init__(self, num_embeddings, embedding_dim, commitment_cost): + r""" + Args: + num_embeddings: number of vectors in the quantized space. + embedding_dim: dimensionality of the tensors in the quantized space. + Inputs to the modules must be in this format as well. + commitment_cost: scalar which controls the weighting of the loss terms + (see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta). + """ + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.commitment_cost = commitment_cost + + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings) + + def forward(self, hidden_states, return_loss=False): + """ + Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the + closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() + + distances = self.compute_distances(hidden_states) + min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1) + min_encodings = torch.zeros( + min_encoding_indices.shape[0], self.num_embeddings + ).to(hidden_states) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view( + hidden_states.shape + ) + + # reshape to (batch, num_tokens) + min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1) + + # compute loss for embedding + loss = None + if return_loss: + loss = torch.mean( + (z_q.detach() - hidden_states) ** 2 + ) + self.commitment_cost * torch.mean((z_q - hidden_states.detach()) ** 2) + # preserve gradients + z_q = hidden_states + (z_q - hidden_states).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, min_encoding_indices, loss + + def compute_distances(self, hidden_states): + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim)) + emb_weights = self.embedding.weight.t() + + inputs_norm_sq = hidden_states_flattended.pow(2.0).sum(dim=1, keepdim=True) + codebook_t_norm_sq = emb_weights.pow(2.0).sum(dim=0, keepdim=True) + distances = torch.addmm( + inputs_norm_sq + codebook_t_norm_sq, + hidden_states_flattended, + emb_weights, + alpha=-2.0, + ) + return distances + + def get_codebook_entry(self, indices): + # indices are expected to be of shape (batch, num_tokens) + # get quantized latent vectors + if len(indices.shape) == 2: + batch, num_tokens = indices.shape + z_q = self.embedding(indices) + z_q = z_q.reshape( + batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1 + ).permute(0, 3, 1, 2) + elif len(indices.shape) == 3: + batch, height, width = indices.shape + indices = indices.view(batch, -1) + z_q = self.embedding(indices) + z_q = z_q.reshape(batch, height, width, -1).permute(0, 3, 1, 2) + else: + print(indices.shape) + raise NotImplementedError + return z_q + + # adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372 + def get_soft_code(self, hidden_states, temp=1.0, stochastic=False): + hidden_states = hidden_states.permute( + 0, 2, 3, 1 + ).contiguous() # (batch, height, width, channel) + distances = self.compute_distances( + hidden_states + ) # (batch * height * width, num_embeddings) + + soft_code = F.softmax( + -distances / temp, dim=-1 + ) # (batch * height * width, num_embeddings) + if stochastic: + code = torch.multinomial(soft_code, 1) # (batch * height * width, 1) + else: + code = distances.argmin(dim=-1) # (batch * height * width) + + code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width) + batch, num_tokens = code.shape + soft_code = soft_code.reshape( + batch, num_tokens, -1 + ) # (batch, height * width, num_embeddings) + return soft_code, code + + def get_code(self, hidden_states): + # reshape z -> (batch, height, width, channel) + hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() + distances = self.compute_distances(hidden_states) + indices = torch.argmin(distances, axis=1).unsqueeze(1) + indices = indices.reshape(hidden_states.shape[0], -1) + return indices diff --git a/src/vqvaes/titok/modules/perceptual_loss.py b/src/vqvaes/titok/modules/perceptual_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0d0d064f1edd0161726dc0f3f5a55da02925ba5d --- /dev/null +++ b/src/vqvaes/titok/modules/perceptual_loss.py @@ -0,0 +1,130 @@ +"""This file contains perceptual loss module using LPIPS and ConvNeXt-S. + +Copyright (2024) Bytedance Ltd. and/or its affiliates + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F + +from torchvision import models +from .lpips import LPIPS + +_IMAGENET_MEAN = [0.485, 0.456, 0.406] +_IMAGENET_STD = [0.229, 0.224, 0.225] + + +class PerceptualLoss(torch.nn.Module): + def __init__(self, model_name: str = "convnext_s"): + """Initializes the PerceptualLoss class. + + Args: + model_name: A string, the name of the perceptual loss model to use. + + Raise: + ValueError: If the model_name does not contain "lpips" or "convnext_s". + """ + super().__init__() + if ("lpips" not in model_name) and ("convnext_s" not in model_name): + raise ValueError(f"Unsupported Perceptual Loss model name {model_name}") + self.lpips = None + self.convnext = None + self.loss_weight_lpips = None + self.loss_weight_convnext = None + + # Parsing the model name. We support name formatted in + # "lpips-convnext_s-{float_number}-{float_number}", where the + # {float_number} refers to the loss weight for each component. + # E.g., lpips-convnext_s-1.0-2.0 refers to compute the perceptual loss + # using both the convnext_s and lpips, and average the final loss with + # (1.0 * loss(lpips) + 2.0 * loss(convnext_s)) / (1.0 + 2.0). + if "lpips" in model_name: + self.lpips = LPIPS().eval() + + if "convnext_s" in model_name: + self.convnext = models.convnext_small( + weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1 + ).eval() + + if "lpips" in model_name and "convnext_s" in model_name: + loss_config = model_name.split("-")[-2:] + self.loss_weight_lpips, self.loss_weight_convnext = float( + loss_config[0] + ), float(loss_config[1]) + print( + f"self.loss_weight_lpips, self.loss_weight_convnext: {self.loss_weight_lpips}, {self.loss_weight_convnext}" + ) + + self.register_buffer( + "imagenet_mean", torch.Tensor(_IMAGENET_MEAN)[None, :, None, None] + ) + self.register_buffer( + "imagenet_std", torch.Tensor(_IMAGENET_STD)[None, :, None, None] + ) + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, input: torch.Tensor, target: torch.Tensor): + """Computes the perceptual loss. + + Args: + input: A tensor of shape (B, C, H, W), the input image. Normalized to [0, 1]. + target: A tensor of shape (B, C, H, W), the target image. Normalized to [0, 1]. + + Returns: + A scalar tensor, the perceptual loss. + """ + # Always in eval mode. + self.eval() + loss = 0.0 + num_losses = 0.0 + lpips_loss = 0.0 + convnext_loss = 0.0 + # Computes LPIPS loss, if available. + if self.lpips is not None: + lpips_loss = self.lpips(input, target) + if self.loss_weight_lpips is None: + loss += lpips_loss + num_losses += 1 + else: + num_losses += self.loss_weight_lpips + loss += self.loss_weight_lpips * lpips_loss + + if self.convnext is not None: + # Computes ConvNeXt-s loss, if available. + input = torch.nn.functional.interpolate( + input, size=224, mode="bilinear", align_corners=False, antialias=True + ) + target = torch.nn.functional.interpolate( + target, size=224, mode="bilinear", align_corners=False, antialias=True + ) + pred_input = self.convnext((input - self.imagenet_mean) / self.imagenet_std) + pred_target = self.convnext( + (target - self.imagenet_mean) / self.imagenet_std + ) + convnext_loss = torch.nn.functional.mse_loss( + pred_input, pred_target, reduction="mean" + ) + + if self.loss_weight_convnext is None: + num_losses += 1 + loss += convnext_loss + else: + num_losses += self.loss_weight_convnext + loss += self.loss_weight_convnext * convnext_loss + + # weighted avg. + loss = loss / num_losses + return loss diff --git a/src/vqvaes/titok/quantizer/__init__.py b/src/vqvaes/titok/quantizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2767f5f2649a2158406669ba30dd4c67fe8dbf2 --- /dev/null +++ b/src/vqvaes/titok/quantizer/__init__.py @@ -0,0 +1 @@ +from .quantizer import VectorQuantizer, DiagonalGaussianDistribution diff --git a/src/vqvaes/titok/quantizer/quantizer.py b/src/vqvaes/titok/quantizer/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1540f7c7c88241c29ae01ab086022bab6db89244 --- /dev/null +++ b/src/vqvaes/titok/quantizer/quantizer.py @@ -0,0 +1,202 @@ +"""Vector quantizer. + +Copyright (2024) Bytedance Ltd. and/or its affiliates + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Reference: + https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py + https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py + https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/distributions/distributions.py + https://github.com/lyndonzheng/CVQ-VAE/blob/main/quantise.py +""" + +from typing import Mapping, Text, Tuple + +import torch +from einops import rearrange +from accelerate.utils.operations import gather +from torch.cuda.amp import autocast + + +class VectorQuantizer(torch.nn.Module): + def __init__( + self, + codebook_size: int = 1024, + token_size: int = 256, + commitment_cost: float = 0.25, + use_l2_norm: bool = False, + clustering_vq: bool = False, + ): + super().__init__() + self.codebook_size = codebook_size + self.token_size = token_size + self.commitment_cost = commitment_cost + + self.embedding = torch.nn.Embedding(codebook_size, token_size) + self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) + self.use_l2_norm = use_l2_norm + + self.clustering_vq = clustering_vq + if clustering_vq: + self.decay = 0.99 + self.register_buffer("embed_prob", torch.zeros(self.codebook_size)) + + # Ensure quantization is performed using f32 + @autocast(enabled=False) + def forward( + self, z: torch.Tensor + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + z = z.float() + z = rearrange(z, "b c h w -> b h w c").contiguous() + z_flattened = rearrange(z, "b h w c -> (b h w) c") + unnormed_z_flattened = z_flattened + + if self.use_l2_norm: + z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1) + embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1) + else: + embedding = self.embedding.weight + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(embedding**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", z_flattened, embedding.T) + ) + + min_encoding_indices = torch.argmin(d, dim=1) # num_ele + z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape) + + if self.use_l2_norm: + z = torch.nn.functional.normalize(z, dim=-1) + + # compute loss for embedding + commitment_loss = self.commitment_cost * torch.mean( + (z_quantized.detach() - z) ** 2 + ) + codebook_loss = torch.mean((z_quantized - z.detach()) ** 2) + + if self.clustering_vq and self.training: + with torch.no_grad(): + # Gather distance matrix from all GPUs. + encoding_indices = gather(min_encoding_indices) + if len(min_encoding_indices.shape) != 1: + raise ValueError( + f"min_encoding_indices in a wrong shape, {min_encoding_indices.shape}" + ) + # Compute and update the usage of each entry in the codebook. + encodings = torch.zeros( + encoding_indices.shape[0], self.codebook_size, device=z.device + ) + encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) + avg_probs = torch.mean(encodings, dim=0) + self.embed_prob.mul_(self.decay).add_(avg_probs, alpha=1 - self.decay) + # Closest sampling to update the codebook. + all_d = gather(d) + all_unnormed_z_flattened = gather(unnormed_z_flattened).detach() + if all_d.shape[0] != all_unnormed_z_flattened.shape[0]: + raise ValueError( + "all_d and all_unnormed_z_flattened have different length" + + f"{all_d.shape}, {all_unnormed_z_flattened.shape}" + ) + indices = torch.argmin(all_d, dim=0) + random_feat = all_unnormed_z_flattened[indices] + # Decay parameter based on the average usage. + decay = ( + torch.exp( + -(self.embed_prob * self.codebook_size * 10) / (1 - self.decay) + - 1e-3 + ) + .unsqueeze(1) + .repeat(1, self.token_size) + ) + self.embedding.weight.data = ( + self.embedding.weight.data * (1 - decay) + random_feat * decay + ) + + loss = commitment_loss + codebook_loss + + # preserve gradients + z_quantized = z + (z_quantized - z).detach() + + # reshape back to match original input shape + z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous() + + result_dict = dict( + quantizer_loss=loss, + commitment_loss=commitment_loss, + codebook_loss=codebook_loss, + min_encoding_indices=min_encoding_indices.view( + z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3] + ), + ) + + return z_quantized, result_dict + + def get_codebook_entry(self, indices): + if len(indices.shape) == 1: + z_quantized = self.embedding(indices) + elif len(indices.shape) == 2: + z_quantized = torch.einsum("bd,dn->bn", indices, self.embedding.weight) + else: + raise NotImplementedError + if self.use_l2_norm: + z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) + return z_quantized + + +class DiagonalGaussianDistribution(object): + @autocast(enabled=False) + def __init__(self, parameters, deterministic=False): + """Initializes a Gaussian distribution instance given the parameters. + + Args: + parameters (torch.Tensor): The parameters for the Gaussian distribution. It is expected + to be in shape [B, 2 * C, *], where B is batch size, and C is the embedding dimension. + First C channels are used for mean and last C are used for logvar in the Gaussian distribution. + deterministic (bool): Whether to use deterministic sampling. When it is true, the sampling results + is purely based on mean (i.e., std = 0). + """ + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters.float(), 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + @autocast(enabled=False) + def sample(self): + x = self.mean.float() + self.std.float() * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + @autocast(enabled=False) + def mode(self): + return self.mean + + @autocast(enabled=False) + def kl(self): + if self.deterministic: + return torch.Tensor([0.0]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean.float(), 2) + + self.var.float() + - 1.0 + - self.logvar.float(), + dim=[1, 2], + ) diff --git a/src/vqvaes/titok/titok.py b/src/vqvaes/titok/titok.py new file mode 100644 index 0000000000000000000000000000000000000000..339fab2d8c72f32d1de9c5a8adceebf3610bb052 --- /dev/null +++ b/src/vqvaes/titok/titok.py @@ -0,0 +1,240 @@ +"""This file contains the model definition of TiTok. + +Copyright (2024) Bytedance Ltd. and/or its affiliates + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn as nn +from einops import rearrange + +from .modules.base_model import BaseModel +from .modules.blocks import TiTokEncoder, TiTokDecoder +from .quantizer.quantizer import VectorQuantizer, DiagonalGaussianDistribution +from .modules.maskgit_vqgan import Encoder as Pixel_Eecoder +from .modules.maskgit_vqgan import Decoder as Pixel_Decoder +from .modules.maskgit_vqgan import VectorQuantizer as Pixel_Quantizer +import json +from omegaconf import OmegaConf +from pathlib import Path + +from huggingface_hub import PyTorchModelHubMixin + + +class PretrainedTokenizer(nn.Module): + def __init__(self, pretrained_weight): + super().__init__() + conf = OmegaConf.create( + { + "channel_mult": [1, 1, 2, 2, 4], + "num_resolutions": 5, + "dropout": 0.0, + "hidden_channels": 128, + "num_channels": 3, + "num_res_blocks": 2, + "resolution": 256, + "z_channels": 256, + } + ) + self.encoder = Pixel_Eecoder(conf) + self.decoder = Pixel_Decoder(conf) + self.quantize = Pixel_Quantizer( + num_embeddings=1024, embedding_dim=256, commitment_cost=0.25 + ) + # Load pretrained weights + self.load_state_dict( + torch.load(pretrained_weight, map_location=torch.device("cpu")), strict=True + ) + + self.eval() + for param in self.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode(self, x): + hidden_states = self.encoder(x) + quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states) + return codebook_indices.detach() + + @torch.no_grad() + def decode(self, codes): + quantized_states = self.quantize.get_codebook_entry(codes) + rec_images = self.decoder(quantized_states) + rec_images = torch.clamp(rec_images, 0.0, 1.0) + return rec_images.detach() + + @torch.no_grad() + def decode_tokens(self, codes): + return self.decode(codes) + + +class TiTok( + BaseModel, + PyTorchModelHubMixin, + tags=["arxiv:2406.07550", "image-tokenization"], + repo_url="https://github.com/bytedance/1d-tokenizer", + license="apache-2.0", +): + def __init__(self, config): + + if isinstance(config, dict): + config = OmegaConf.create(config) + + super().__init__() + self.config = config + # This should be False for stage1 and True for stage2. + self.finetune_decoder = config.model.vq_model.get("finetune_decoder", True) + + self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq") + if self.quantize_mode not in ["vq", "vae"]: + raise ValueError(f"Unsupported quantize mode {self.quantize_mode}.") + + if self.finetune_decoder and self.quantize_mode not in ["vq"]: + raise ValueError( + "Only supprot finetune_decoder with vq quantization for now." + ) + + self.encoder = TiTokEncoder(config) + self.decoder = TiTokDecoder(config) + + self.num_latent_tokens = config.model.vq_model.num_latent_tokens + scale = self.encoder.width**-0.5 + self.latent_tokens = nn.Parameter( + scale * torch.randn(self.num_latent_tokens, self.encoder.width) + ) + + self.apply(self._init_weights) + + if self.quantize_mode == "vq": + self.quantize = VectorQuantizer( + codebook_size=config.model.vq_model.codebook_size, + token_size=config.model.vq_model.token_size, + commitment_cost=config.model.vq_model.commitment_cost, + use_l2_norm=config.model.vq_model.use_l2_norm, + ) + elif self.quantize_mode == "vae": + self.quantize = DiagonalGaussianDistribution + else: + raise NotImplementedError + + if self.finetune_decoder: + # Freeze encoder/quantizer/latent tokens + self.latent_tokens.requires_grad_(False) + self.encoder.eval() + self.encoder.requires_grad_(False) + self.quantize.eval() + self.quantize.requires_grad_(False) + + # Include MaskGiT-VQGAN's quantizer and decoder + self.pixel_quantize = Pixel_Quantizer( + num_embeddings=1024, embedding_dim=256, commitment_cost=0.25 + ) + self.pixel_decoder = Pixel_Decoder( + OmegaConf.create( + { + "channel_mult": [1, 1, 2, 2, 4], + "num_resolutions": 5, + "dropout": 0.0, + "hidden_channels": 128, + "num_channels": 3, + "num_res_blocks": 2, + "resolution": 256, + "z_channels": 256, + } + ) + ) + + def _save_pretrained(self, save_directory: Path) -> None: + """Save weights and config to a local directory.""" + # Assume 'self.config' is your DictConfig object + # Convert to a regular dictionary + dict_config = OmegaConf.to_container(self.config) + # Save as JSON + file_path = Path(save_directory) / "config.json" + with open(file_path, "w") as json_file: + json.dump(dict_config, json_file, indent=4) + super()._save_pretrained(save_directory) + + def _init_weights(self, module): + """Initialize the weights. + :param: + module -> torch.nn.Module: module to initialize + """ + if ( + isinstance(module, nn.Linear) + or isinstance(module, nn.Conv1d) + or isinstance(module, nn.Conv2d) + ): + module.weight.data = nn.init.trunc_normal_( + module.weight.data, mean=0.0, std=0.02 + ) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data = nn.init.trunc_normal_( + module.weight.data, mean=0.0, std=0.02 + ) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def encode(self, x): + if self.finetune_decoder: + with torch.no_grad(): + self.encoder.eval() + self.quantize.eval() + z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens) + z_quantized, result_dict = self.quantize(z) + result_dict["quantizer_loss"] *= 0 + result_dict["commitment_loss"] *= 0 + result_dict["codebook_loss"] *= 0 + else: + z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens) + if self.quantize_mode == "vq": + z_quantized, result_dict = self.quantize(z) + elif self.quantize_mode == "vae": + posteriors = self.quantize(z) + z_quantized = posteriors.sample() + result_dict = posteriors + + return z_quantized, result_dict + + def decode(self, z_quantized): + decoded = self.decoder(z_quantized) + if self.finetune_decoder: + quantized_states = torch.einsum( + "nchw,cd->ndhw", + decoded.softmax(1), + self.pixel_quantize.embedding.weight, + ) + decoded = self.pixel_decoder(quantized_states) + return decoded + + def decode_tokens(self, tokens): + if self.quantize_mode == "vq": + tokens = tokens.squeeze(1) + batch, seq_len = tokens.shape # B x N + z_quantized = self.quantize.get_codebook_entry(tokens.reshape(-1)).reshape( + batch, 1, seq_len, -1 + ) + z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous() + elif self.quantize_mode == "vae": + z_quantized = tokens + decoded = self.decode(z_quantized) + return decoded + + def forward(self, x): + z_quantized, result_dict = self.encode(x) + decoded = self.decode(z_quantized) + return decoded, result_dict["min_encoding_indices"], z_quantized diff --git a/src/vqvaes/var/basic_vae.py b/src/vqvaes/var/basic_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..feef01c9126bf66f036586ec9eda76757427c23a --- /dev/null +++ b/src/vqvaes/var/basic_vae.py @@ -0,0 +1,294 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# this file only provides the 2 modules used in VQVAE +__all__ = [ + "Encoder", + "Decoder", +] + + +""" +References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py +""" + + +# swish +def nonlinearity(x): + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample2x(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + return self.conv(F.interpolate(x, scale_factor=2, mode="nearest")) + + +class Downsample2x(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode="constant", value=0)) + + +class ResnetBlock(nn.Module): + def __init__( + self, *, in_channels, out_channels=None, dropout + ): # conv_shortcut=False, # conv_shortcut: always False in VAE + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity() + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x): + h = self.conv1(F.silu(self.norm1(x), inplace=True)) + h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True))) + return self.nin_shortcut(x) + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.C = in_channels + + self.norm = Normalize(in_channels) + self.qkv = torch.nn.Conv2d( + in_channels, 3 * in_channels, kernel_size=1, stride=1, padding=0 + ) + self.w_ratio = int(in_channels) ** (-0.5) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + qkv = self.qkv(self.norm(x)) + B, _, H, W = qkv.shape # should be B,3C,H,W + C = self.C + q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1) + + # compute attention + q = q.view(B, C, H * W).contiguous() + q = q.permute(0, 2, 1).contiguous() # B,HW,C + k = k.view(B, C, H * W).contiguous() # B,C,HW + w = torch.bmm(q, k).mul_( + self.w_ratio + ) # B,HW,HW w[B,i,j]=sum_c q[B,i,C]k[B,C,j] + w = F.softmax(w, dim=2) + + # attend to values + v = v.view(B, C, H * W).contiguous() + w = w.permute(0, 2, 1).contiguous() # B,HW,HW (first HW of k, second of q) + h = torch.bmm(v, w) # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j] + h = h.view(B, C, H, W).contiguous() + + return x + self.proj_out(h) + + +def make_attn(in_channels, using_sa=True): + return AttnBlock(in_channels) if using_sa else nn.Identity() + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch=128, + ch_mult=(1, 2, 4, 8), + num_res_blocks=2, + dropout=0.0, + in_channels=3, + z_channels, + double_z=False, + using_sa=True, + using_mid_sa=True, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.downsample_ratio = 2 ** (self.num_resolutions - 1) + self.num_res_blocks = num_res_blocks + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, dropout=dropout + ) + ) + block_in = block_out + if i_level == self.num_resolutions - 1 and using_sa: + attn.append(make_attn(block_in, using_sa=True)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample2x(block_in) + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dropout=dropout + ) + self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dropout=dropout + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + (2 * z_channels if double_z else z_channels), + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # downsampling + h = self.conv_in(x) + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = self.down[i_level].downsample(h) + + # middle + h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h))) + + # end + h = self.conv_out(F.silu(self.norm_out(h), inplace=True)) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch=128, + ch_mult=(1, 2, 4, 8), + num_res_blocks=2, + dropout=0.0, + in_channels=3, # in_channels: raw img channels + z_channels, + using_sa=True, + using_mid_sa=True, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_channels + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dropout=dropout + ) + self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dropout=dropout + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, dropout=dropout + ) + ) + block_in = block_out + if i_level == self.num_resolutions - 1 and using_sa: + attn.append(make_attn(block_in, using_sa=True)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample2x(block_in) + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # z to block_in + # middle + h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z)))) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.conv_out(F.silu(self.norm_out(h), inplace=True)) + return h diff --git a/src/vqvaes/var/dist.py b/src/vqvaes/var/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..54a0618d8827851a3ea4a2e65065b3fee0594704 --- /dev/null +++ b/src/vqvaes/var/dist.py @@ -0,0 +1,231 @@ +import datetime +import functools +import os +import sys +from typing import List +from typing import Union + +import torch +import torch.distributed as tdist +import torch.multiprocessing as mp + +__rank, __local_rank, __world_size, __device = ( + 0, + 0, + 1, + "cuda" if torch.cuda.is_available() else "cpu", +) +__initialized = False + + +def initialized(): + return __initialized + + +def initialize(fork=False, backend="nccl", gpu_id_if_not_distibuted=0, timeout=30): + global __device + if not torch.cuda.is_available(): + print( + f"[dist initialize] cuda is not available, use cpu instead", file=sys.stderr + ) + return + elif "RANK" not in os.environ: + torch.cuda.set_device(gpu_id_if_not_distibuted) + __device = torch.empty(1).cuda().device + print( + f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', + file=sys.stderr, + ) + return + # then 'RANK' must exist + global_rank, num_gpus = int(os.environ["RANK"]), torch.cuda.device_count() + local_rank = global_rank % num_gpus + torch.cuda.set_device(local_rank) + + # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 + if mp.get_start_method(allow_none=True) is None: + method = "fork" if fork else "spawn" + print(f"[dist initialize] mp method={method}") + mp.set_start_method(method) + tdist.init_process_group( + backend=backend, timeout=datetime.timedelta(seconds=timeout * 60) + ) + + global __rank, __local_rank, __world_size, __initialized + __local_rank = local_rank + __rank, __world_size = tdist.get_rank(), tdist.get_world_size() + __device = torch.empty(1).cuda().device + __initialized = True + + assert tdist.is_initialized(), "torch.distributed is not initialized!" + print(f"[lrk={get_local_rank()}, rk={get_rank()}]") + + +def get_rank(): + return __rank + + +def get_local_rank(): + return __local_rank + + +def get_world_size(): + return __world_size + + +def get_device(): + return __device + + +def set_gpu_id(gpu_id: int): + if gpu_id is None: + return + global __device + if isinstance(gpu_id, (str, int)): + torch.cuda.set_device(int(gpu_id)) + __device = torch.empty(1).cuda().device + else: + raise NotImplementedError + + +def is_master(): + return __rank == 0 + + +def is_local_master(): + return __local_rank == 0 + + +def new_group(ranks: List[int]): + if __initialized: + return tdist.new_group(ranks=ranks) + return None + + +def barrier(): + if __initialized: + tdist.barrier() + + +def allreduce(t: torch.Tensor, async_op=False): + if __initialized: + if not t.is_cuda: + cu = t.detach().cuda() + ret = tdist.all_reduce(cu, async_op=async_op) + t.copy_(cu.cpu()) + else: + ret = tdist.all_reduce(t, async_op=async_op) + return ret + return None + + +def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: + if __initialized: + if not t.is_cuda: + t = t.cuda() + ls = [torch.empty_like(t) for _ in range(__world_size)] + tdist.all_gather(ls, t) + else: + ls = [t] + if cat: + ls = torch.cat(ls, dim=0) + return ls + + +def allgather_diff_shape( + t: torch.Tensor, cat=True +) -> Union[List[torch.Tensor], torch.Tensor]: + if __initialized: + if not t.is_cuda: + t = t.cuda() + + t_size = torch.tensor(t.size(), device=t.device) + ls_size = [torch.empty_like(t_size) for _ in range(__world_size)] + tdist.all_gather(ls_size, t_size) + + max_B = max(size[0].item() for size in ls_size) + pad = max_B - t_size[0].item() + if pad: + pad_size = (pad, *t.size()[1:]) + t = torch.cat((t, t.new_empty(pad_size)), dim=0) + + ls_padded = [torch.empty_like(t) for _ in range(__world_size)] + tdist.all_gather(ls_padded, t) + ls = [] + for t, size in zip(ls_padded, ls_size): + ls.append(t[: size[0].item()]) + else: + ls = [t] + if cat: + ls = torch.cat(ls, dim=0) + return ls + + +def broadcast(t: torch.Tensor, src_rank) -> None: + if __initialized: + if not t.is_cuda: + cu = t.detach().cuda() + tdist.broadcast(cu, src=src_rank) + t.copy_(cu.cpu()) + else: + tdist.broadcast(t, src=src_rank) + + +def dist_fmt_vals( + val: float, fmt: Union[str, None] = "%.2f" +) -> Union[torch.Tensor, List]: + if not initialized(): + return torch.tensor([val]) if fmt is None else [fmt % val] + + ts = torch.zeros(__world_size) + ts[__rank] = val + allreduce(ts) + if fmt is None: + return ts + return [fmt % v for v in ts.cpu().numpy().tolist()] + + +def master_only(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + force = kwargs.pop("force", False) + if force or is_master(): + ret = func(*args, **kwargs) + else: + ret = None + barrier() + return ret + + return wrapper + + +def local_master_only(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + force = kwargs.pop("force", False) + if force or is_local_master(): + ret = func(*args, **kwargs) + else: + ret = None + barrier() + return ret + + return wrapper + + +def for_visualize(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if is_master(): + # with torch.no_grad(): + ret = func(*args, **kwargs) + else: + ret = None + return ret + + return wrapper + + +def finalize(): + if __initialized: + tdist.destroy_process_group() diff --git a/src/vqvaes/var/quant.py b/src/vqvaes/var/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..5541a5221bc719f9ed6681b5d0b81fdfa750e2ba --- /dev/null +++ b/src/vqvaes/var/quant.py @@ -0,0 +1,409 @@ +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from torch import distributed as tdist, nn as nn +from torch.nn import functional as F + +from . import dist + + +# this file only provides the VectorQuantizer2 used in VQVAE +__all__ = [ + "VectorQuantizer2", +] + + +class VectorQuantizer2(nn.Module): + # VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25 + def __init__( + self, + vocab_size, + Cvae, + using_znorm, + beta: float = 0.25, + default_qresi_counts=0, + v_patch_nums=None, + quant_resi=0.5, + share_quant_resi=4, # share_quant_resi: args.qsr + ): + super().__init__() + self.vocab_size: int = vocab_size + self.Cvae: int = Cvae + self.using_znorm: bool = using_znorm + self.v_patch_nums: Tuple[int] = v_patch_nums + + self.quant_resi_ratio = quant_resi + if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales + self.quant_resi = PhiNonShared( + [ + (Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) + for _ in range(default_qresi_counts or len(self.v_patch_nums)) + ] + ) + elif share_quant_resi == 1: # fully shared: only a single \phi for K scales + self.quant_resi = PhiShared( + Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity() + ) + else: # partially shared: \phi_{1 to share_quant_resi} for K scales + self.quant_resi = PhiPartiallyShared( + nn.ModuleList( + [ + ( + Phi(Cvae, quant_resi) + if abs(quant_resi) > 1e-6 + else nn.Identity() + ) + for _ in range(share_quant_resi) + ] + ) + ) + + self.register_buffer( + "ema_vocab_hit_SV", + torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0), + ) + self.record_hit = 0 + + self.beta: float = beta + self.embedding = nn.Embedding(self.vocab_size, self.Cvae) + + # only used for progressive training of VAR (not supported yet, will be tested and supported in the future) + self.prog_si = -1 # progressive training: not supported yet, prog_si always -1 + + def eini(self, eini): + if eini > 0: + nn.init.trunc_normal_(self.embedding.weight.data, std=eini) + elif eini < 0: + self.embedding.weight.data.uniform_( + -abs(eini) / self.vocab_size, abs(eini) / self.vocab_size + ) + + def extra_repr(self) -> str: + return f"{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}" + + # ===================== `forward` is only used in VAE training ===================== + def forward( + self, f_BChw: torch.Tensor, ret_usages=False + ) -> Tuple[torch.Tensor, List[float], torch.Tensor]: + dtype = f_BChw.dtype + if dtype != torch.float32: + f_BChw = f_BChw.float() + B, C, H, W = f_BChw.shape + f_no_grad = f_BChw.detach() + + f_rest = f_no_grad.clone() + f_hat = torch.zeros_like(f_rest) + + with torch.cuda.amp.autocast(enabled=False): + mean_vq_loss: torch.Tensor = 0.0 + vocab_hit_V = torch.zeros( + self.vocab_size, dtype=torch.float, device=f_BChw.device + ) + SN = len(self.v_patch_nums) + for si, pn in enumerate(self.v_patch_nums): # from small to large + # find the nearest embedding + if self.using_znorm: + rest_NC = ( + F.interpolate(f_rest, size=(pn, pn), mode="area") + .permute(0, 2, 3, 1) + .reshape(-1, C) + if (si != SN - 1) + else f_rest.permute(0, 2, 3, 1).reshape(-1, C) + ) + rest_NC = F.normalize(rest_NC, dim=-1) + idx_N = torch.argmax( + rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0), + dim=1, + ) + else: + rest_NC = ( + F.interpolate(f_rest, size=(pn, pn), mode="area") + .permute(0, 2, 3, 1) + .reshape(-1, C) + if (si != SN - 1) + else f_rest.permute(0, 2, 3, 1).reshape(-1, C) + ) + d_no_grad = torch.sum( + rest_NC.square(), dim=1, keepdim=True + ) + torch.sum( + self.embedding.weight.data.square(), dim=1, keepdim=False + ) + d_no_grad.addmm_( + rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1 + ) # (B*h*w, vocab_size) + idx_N = torch.argmin(d_no_grad, dim=1) + + hit_V = idx_N.bincount(minlength=self.vocab_size).float() + if self.training: + if dist.initialized(): + handler = tdist.all_reduce(hit_V, async_op=True) + + # calc loss + idx_Bhw = idx_N.view(B, pn, pn) + h_BChw = ( + F.interpolate( + self.embedding(idx_Bhw).permute(0, 3, 1, 2), + size=(H, W), + mode="bicubic", + ).contiguous() + if (si != SN - 1) + else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous() + ) + h_BChw = self.quant_resi[si / (SN - 1)](h_BChw) + f_hat = f_hat + h_BChw + f_rest -= h_BChw + + if self.training and dist.initialized(): + handler.wait() + if self.record_hit == 0: + self.ema_vocab_hit_SV[si].copy_(hit_V) + elif self.record_hit < 100: + self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1)) + else: + self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01)) + self.record_hit += 1 + vocab_hit_V.add_(hit_V) + mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_( + self.beta + ) + F.mse_loss(f_hat, f_no_grad) + + mean_vq_loss *= 1.0 / SN + f_hat = (f_hat.data - f_no_grad).add_(f_BChw) + + # margin = tdist.get_world_size() * (f_BChw.numel() / f_BChw.shape[1]) / self.vocab_size * 0.08 + margin = pn * pn / 100 + if ret_usages: + usages = [ + (self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 + for si, pn in enumerate(self.v_patch_nums) + ] + else: + usages = None + return f_hat, usages, mean_vq_loss + + # ===================== `forward` is only used in VAE training ===================== + + def embed_to_fhat( + self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False + ) -> Union[List[torch.Tensor], torch.Tensor]: + ls_f_hat_BChw = [] + B = ms_h_BChw[0].shape[0] + H = W = self.v_patch_nums[-1] + SN = len(self.v_patch_nums) + if all_to_max_scale: + f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32) + for si, pn in enumerate(self.v_patch_nums): # from small to large + h_BChw = ms_h_BChw[si] + if si < len(self.v_patch_nums) - 1: + h_BChw = F.interpolate(h_BChw, size=(H, W), mode="bicubic") + h_BChw = self.quant_resi[si / (SN - 1)](h_BChw) + f_hat.add_(h_BChw) + if last_one: + ls_f_hat_BChw = f_hat + else: + ls_f_hat_BChw.append(f_hat.clone()) + else: + # WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above) + # WARNING: this should only be used for experimental purpose + f_hat = ms_h_BChw[0].new_zeros( + B, + self.Cvae, + self.v_patch_nums[0], + self.v_patch_nums[0], + dtype=torch.float32, + ) + for si, pn in enumerate(self.v_patch_nums): # from small to large + f_hat = F.interpolate(f_hat, size=(pn, pn), mode="bicubic") + h_BChw = self.quant_resi[si / (SN - 1)](ms_h_BChw[si]) + f_hat.add_(h_BChw) + if last_one: + ls_f_hat_BChw = f_hat + else: + ls_f_hat_BChw.append(f_hat) + + return ls_f_hat_BChw + + def f_to_idxBl_or_fhat( + self, + f_BChw: torch.Tensor, + to_fhat: bool, + v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, + ) -> List[ + Union[torch.Tensor, torch.LongTensor] + ]: # z_BChw is the feature from inp_img_no_grad + B, C, H, W = f_BChw.shape + f_no_grad = f_BChw.detach() + f_rest = f_no_grad.clone() + f_hat = torch.zeros_like(f_rest) + + f_hat_or_idx_Bl: List[torch.Tensor] = [] + + patch_hws = [ + (pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) + for pn in (v_patch_nums or self.v_patch_nums) + ] # from small to large + assert ( + patch_hws[-1][0] == H and patch_hws[-1][1] == W + ), f"{patch_hws[-1]=} != ({H=}, {W=})" + + SN = len(patch_hws) + for si, (ph, pw) in enumerate(patch_hws): # from small to large + if 0 <= self.prog_si < si: + break # progressive training: not supported yet, prog_si always -1 + # find the nearest embedding + z_NC = ( + F.interpolate(f_rest, size=(ph, pw), mode="area") + .permute(0, 2, 3, 1) + .reshape(-1, C) + if (si != SN - 1) + else f_rest.permute(0, 2, 3, 1).reshape(-1, C) + ) + if self.using_znorm: + z_NC = F.normalize(z_NC, dim=-1) + idx_N = torch.argmax( + z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1 + ) + else: + d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum( + self.embedding.weight.data.square(), dim=1, keepdim=False + ) + d_no_grad.addmm_( + z_NC, self.embedding.weight.data.T, alpha=-2, beta=1 + ) # (B*h*w, vocab_size) + idx_N = torch.argmin(d_no_grad, dim=1) + + idx_Bhw = idx_N.view(B, ph, pw) + h_BChw = ( + F.interpolate( + self.embedding(idx_Bhw).permute(0, 3, 1, 2), + size=(H, W), + mode="bicubic", + ).contiguous() + if (si != SN - 1) + else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous() + ) + h_BChw = self.quant_resi[si / (SN - 1)](h_BChw) + f_hat.add_(h_BChw) + f_rest.sub_(h_BChw) + f_hat_or_idx_Bl.append( + f_hat.clone() if to_fhat else idx_N.reshape(B, ph * pw) + ) + + return f_hat_or_idx_Bl + + # ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input ===================== + def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor: + next_scales = [] + B = gt_ms_idx_Bl[0].shape[0] + C = self.Cvae + H = W = self.v_patch_nums[-1] + SN = len(self.v_patch_nums) + + f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32) + pn_next: int = self.v_patch_nums[0] + for si in range(SN - 1): + if self.prog_si == 0 or (0 <= self.prog_si - 1 < si): + break # progressive training: not supported yet, prog_si always -1 + h_BChw = F.interpolate( + self.embedding(gt_ms_idx_Bl[si]) + .transpose_(1, 2) + .view(B, C, pn_next, pn_next), + size=(H, W), + mode="bicubic", + ) + f_hat.add_(self.quant_resi[si / (SN - 1)](h_BChw)) + pn_next = self.v_patch_nums[si + 1] + next_scales.append( + F.interpolate(f_hat, size=(pn_next, pn_next), mode="area") + .view(B, C, -1) + .transpose(1, 2) + ) + return ( + torch.cat(next_scales, dim=1) if len(next_scales) else None + ) # cat BlCs to BLC, this should be float32 + + # ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input ===================== + def get_next_autoregressive_input( + self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference + HW = self.v_patch_nums[-1] + if si != SN - 1: + h = self.quant_resi[si / (SN - 1)]( + F.interpolate(h_BChw, size=(HW, HW), mode="bicubic") + ) # conv after upsample + f_hat.add_(h) + return f_hat, F.interpolate( + f_hat, + size=(self.v_patch_nums[si + 1], self.v_patch_nums[si + 1]), + mode="area", + ) + else: + h = self.quant_resi[si / (SN - 1)](h_BChw) + f_hat.add_(h) + return f_hat, f_hat + + +class Phi(nn.Conv2d): + def __init__(self, embed_dim, quant_resi): + ks = 3 + super().__init__( + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=ks, + stride=1, + padding=ks // 2, + ) + self.resi_ratio = abs(quant_resi) + + def forward(self, h_BChw): + return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_( + self.resi_ratio + ) + + +class PhiShared(nn.Module): + def __init__(self, qresi: Phi): + super().__init__() + self.qresi: Phi = qresi + + def __getitem__(self, _) -> Phi: + return self.qresi + + +class PhiPartiallyShared(nn.Module): + def __init__(self, qresi_ls: nn.ModuleList): + super().__init__() + self.qresi_ls = qresi_ls + K = len(qresi_ls) + self.ticks = ( + np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) + if K == 4 + else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K) + ) + + def __getitem__(self, at_from_0_to_1: float) -> Phi: + return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()] + + def extra_repr(self) -> str: + return f"ticks={self.ticks}" + + +class PhiNonShared(nn.ModuleList): + def __init__(self, qresi: List): + super().__init__(qresi) + # self.qresi = qresi + K = len(qresi) + self.ticks = ( + np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) + if K == 4 + else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K) + ) + + def __getitem__(self, at_from_0_to_1: float) -> Phi: + return super().__getitem__( + np.argmin(np.abs(self.ticks - at_from_0_to_1)).item() + ) + + def extra_repr(self) -> str: + return f"ticks={self.ticks}" diff --git a/src/vqvaes/var/var_vq.py b/src/vqvaes/var/var_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..8685cfc56640a96100766cc0def5b78289f05695 --- /dev/null +++ b/src/vqvaes/var/var_vq.py @@ -0,0 +1,175 @@ +""" +References: +- VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110 +- GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213 +- VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14 +""" + +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn + +from .basic_vae import Decoder, Encoder +from .quant import VectorQuantizer2 + + +class VQVAE(nn.Module): + def __init__( + self, + vocab_size=4096, + z_channels=32, + ch=128, + dropout=0.0, + beta=0.25, # commitment loss weight + using_znorm=False, # whether to normalize when computing the nearest neighbors + quant_conv_ks=3, # quant conv kernel size + quant_resi=0.5, # 0.5 means \phi(x) = 0.5conv(x) + (1-0.5)x + share_quant_resi=4, # use 4 \phi layers for K scales: partially-shared \phi + default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums) + v_patch_nums=( + 1, + 2, + 3, + 4, + 5, + 6, + 8, + 10, + 13, + 16, + ), # number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k] + test_mode=True, + ): + super().__init__() + self.test_mode = test_mode + self.V, self.Cvae = vocab_size, z_channels + # ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml + ddconfig = dict( + dropout=dropout, + ch=ch, + z_channels=z_channels, + in_channels=3, + ch_mult=(1, 1, 2, 2, 4), + num_res_blocks=2, # from vq-f16/config.yaml above + using_sa=True, + using_mid_sa=True, # from vq-f16/config.yaml above + # resamp_with_conv=True, # always True, removed. + ) + ddconfig.pop("double_z", None) # only KL-VAE should use double_z=True + self.encoder = Encoder(double_z=False, **ddconfig) + self.decoder = Decoder(**ddconfig) + + self.vocab_size = vocab_size + self.downsample = 2 ** (len(ddconfig["ch_mult"]) - 1) + self.quantize: VectorQuantizer2 = VectorQuantizer2( + vocab_size=vocab_size, + Cvae=self.Cvae, + using_znorm=using_znorm, + beta=beta, + default_qresi_counts=default_qresi_counts, + v_patch_nums=v_patch_nums, + quant_resi=quant_resi, + share_quant_resi=share_quant_resi, + ) + self.quant_conv = torch.nn.Conv2d( + self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks // 2 + ) + self.post_quant_conv = torch.nn.Conv2d( + self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks // 2 + ) + + if self.test_mode: + self.eval() + [p.requires_grad_(False) for p in self.parameters()] + + # ===================== `forward` is only used in VAE training ===================== + def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss + VectorQuantizer2.forward + quanz = self.quant_conv(self.encoder(inp)) + img_tok = self.quantize.f_to_idxBl_or_fhat(quanz, to_fhat=False) + decoded = self.idxBl_to_img(img_tok, same_shape=True) + return decoded[-1], img_tok[-1], quanz + + # ===================== `forward` is only used in VAE training ===================== + + def fhat_to_img(self, f_hat: torch.Tensor): + return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) + + def img_to_idxBl( + self, + inp_img_no_grad: torch.Tensor, + v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, + ) -> List[torch.LongTensor]: # return List[Bl] + f = self.quant_conv(self.encoder(inp_img_no_grad)) + return self.quantize.f_to_idxBl_or_fhat( + f, to_fhat=False, v_patch_nums=v_patch_nums + ) + + def idxBl_to_img( + self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False + ) -> Union[List[torch.Tensor], torch.Tensor]: + B = ms_idx_Bl[0].shape[0] + ms_h_BChw = [] + for idx_Bl in ms_idx_Bl: + l = idx_Bl.shape[1] + pn = round(l**0.5) + ms_h_BChw.append( + self.quantize.embedding(idx_Bl) + .transpose(1, 2) + .view(B, self.Cvae, pn, pn) + ) + return self.embed_to_img( + ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one + ) + + def embed_to_img( + self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False + ) -> Union[List[torch.Tensor], torch.Tensor]: + if last_one: + return self.decoder( + self.post_quant_conv( + self.quantize.embed_to_fhat( + ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True + ) + ) + ).clamp_(-1, 1) + else: + return [ + self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) + for f_hat in self.quantize.embed_to_fhat( + ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False + ) + ] + + def img_to_reconstructed_img( + self, + x, + v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, + last_one=False, + ) -> List[torch.Tensor]: + f = self.quant_conv(self.encoder(x)) + ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat( + f, to_fhat=False, v_patch_nums=v_patch_nums + ) + return ( + self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1), + ls_f_hat_BChw, + f, + ) + + # if last_one: + # return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1) + # else: + # return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in ls_f_hat_BChw] + + def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False): + if ( + "quantize.ema_vocab_hit_SV" in state_dict + and state_dict["quantize.ema_vocab_hit_SV"].shape[0] + != self.quantize.ema_vocab_hit_SV.shape[0] + ): + state_dict["quantize.ema_vocab_hit_SV"] = self.quantize.ema_vocab_hit_SV + return super().load_state_dict( + state_dict=state_dict, strict=strict, assign=assign + ) diff --git a/src/vqvaes/xqgan/cliploss.py b/src/vqvaes/xqgan/cliploss.py new file mode 100644 index 0000000000000000000000000000000000000000..bd1dedd29d96b27b4d960793792ae36a4f775ec4 --- /dev/null +++ b/src/vqvaes/xqgan/cliploss.py @@ -0,0 +1,478 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F + +try: + import torch.distributed.nn + from torch import distributed as dist + + has_distributed = True +except ImportError: + has_distributed = False + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def gather_features( + image_features, + text_features, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, + use_horovod=False, +): + assert ( + has_distributed + ), "torch.distributed did not import correctly, please use a PyTorch version with support." + if use_horovod: + assert hvd is not None, "Please install horovod" + if gather_with_grad: + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + else: + with torch.no_grad(): + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features = list( + all_image_features.chunk(world_size, dim=0) + ) + gathered_text_features = list( + all_text_features.chunk(world_size, dim=0) + ) + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_image_features = torch.cat( + torch.distributed.nn.all_gather(image_features), dim=0 + ) + all_text_features = torch.cat( + torch.distributed.nn.all_gather(text_features), dim=0 + ) + else: + gathered_image_features = [ + torch.zeros_like(image_features) for _ in range(world_size) + ] + gathered_text_features = [ + torch.zeros_like(text_features) for _ in range(world_size) + ] + dist.all_gather(gathered_image_features, image_features) + dist.all_gather(gathered_text_features, text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + + return all_image_features, all_text_features + + +class ClipLoss(nn.Module): + + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def get_ground_truth(self, device, num_logits) -> torch.Tensor: + # calculated ground-truth and cache if enabled + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + return labels + + def get_logits(self, image_features, text_features, logit_scale): + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, + text_features, + self.local_loss, + self.gather_with_grad, + self.rank, + self.world_size, + self.use_horovod, + ) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = ( + logit_scale * all_image_features @ all_text_features.T + ) + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + return logits_per_image, logits_per_text + + def forward(self, image_features, text_features, logit_scale, output_dict=False): + device = image_features.device + logits_per_image, logits_per_text = self.get_logits( + image_features, text_features, logit_scale + ) + + labels = self.get_ground_truth(device, logits_per_image.shape[0]) + + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + return {"contrastive_loss": total_loss} if output_dict else total_loss + + +class CoCaLoss(ClipLoss): + def __init__( + self, + caption_loss_weight, + clip_loss_weight, + pad_id=0, # pad_token for open_clip custom tokenizer + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=cache_labels, + rank=rank, + world_size=world_size, + use_horovod=use_horovod, + ) + + self.clip_loss_weight = clip_loss_weight + self.caption_loss_weight = caption_loss_weight + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + + def forward( + self, + image_features, + text_features, + logits, + labels, + logit_scale, + output_dict=False, + ): + + clip_loss = torch.tensor(0) + + if self.clip_loss_weight: + clip_loss = super().forward(image_features, text_features, logit_scale) + clip_loss = self.clip_loss_weight * clip_loss + + caption_loss = self.caption_loss( + logits.permute(0, 2, 1), + labels, + ) + caption_loss = caption_loss * self.caption_loss_weight + + if output_dict: + return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} + + return clip_loss, caption_loss + + +class DistillClipLoss(ClipLoss): + + def dist_loss(self, teacher_logits, student_logits): + return ( + -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)) + .sum(dim=1) + .mean(dim=0) + ) + + def forward( + self, + image_features, + text_features, + logit_scale, + dist_image_features, + dist_text_features, + dist_logit_scale, + output_dict=False, + ): + logits_per_image, logits_per_text = self.get_logits( + image_features, text_features, logit_scale + ) + + dist_logits_per_image, dist_logits_per_text = self.get_logits( + dist_image_features, dist_text_features, dist_logit_scale + ) + + labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) + + contrastive_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + distill_loss = ( + self.dist_loss(dist_logits_per_image, logits_per_image) + + self.dist_loss(dist_logits_per_text, logits_per_text) + ) / 2 + + if output_dict: + return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} + + return contrastive_loss, distill_loss + + +def neighbour_exchange(from_rank, to_rank, tensor, group=None): + tensor_recv = torch.zeros_like(tensor) + send_op = torch.distributed.P2POp( + torch.distributed.isend, + tensor, + to_rank, + group=group, + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_recv, + from_rank, + group=group, + ) + reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) + for req in reqs: + req.wait() + return tensor_recv + + +def neighbour_exchange_bidir( + left_rank, right_rank, tensor_to_left, tensor_to_right, group=None +): + tensor_from_left = torch.zeros_like(tensor_to_right) + tensor_from_right = torch.zeros_like(tensor_to_left) + send_op_left = torch.distributed.P2POp( + torch.distributed.isend, + tensor_to_left, + left_rank, + group=group, + ) + send_op_right = torch.distributed.P2POp( + torch.distributed.isend, + tensor_to_right, + right_rank, + group=group, + ) + recv_op_left = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_from_left, + left_rank, + group=group, + ) + recv_op_right = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_from_right, + right_rank, + group=group, + ) + reqs = torch.distributed.batch_isend_irecv( + [send_op_right, send_op_left, recv_op_right, recv_op_left] + ) + for req in reqs: + req.wait() + return tensor_from_right, tensor_from_left + + +class NeighbourExchange(torch.autograd.Function): + @staticmethod + def forward(ctx, from_rank, to_rank, group, tensor): + ctx.group = group + ctx.from_rank = from_rank + ctx.to_rank = to_rank + return neighbour_exchange(from_rank, to_rank, tensor, group=group) + + @staticmethod + def backward(ctx, grad_output): + return (None, None, None) + ( + NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output), + ) + + +def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None): + return NeighbourExchange.apply(from_rank, to_rank, group, tensor) + + +class NeighbourExchangeBidir(torch.autograd.Function): + @staticmethod + def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right): + ctx.group = group + ctx.left_rank = left_rank + ctx.right_rank = right_rank + return neighbour_exchange_bidir( + left_rank, right_rank, tensor_to_left, tensor_to_right, group=group + ) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None, None, None) + NeighbourExchangeBidir.apply( + ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs + ) + + +def neighbour_exchange_bidir_with_grad( + left_rank, right_rank, tensor_to_left, tensor_to_right, group=None +): + return NeighbourExchangeBidir.apply( + left_rank, right_rank, group, tensor_to_left, tensor_to_right + ) + + +class SigLipLoss(nn.Module): + """Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343 + + @article{zhai2023sigmoid, + title={Sigmoid loss for language image pre-training}, + author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas}, + journal={arXiv preprint arXiv:2303.15343}, + year={2023} + } + """ + + def __init__( + self, + cache_labels=False, + rank=0, + world_size=1, + bidir=True, + use_horovod=False, + ): + super().__init__() + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + assert not use_horovod # FIXME need to look at hvd ops for ring transfers + self.use_horovod = use_horovod + self.bidir = bidir + + # cache state FIXME cache not currently used, worthwhile? + self.prev_num_logits = 0 + self.labels = {} + + def get_ground_truth( + self, device, dtype, num_logits, negative_only=False + ) -> torch.Tensor: + labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype) + if not negative_only: + labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels + return labels + + def get_logits(self, image_features, text_features, logit_scale, logit_bias=None): + logits = logit_scale * image_features @ text_features.T + if logit_bias is not None: + logits += logit_bias + return logits + + def _loss( + self, + image_features, + text_features, + logit_scale, + logit_bias=None, + negative_only=False, + ): + logits = self.get_logits(image_features, text_features, logit_scale, logit_bias) + labels = self.get_ground_truth( + image_features.device, + image_features.dtype, + image_features.shape[0], + negative_only=negative_only, + ) + loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0] + return loss + + def forward( + self, image_features, text_features, logit_scale, logit_bias, output_dict=False + ): + loss = self._loss(image_features, text_features, logit_scale, logit_bias) + + if self.world_size > 1: + # exchange text features w/ neighbour world_size - 1 times + right_rank = (self.rank + 1) % self.world_size + left_rank = (self.rank - 1 + self.world_size) % self.world_size + if self.bidir: + text_features_to_right = text_features_to_left = text_features + num_bidir, remainder = divmod(self.world_size - 1, 2) + for i in range(num_bidir): + text_features_recv = neighbour_exchange_bidir_with_grad( + left_rank, + right_rank, + text_features_to_left, + text_features_to_right, + ) + + for f in text_features_recv: + loss += self._loss( + image_features, + f, + logit_scale, + logit_bias, + negative_only=True, + ) + text_features_to_left, text_features_to_right = text_features_recv + + if remainder: + text_features_recv = neighbour_exchange_with_grad( + left_rank, right_rank, text_features_to_right + ) + + loss += self._loss( + image_features, + text_features_recv, + logit_scale, + logit_bias, + negative_only=True, + ) + else: + text_features_to_right = text_features + for i in range(self.world_size - 1): + text_features_from_left = neighbour_exchange_with_grad( + left_rank, right_rank, text_features_to_right + ) + + loss += self._loss( + image_features, + text_features_from_left, + logit_scale, + logit_bias, + negative_only=True, + ) + text_features_to_right = text_features_from_left + + return {"contrastive_loss": loss} if output_dict else loss diff --git a/src/vqvaes/xqgan/diffaug.py b/src/vqvaes/xqgan/diffaug.py new file mode 100644 index 0000000000000000000000000000000000000000..d74d1feecfbfec535fbc43b549da0fe8d2078bcf --- /dev/null +++ b/src/vqvaes/xqgan/diffaug.py @@ -0,0 +1,168 @@ +# this file is taken from https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/training/diffaug.py +import math + +import torch +import torch.nn.functional as F + + +def load_png(file_name: str): + from torchvision.io import read_image + + return ( + read_image(file_name).float().div_(255).mul_(2).sub_(1).unsqueeze(0) + ) # to [-1, 1] + + +def show(tensor): # from [-1, 1] + from torchvision.utils import make_grid + from torchvision.transforms.functional import to_pil_image + + if tensor.shape[0] == 1: + tensor = tensor[0] + if tensor.ndim == 3: + to_pil_image(tensor.add(1).div_(2).clamp_(0, 1).detach().cpu()).convert( + "RGB" + ).show() + else: + to_pil_image( + make_grid(tensor.add(1).div_(2).clamp_(0, 1).detach().cpu()) + ).convert("RGB").show() + + +class DiffAug(object): + def __init__(self, prob=1.0, cutout=0.2): # todo: swin ratio = 0.5, T&XL = 0.2 + self.grids = {} + self.prob = abs(prob) + self.using_cutout = prob > 0 + self.cutout = cutout + self.img_channels = -1 + self.last_blur_radius = -1 + self.last_blur_kernel_h = self.last_blur_kernel_w = None + + def get_grids(self, B, x, y, dev): + if (B, x, y) in self.grids: + return self.grids[(B, x, y)] + + self.grids[(B, x, y)] = ret = torch.meshgrid( + torch.arange(B, dtype=torch.long, device=dev), + torch.arange(x, dtype=torch.long, device=dev), + torch.arange(y, dtype=torch.long, device=dev), + indexing="ij", + ) + return ret + + def aug(self, BCHW: torch.Tensor, warmup_blur_schedule: float = 0) -> torch.Tensor: + # warmup blurring + if BCHW.dtype != torch.float32: + BCHW = BCHW.float() + if warmup_blur_schedule > 0: + self.img_channels = BCHW.shape[1] + sigma0 = (BCHW.shape[-2] * 0.5) ** 0.5 + sigma = sigma0 * warmup_blur_schedule + blur_radius = math.floor(sigma * 3) # 3-sigma is enough for Gaussian + if blur_radius >= 1: + if self.last_blur_radius != blur_radius: + self.last_blur_radius = blur_radius + gaussian = torch.arange( + -blur_radius, + blur_radius + 1, + dtype=torch.float32, + device=BCHW.device, + ) + gaussian = gaussian.mul_(1 / sigma).square_().neg_().exp2_() + gaussian.div_(gaussian.sum()) # normalize + self.last_blur_kernel_h = ( + gaussian.view(1, 1, 2 * blur_radius + 1, 1) + .repeat(self.img_channels, 1, 1, 1) + .contiguous() + ) + self.last_blur_kernel_w = ( + gaussian.view(1, 1, 1, 2 * blur_radius + 1) + .repeat(self.img_channels, 1, 1, 1) + .contiguous() + ) + + BCHW = F.pad( + BCHW, + [blur_radius, blur_radius, blur_radius, blur_radius], + mode="reflect", + ) + BCHW = F.conv2d( + input=BCHW, + weight=self.last_blur_kernel_h, + bias=None, + groups=self.img_channels, + ) + BCHW = F.conv2d( + input=BCHW, + weight=self.last_blur_kernel_w, + bias=None, + groups=self.img_channels, + ) + # BCHW = filter2d(BCHW, f.div_(f.sum())) # no need to specify padding (filter2d will add padding in itself based on filter size) + + if self.prob < 1e-6: + return BCHW + trans, color, cut = torch.rand(3) <= self.prob + trans, color, cut = trans.item(), color.item(), cut.item() + B, dev = BCHW.shape[0], BCHW.device + rand01 = torch.rand(7, B, 1, 1, device=dev) if (trans or color or cut) else None + + raw_h, raw_w = BCHW.shape[-2:] + if trans: + ratio = 0.125 + delta_h = round(raw_h * ratio) + delta_w = round(raw_w * ratio) + translation_h = ( + rand01[0].mul(delta_h + delta_h + 1).floor().long() - delta_h + ) + translation_w = ( + rand01[1].mul(delta_w + delta_w + 1).floor().long() - delta_w + ) + # translation_h = torch.randint(-delta_h, delta_h+1, size=(B, 1, 1), device=dev) + # translation_w = torch.randint(-delta_w, delta_w+1, size=(B, 1, 1), device=dev) + + grid_B, grid_h, grid_w = self.get_grids(B, raw_h, raw_w, dev) + grid_h = (grid_h + translation_h).add_(1).clamp_(0, raw_h + 1) + grid_w = (grid_w + translation_w).add_(1).clamp_(0, raw_w + 1) + bchw_pad = F.pad(BCHW, [1, 1, 1, 1, 0, 0, 0, 0]) + BCHW = ( + bchw_pad.permute(0, 2, 3, 1) + .contiguous()[grid_B, grid_h, grid_w] + .permute(0, 3, 1, 2) + .contiguous() + ) + + if color: + BCHW = BCHW.add(rand01[2].unsqueeze(-1).sub(0.5)) + # BCHW.add_(torch.rand(B, 1, 1, 1, dtype=BCHW.dtype, device=dev).sub_(0.5)) + bchw_mean = BCHW.mean(dim=1, keepdim=True) + BCHW = ( + BCHW.sub(bchw_mean).mul(rand01[3].unsqueeze(-1).mul(2)).add_(bchw_mean) + ) + # BCHW.sub_(bchw_mean).mul_(torch.rand(B, 1, 1, 1, dtype=BCHW.dtype, device=dev).mul_(2)).add_(bchw_mean) + bchw_mean = BCHW.mean(dim=(1, 2, 3), keepdim=True) + BCHW = ( + BCHW.sub(bchw_mean) + .mul(rand01[4].unsqueeze(-1).add(0.5)) + .add_(bchw_mean) + ) + # BCHW.sub_(bchw_mean).mul_(torch.rand(B, 1, 1, 1, dtype=BCHW.dtype, device=dev).add_(0.5)).add_(bchw_mean) + + if self.using_cutout and cut: + ratio = self.cutout # todo: styleswin ratio = 0.5, T&XL = 0.2 + cutout_h = round(raw_h * ratio) + cutout_w = round(raw_w * ratio) + offset_h = rand01[5].mul(raw_h + (1 - cutout_h % 2)).floor().long() + offset_w = rand01[6].mul(raw_w + (1 - cutout_w % 2)).floor().long() + # offset_h = torch.randint(0, raw_h + (1 - cutout_h % 2), size=(B, 1, 1), device=dev) + # offset_w = torch.randint(0, raw_w + (1 - cutout_w % 2), size=(B, 1, 1), device=dev) + + grid_B, grid_h, grid_w = self.get_grids(B, cutout_h, cutout_w, dev) + grid_h = (grid_h + offset_h).sub_(cutout_h // 2).clamp(min=0, max=raw_h - 1) + grid_w = (grid_w + offset_w).sub_(cutout_w // 2).clamp(min=0, max=raw_w - 1) + mask = torch.ones(B, raw_h, raw_w, dtype=BCHW.dtype, device=dev) + mask[grid_B, grid_h, grid_w] = 0 + BCHW = BCHW.mul(mask.unsqueeze(1)) + + return BCHW diff --git a/src/vqvaes/xqgan/dino_enc/__init__.py b/src/vqvaes/xqgan/dino_enc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed82566609390db490376c63ddced522761d736e --- /dev/null +++ b/src/vqvaes/xqgan/dino_enc/__init__.py @@ -0,0 +1,2 @@ +from .dinov2 import DINOv2Encoder, DINOv2Decoder, DINOv2Decoder_ +from .vision_transformer import Attention, RoPEAttention diff --git a/src/vqvaes/xqgan/dino_enc/dinov2.py b/src/vqvaes/xqgan/dino_enc/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..1cdfaebd53415567f8d2fed26c1dfe110892b8ed --- /dev/null +++ b/src/vqvaes/xqgan/dino_enc/dinov2.py @@ -0,0 +1,643 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +import peft +from timm.models import create_model, safe_model_name +from timm.layers import trunc_normal_, Mlp + +import sys + +from .to_pixel import ToPixel + +from .vision_transformer import Attention, RoPEAttention + +import math + + +class DINOv2Encoder(nn.Module): + def __init__( + self, + in_channels=3, + num_latent_tokens=32, + use_attn_mask=False, + model_name="vit_small_patch14_dinov2.lvd142m", + model_kwargs={ + "img_size": 224, + "patch_size": 14, + "drop_path_rate": 0.0, + }, + pretrained=True, + tuning_method="lora", + tuning_kwargs={"r": 8}, + abs_pos_embed=False, + product_quant=1, + ): + super().__init__() + + assert model_name in [ + "vit_small_patch14_dinov2.lvd142m", + "vit_base_patch14_dinov2.lvd142m", + "vit_large_patch14_dinov2.lvd142m", + "vit_giant_patch14_dinov2.lvd142m", + "vit_small_patch14_reg4_dinov2.lvd142m", + "vit_base_patch14_reg4_dinov2.lvd142m", + "vit_large_patch14_reg4_dinov2.lvd142m", + "vit_giant_patch14_reg4_dinov2.lvd142m", + ], f"{model_name} not found" + + # parameters + self.num_latent_tokens = num_latent_tokens + self.use_attn_mask = use_attn_mask + self.product_quant = product_quant + + # load model + model = create_model(model_name, pretrained=pretrained, **model_kwargs) + # model = vit_base_patch14_dinov2(pretrained=pretrained, **model_kwargs) + + self.embed_dim = model.embed_dim + # get num of img tokens + self.num_img_tokens = model.patch_embed.num_patches + self.num_prefix_tokens = model.num_prefix_tokens + self.abs_pos_embed = abs_pos_embed + + # tuning method + if tuning_method == "full": + # doing nothing + self.model = model + elif tuning_method == "lora": + # lora tuning the backbone + # config = peft.LoraConfig(target_modules=r".*\.mlp\.fc\d", modules_to_save=['patch_embed.proj', 'patch_embed.norm', 'norm'], **tuning_kwargs) + config = peft.LoraConfig( + target_modules=r".*\.mlp\.fc\d", + modules_to_save=["norm"], + **tuning_kwargs, + ) + # config = peft.LoraConfig(target_modules=r".*\.mlp\.fc\d|.*\.qkv|.*\.proj", modules_to_save=['norm'], **tuning_kwargs) + self.model = peft.get_peft_model(model, config) + # self.model.base_model.model.pos_embed.requires_grad = True + self.model.print_trainable_parameters() + elif tuning_method == "lora_unfreeze_patch_embed": + config = peft.LoraConfig( + target_modules=r".*\.mlp\.fc\d", + modules_to_save=["patch_embed.proj", "patch_embed.norm", "norm"], + **tuning_kwargs, + ) + self.model = peft.get_peft_model(model, config) + # self.model.base_model.model.pos_embed.requires_grad = True + self.model.print_trainable_parameters() + elif tuning_method == "lat_lora": + from models.peft_models.lora import LatentLoRALinear + + config = peft.LoraConfig( + target_modules=r".*\.mlp\.fc\d|.*\.qkv|.*\.proj", + modules_to_save=["norm"], + **tuning_kwargs, + ) + config._register_custom_module({nn.Linear: LatentLoRALinear}) + self.model = peft.get_peft_model(model, config) + self.use_attn_mask = True # force to use attn mask + self.model.print_trainable_parameters() + elif tuning_method == "frozen": + for param in model.parameters(): + param.requires_grad = False + self.model = model + + if self.num_latent_tokens: + # latent tokens + self.latent_tokens = nn.Parameter( + torch.zeros(1, self.num_latent_tokens, model.embed_dim) + ) + nn.init.normal_(self.latent_tokens, std=1e-6) + + if self.abs_pos_embed: + if self.product_quant > 1: + self.lvl_embed = nn.Embedding( + 1 + self.product_quant, model.embed_dim + ) + patch_size = model_kwargs["patch_size"] + nn.init.trunc_normal_( + self.lvl_embed.weight.data, + mean=0, + std=math.sqrt(1 / model.embed_dim / 3), + ) + lvl1LC = torch.cat( + [ + torch.full((patch_size * patch_size + 1,), 0), + ] + + [ + torch.full( + (self.num_latent_tokens // self.product_quant,), i + 1 + ) + for i in range(self.product_quant) + ] + ).view(1, -1) + else: + self.lvl_embed = nn.Embedding(2, model.embed_dim) + patch_size = model_kwargs["patch_size"] + nn.init.trunc_normal_( + self.lvl_embed.weight.data, + mean=0, + std=math.sqrt(1 / model.embed_dim / 3), + ) + lvl1LC = torch.cat( + [ + torch.full((patch_size * patch_size + 1,), 0), + torch.full((self.num_latent_tokens,), 1), + ] + ).view(1, -1) + self.register_buffer("lvl1LC", lvl1LC) + else: + self.latent_pos_embed = nn.Parameter( + torch.zeros(1, self.num_latent_tokens, model.embed_dim) + ) + trunc_normal_(self.latent_pos_embed, std=0.02) + + if self.use_attn_mask: + # create attn mask + total_length = ( + self.num_img_tokens + + self.num_latent_tokens + + self.num_prefix_tokens + ) + attn_mask = torch.zeros((total_length, total_length)) + attn_mask[ + : self.num_prefix_tokens + self.num_img_tokens, + -self.num_latent_tokens :, + ] = -torch.inf + attn_mask = attn_mask.view(1, 1, total_length, total_length) + print(attn_mask) + self.register_buffer("attn_mask", attn_mask) + + def finetine(self, tuning_method, tuning_kwargs={"r": 8}): + if tuning_method == "full": + return + elif tuning_method == "lora": + # lora tuning the backbone + # config = peft.LoraConfig(target_modules=r".*\.mlp\.fc\d", modules_to_save=['patch_embed.proj', 'patch_embed.norm', 'norm'], **tuning_kwargs) + config = peft.LoraConfig( + target_modules=r".*\.mlp\.fc\d", + modules_to_save=["norm"], + **tuning_kwargs, + ) + # config = peft.LoraConfig(target_modules=r".*\.mlp\.fc\d|.*\.qkv|.*\.proj", modules_to_save=['norm'], **tuning_kwargs) + self.model = peft.get_peft_model(self.model, config) + # self.model.base_model.model.pos_embed.requires_grad = True + self.model.print_trainable_parameters() + elif tuning_method == "lora_unfreeze_patch_embed": + config = peft.LoraConfig( + target_modules=r".*\.mlp\.fc\d", + modules_to_save=["patch_embed.proj", "patch_embed.norm", "norm"], + **tuning_kwargs, + ) + self.model = peft.get_peft_model(self.model, config) + # self.model.base_model.model.pos_embed.requires_grad = True + self.model.print_trainable_parameters() + elif tuning_method == "lat_lora": + from models.peft_models.lora import LatentLoRALinear + + config = peft.LoraConfig( + target_modules=r".*\.mlp\.fc\d|.*\.qkv|.*\.proj", + modules_to_save=["norm"], + **tuning_kwargs, + ) + config._register_custom_module({nn.Linear: LatentLoRALinear}) + self.model = peft.get_peft_model(self.model, config) + self.use_attn_mask = True # force to use attn mask + self.model.print_trainable_parameters() + elif tuning_method == "frozen": + for param in self.model.parameters(): + param.requires_grad = False + + def no_weight_decay(self): + return [ + "model.pos_embed", + "model.cls_token", + "model.dist_token", + "latent_tokens", + "latent_pos_embed", + ] + + def forward(self, x, masks=None): + + # get tokens + x = self.model.patch_embed(x) + + with torch.cuda.amp.autocast(enabled=False): + x = self.model._pos_embed(x) + x = self.model.patch_drop(x) + + if self.num_latent_tokens: + # insert latent tokens + z = self.latent_tokens.expand(x.size(0), -1, -1) + if self.abs_pos_embed: + if self.product_quant > 1: + H, W = int( + math.sqrt(self.num_latent_tokens // self.product_quant) + ), int(math.sqrt(self.num_latent_tokens // self.product_quant)) + assert H * W == self.num_latent_tokens // self.product_quant + z = z.view(x.size(0), self.product_quant * H, W, -1) + z_list = z.chunk(chunks=self.product_quant, dim=1) + z_list = [ + self.model._pos_embed(z)[ + :, + 1:, + ] + for z in z_list + ] # remove cls token + x = torch.cat( + [ + x, + ] + + z_list, + dim=1, + ) + x += self.lvl_embed(self.lvl1LC.expand(x.size(0), -1)) + else: + H, W = int(math.sqrt(self.num_latent_tokens)), int( + math.sqrt(self.num_latent_tokens) + ) + assert H * W == self.num_latent_tokens + z = z.view(x.size(0), H, W, -1) + z = self.model._pos_embed(z)[ + :, + 1:, + ] # remove cls token + x = torch.cat([x, z], dim=1) + x += self.lvl_embed(self.lvl1LC.expand(x.size(0), -1)) + else: + x = torch.cat([x, z + self.latent_pos_embed], dim=1) + # get dtype + temp = x.new_ones(8, 8) + main_type = torch.matmul(temp, temp).dtype + x = x.to(main_type) + + # pre layer norm + x = self.model.norm_pre(x) + + # forward backbones + if self.use_attn_mask: + for blk in self.model.blocks: + x = blk(x, self.attn_mask) + else: + x = self.model.blocks(x) + x = self.model.norm(x) + + if self.num_latent_tokens: + # get z tokens as out + out = x[:, -self.num_latent_tokens :] + else: + # get img tokens as out + out = x[:, self.num_prefix_tokens :] + return out + + +class DINOv2Decoder(nn.Module): + def __init__( + self, + in_channels=3, + model_name="vit_small_patch14_dinov2.lvd142m", + model_kwargs={"img_size": 224, "patch_size": 14, "drop_path_rate": 0.0}, + pretrained=True, + tuning_method="lora", + tuning_kwargs={"r": 8}, + num_latent_tokens=32, + to_pixel="linear", + use_rope=False, + cond_latent=False, + abs_pos_embed=False, + ): + super().__init__() + + assert model_name in [ + "vit_small_patch14_dinov2.lvd142m", + "vit_base_patch14_dinov2.lvd142m", + "vit_large_patch14_dinov2.lvd142m", + "vit_giant_patch14_dinov2.lvd142m", + "vit_small_patch14_reg4_dinov2.lvd142m", + "vit_base_patch14_reg4_dinov2.lvd142m", + "vit_large_patch14_reg4_dinov2.lvd142m", + "vit_giant_patch14_reg4_dinov2.lvd142m", + ] + + # load model + if use_rope: + print("using RoPEAttention") + attn_layer = RoPEAttention + else: + attn_layer = Attention + + model_kwargs["num_latent_tokens"] = num_latent_tokens + model_kwargs["attn_layer"] = attn_layer + model = create_model(model_name, pretrained=pretrained, **model_kwargs) + self.use_rope = use_rope + self.embed_dim = model.embed_dim + # get num of img tokens + self.num_img_tokens = model.patch_embed.num_patches + self.num_prefix_tokens = model.num_prefix_tokens + self.num_latent_tokens = num_latent_tokens + + self.abs_pos_embed = abs_pos_embed + + # for n, m in model.named_modules(): + # print(n, type(m)) + + # tuning method + if tuning_method == "full": + # doing nothing + self.model = model + elif tuning_method == "lora": + # lora tuning the backbone + # config = peft.LoraConfig(target_modules=r".*\.mlp\.fc\d", modules_to_save=['patch_embed.proj', 'patch_embed.norm', 'norm'], **tuning_kwargs) + config = peft.LoraConfig( + target_modules=r".*\.mlp\.fc\d", + modules_to_save=["norm"], + **tuning_kwargs, + ) + self.model = peft.get_peft_model(model, config) + # self.model.base_model.model.pos_embed.requires_grad = True + self.model.print_trainable_parameters() + elif tuning_method == "frozen": + for param in model.parameters(): + param.requires_grad = False + + # latent tokens + self.mask_token = nn.Parameter(torch.zeros(1, 1, model.embed_dim)) + # self.mask_token = nn.Parameter(torch.zeros(1, self.num_img_tokens, model.embed_dim)) + nn.init.normal_(self.mask_token, std=1e-6) + # self.mask_token = nn.Parameter(torch.clone(model.cls_token)) + + if not self.use_rope: + if self.abs_pos_embed: + self.lvl_embed = nn.Embedding(2, model.embed_dim) + patch_size = model_kwargs["patch_size"] + nn.init.trunc_normal_( + self.lvl_embed.weight.data, + mean=0, + std=math.sqrt(1 / model.embed_dim / 3), + ) + lvl1LC = torch.cat( + [ + torch.full((patch_size * patch_size + 1,), 0), + torch.full((self.num_latent_tokens + 1,), 1), + ] + ).view(1, -1) + self.register_buffer("lvl1LC", lvl1LC) + else: + self.latent_pos_embed = nn.Parameter( + torch.zeros(1, self.num_latent_tokens, model.embed_dim) + ) + trunc_normal_(self.latent_pos_embed, std=0.02) + # from timm.models.vision_transformer import resize_pos_embed + # latent_pos_embed = resize_pos_embed(model.pos_embed, torch.zeros(1, self.num_latent_tokens, model.embed_dim), 0) + # self.latent_pos_embed = nn.Parameter(latent_pos_embed) + + # to pixel + self.to_pixel = ToPixel( + to_pixel=to_pixel, + img_size=model_kwargs["img_size"], + in_channels=in_channels, + in_dim=model.embed_dim, + patch_size=model_kwargs["patch_size"], + ) + + # latent initial as pooled dino feature + self.cond_latent = cond_latent + if self.cond_latent: + self.mlp1 = Mlp(model.embed_dim, model.embed_dim, norm_layer=nn.LayerNorm) + self.mlp2 = Mlp(model.embed_dim, model.embed_dim, norm_layer=nn.LayerNorm) + self.norm1 = nn.LayerNorm(model.embed_dim) + + del self.model.patch_embed.proj.bias + del self.model.patch_embed.proj.weight + + def finetine(self, tuning_method, tuning_kwargs={"r": 8}): + if tuning_method == "full": + # doing nothing + return + elif tuning_method == "lora": + # lora tuning the backbone + # config = peft.LoraConfig(target_modules=r".*\.mlp\.fc\d", modules_to_save=['patch_embed.proj', 'patch_embed.norm', 'norm'], **tuning_kwargs) + config = peft.LoraConfig( + target_modules=r".*\.mlp\.fc\d", + modules_to_save=["norm"], + **tuning_kwargs, + ) + self.model = peft.get_peft_model(self.model, config) + # self.model.base_model.model.pos_embed.requires_grad = True + self.model.print_trainable_parameters() + elif tuning_method == "frozen": + for param in self.model.parameters(): + param.requires_grad = False + + def no_weight_decay(self): + return [ + "model.pos_embed", + "model.cls_token", + "model.dist_token", + "mask_token", + "latent_pos_embed", + ] + + @property + def last_layer(self): + return self.to_pixel.model.weight + + def forward(self, z): + + # mask tokens + x = self.mask_token.expand(z.size(0), self.num_img_tokens, -1) + # x = self.mask_token.expand(z.size(0), -1, -1) + + with torch.cuda.amp.autocast(enabled=False): + if not self.use_rope: + x = self.model._pos_embed(x) + + if self.cond_latent: + ffnout = x + self.mlp1(torch.mean(z.float(), dim=1, keepdim=True)) + x = x + self.mlp2(self.norm1(ffnout)) + if self.abs_pos_embed: + H, W = int(math.sqrt(self.num_latent_tokens)), int( + math.sqrt(self.num_latent_tokens) + ) + assert H * W == self.num_latent_tokens + z = z.view(x.size(0), H, W, -1) + z = self.model._pos_embed(z) + else: + z = z + self.latent_pos_embed + else: + to_cat = [] + if self.model.cls_token is not None: + to_cat.append(self.model.cls_token.expand(x.shape[0], -1, -1)) + if self.model.reg_token is not None: + to_cat.append(self.model.reg_token.expand(x.shape[0], -1, -1)) + x = torch.cat(to_cat + [x], dim=1) + x = self.model.patch_drop(x) + + x = torch.cat([x, z], dim=1) + if self.abs_pos_embed: + x += self.lvl_embed(self.lvl1LC.expand(x.size(0), -1)) + # get dtype + temp = x.new_ones(8, 8) + main_type = torch.matmul(temp, temp).dtype + x = x.to(main_type) + + x = self.model.norm_pre(x) + + # forward backbones + x = self.model.blocks(x) + x = self.model.norm(x) + + # get img tokens as out + # x = x[:, z.size(1)+self.num_prefix_tokens:] + # out = x[:, self.num_prefix_tokens:] + # x = x[:, -self.num_img_tokens:] + # x = self.to_pixel(x) + x = x[:, self.num_prefix_tokens : self.num_img_tokens + self.num_prefix_tokens] + + out = self.to_pixel(x) + + return out + + +class DINOv2Decoder_(nn.Module): + def __init__( + self, + in_channels=3, + model_name="vit_small_patch14_dinov2.lvd142m", + model_kwargs={"img_size": 224, "patch_size": 14, "drop_path_rate": 0.0}, + pretrained=True, + tuning_method="lora", + tuning_kwargs={"r": 8}, + to_pixel="linear", + use_rope=False, + cond_latent=False, + ): + super().__init__() + + assert model_name in [ + "vit_small_patch14_dinov2.lvd142m", + "vit_base_patch14_dinov2.lvd142m", + "vit_large_patch14_dinov2.lvd142m", + "vit_giant_patch14_dinov2.lvd142m", + "vit_small_patch14_reg4_dinov2.lvd142m", + "vit_base_patch14_reg4_dinov2.lvd142m", + "vit_large_patch14_reg4_dinov2.lvd142m", + "vit_giant_patch14_reg4_dinov2.lvd142m", + ] + + # load model + if use_rope: + print("using RoPEAttention") + attn_layer = RoPEAttention + else: + attn_layer = Attention + + model_kwargs["attn_layer"] = attn_layer + model = create_model(model_name, pretrained=pretrained, **model_kwargs) + self.use_rope = use_rope + self.embed_dim = model.embed_dim + # get num of img tokens + self.num_img_tokens = model.patch_embed.num_patches + self.num_prefix_tokens = model.num_prefix_tokens + + # for n, m in model.named_modules(): + # print(n, type(m)) + + # tuning method + if tuning_method == "full": + # doing nothing + self.model = model + elif tuning_method == "lora": + # lora tuning the backbone + # config = peft.LoraConfig(target_modules=r".*\.mlp\.fc\d", modules_to_save=['patch_embed.proj', 'patch_embed.norm', 'norm'], **tuning_kwargs) + config = peft.LoraConfig( + target_modules=r".*\.mlp\.fc\d", + modules_to_save=["norm"], + **tuning_kwargs, + ) + self.model = peft.get_peft_model(model, config) + # self.model.base_model.model.pos_embed.requires_grad = True + self.model.print_trainable_parameters() + elif tuning_method == "frozen": + for param in model.parameters(): + param.requires_grad = False + + # from timm.models.vision_transformer import resize_pos_embed + # latent_pos_embed = resize_pos_embed(model.pos_embed, torch.zeros(1, self.num_latent_tokens, model.embed_dim), 0) + # self.latent_pos_embed = nn.Parameter(latent_pos_embed) + + # to pixel + self.to_pixel = ToPixel( + to_pixel=to_pixel, + img_size=model_kwargs["img_size"], + in_channels=in_channels, + in_dim=model.embed_dim, + patch_size=model_kwargs["patch_size"], + ) + + # latent initial as pooled dino feature + self.cond_latent = cond_latent + if self.cond_latent: + self.mlp1 = Mlp(model.embed_dim, model.embed_dim, norm_layer=nn.LayerNorm) + self.mlp2 = Mlp(model.embed_dim, model.embed_dim, norm_layer=nn.LayerNorm) + self.norm1 = nn.LayerNorm(model.embed_dim) + + def no_weight_decay(self): + return [ + "model.pos_embed", + "model.cls_token", + "model.dist_token", + "mask_token", + "latent_pos_embed", + ] + + def forward(self, x): + + with torch.cuda.amp.autocast(enabled=False): + x = self.model._pos_embed(x) + x = self.model.patch_drop(x) + # get dtype + temp = x.new_ones(8, 8) + main_type = torch.matmul(temp, temp).dtype + x = x.to(main_type) + + x = self.model.norm_pre(x) + + # forward backbones + x = self.model.blocks(x) + x = self.model.norm(x) + + # get img tokens as out + # x = x[:, z.size(1)+self.num_prefix_tokens:] + # out = x[:, self.num_prefix_tokens:] + # x = x[:, -self.num_img_tokens:] + # x = self.to_pixel(x) + x = x[:, self.num_prefix_tokens : self.num_img_tokens + self.num_prefix_tokens] + + out = self.to_pixel(x) + + return out + + +if __name__ == "__main__": + encoder = DINOv2Encoder( + model_name="vit_small_patch14_dinov2.lvd142m", + model_kwargs={"img_size": 256, "patch_size": 16, "drop_path_rate": 0.0}, + tuning_method="lat_lora", + tuning_kwargs={"r": 8}, + num_latent_tokens=32, + ) + decoder = DINOv2Decoder( + model_name="vit_small_patch14_dinov2.lvd142m", + model_kwargs={"img_size": 256, "patch_size": 16, "drop_path_rate": 0.0}, + tuning_method="full", + tuning_kwargs={"r": 8}, + num_latent_tokens=32, + use_rope=True, + ) + x = torch.randn(1, 3, 256, 256) + out = encoder(x) + out = decoder(out) + print(out.shape) diff --git a/src/vqvaes/xqgan/dino_enc/to_pixel.py b/src/vqvaes/xqgan/dino_enc/to_pixel.py new file mode 100644 index 0000000000000000000000000000000000000000..8a520f10b272d5ca86d8a141c94c1182e69f0f45 --- /dev/null +++ b/src/vqvaes/xqgan/dino_enc/to_pixel.py @@ -0,0 +1,112 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + + +class SineLayer(nn.Module): + """ + Paper: Implicit Neural Representation with Periodic Activ ation Function (SIREN) + """ + + def __init__( + self, in_features, out_features, bias=True, is_first=False, omega_0=30 + ): + super().__init__() + self.omega_0 = omega_0 + self.is_first = is_first + + self.in_features = in_features + self.linear = nn.Linear(in_features, out_features, bias=bias) + + self.init_weights() + + def init_weights(self): + with torch.no_grad(): + if self.is_first: + self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features) + else: + self.linear.weight.uniform_( + -np.sqrt(6 / self.in_features) / self.omega_0, + np.sqrt(6 / self.in_features) / self.omega_0, + ) + + def forward(self, input): + return torch.sin(self.omega_0 * self.linear(input)) + + +class ToPixel(nn.Module): + def __init__( + self, to_pixel="linear", img_size=256, in_channels=3, in_dim=512, patch_size=16 + ) -> None: + super().__init__() + self.to_pixel_name = to_pixel + self.patch_size = patch_size + self.num_patches = (img_size // patch_size) ** 2 + self.in_channels = in_channels + if to_pixel == "linear": + self.model = nn.Linear(in_dim, in_channels * patch_size * patch_size) + elif to_pixel == "conv": + self.model = nn.Sequential( + Rearrange("b (h w) c -> b c h w", h=img_size // patch_size), + nn.ConvTranspose2d( + in_dim, in_channels, kernel_size=patch_size, stride=patch_size + ), + ) + elif to_pixel == "siren": + self.model = nn.Sequential( + SineLayer(in_dim, in_dim * 2, is_first=True, omega_0=30.0), + SineLayer( + in_dim * 2, + img_size // patch_size * patch_size * in_channels, + is_first=False, + omega_0=30, + ), + ) + elif to_pixel == "identity": + self.model = nn.Identity() + else: + raise NotImplementedError + + def get_last_layer(self): + if self.to_pixel_name == "linear": + return self.model.weight + elif self.to_pixel_name == "siren": + return self.model[1].linear.weight + elif self.to_pixel_name == "conv": + return self.model[1].weight + else: + return None + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = self.patch_size + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + + def forward(self, x): + if self.to_pixel_name == "linear": + x = self.model(x) + x = self.unpatchify(x) + elif self.to_pixel_name == "siren": + x = self.model(x) + x = x.view( + x.shape[0], + self.in_channels, + self.patch_size * int(self.num_patches**0.5), + self.patch_size * int(self.num_patches**0.5), + ) + elif self.to_pixel_name == "conv": + x = self.model(x) + elif self.to_pixel_name == "identity": + pass + return x diff --git a/src/vqvaes/xqgan/dino_enc/vision_transformer.py b/src/vqvaes/xqgan/dino_enc/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb31eb47d03d42a7979adff2dab11c4ec3c38b9 --- /dev/null +++ b/src/vqvaes/xqgan/dino_enc/vision_transformer.py @@ -0,0 +1,4730 @@ +"""Vision Transformer (ViT) in PyTorch + +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 + +`FlexiViT: One Model for All Patch Sizes` + - https://arxiv.org/abs/2212.08013 + +The official jax code is released and available at + * https://github.com/google-research/vision_transformer + * https://github.com/google-research/big_vision + +Acknowledgments: + * The paper authors for releasing code and weights, thanks! + * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch + * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT + * Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2020, Ross Wightman +""" + +import logging +import math +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.jit import Final + +from timm.data import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + IMAGENET_INCEPTION_MEAN, + IMAGENET_INCEPTION_STD, + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, +) +from timm.layers import ( + PatchEmbed, + Mlp, + DropPath, + AttentionPoolLatent, + RmsNorm, + PatchDropout, + SwiGLUPacked, + trunc_normal_, + lecun_normal_, + resample_patch_embed, + resample_abs_pos_embed, + use_fused_attn, + get_act_layer, + get_norm_layer, + LayerType, +) +from timm.models._builder import build_model_with_cfg +from timm.models._features import feature_take_indices +from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv +from timm.models._registry import ( + generate_default_cfgs, + register_model, + register_model_deprecations, +) + +__all__ = ["VisionTransformer"] # model_registry will add each entrypoint fn to this + +_logger = logging.getLogger(__name__) + + +def init_1d_freqs(dim: int, end: int, theta: float = 10000.0): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def init_2d_freqs(dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True): + freqs_x = [] + freqs_y = [] + mag = 1 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + for i in range(num_heads): + angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1) + fx = torch.cat( + [mag * torch.cos(angles), mag * torch.cos(torch.pi / 2 + angles)], dim=-1 + ) + fy = torch.cat( + [mag * torch.sin(angles), mag * torch.sin(torch.pi / 2 + angles)], dim=-1 + ) + freqs_x.append(fx) + freqs_y.append(fy) + freqs_x = torch.stack(freqs_x, dim=0) + freqs_y = torch.stack(freqs_y, dim=0) + freqs = torch.stack([freqs_x, freqs_y], dim=0) + return freqs + + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_mixed_cis( + freqs: torch.Tensor, t_x: torch.Tensor, t_y: torch.Tensor, num_heads: int +): + N = t_x.shape[0] + # No float 16 for this range + with torch.cuda.amp.autocast(enabled=False): + freqs_x = ( + (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2)) + .view(N, num_heads, -1) + .permute(1, 0, 2) + ) + freqs_y = ( + (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2)) + .view(N, num_heads, -1) + .permute(1, 0, 2) + ) + freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y) + return freqs_cis + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 100.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + if freqs_cis.shape == (x.shape[-2], x.shape[-1]): + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]): + shape = [d if i >= ndim - 3 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) + + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + **kwargs, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + attn_mask=attn_mask, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + if attn_mask is not None: + attn += attn_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class RoPEAttention(Attention): + """Multi-head Attention block with rotary position embeddings.""" + + def __init__( + self, + *args, + num_prefix_tokens=1, + num_latent_tokens=32, + num_image_tokens=256, + rope_theta=10.0, + rope_mixed=True, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.rope_mixed = rope_mixed + self.num_prefix_tokens = num_prefix_tokens + self.num_latent_tokens = num_latent_tokens + self.num_image_tokens = num_image_tokens + self.num_axis_tokens = int(num_image_tokens**0.5) + + if self.rope_mixed: + self.compute_cis = partial(compute_mixed_cis, num_heads=self.num_heads) + + freqs = init_2d_freqs( + dim=self.head_dim, + num_heads=self.num_heads, + theta=rope_theta, + rotate=True, + ).view(2, -1) + self.freqs = nn.Parameter(freqs, requires_grad=True) + + t_x, t_y = init_t_xy(end_x=self.num_axis_tokens, end_y=self.num_axis_tokens) + self.register_buffer("freqs_t_x", t_x) + self.register_buffer("freqs_t_y", t_y) + else: + self.compute_cis = partial( + compute_axial_cis, dim=self.head_dim, theta=rope_theta + ) + freqs_cis = self.compute_cis( + end_x=self.num_axis_tokens, end_y=self.num_axis_tokens + ) + self.freqs_cis = freqs_cis + + # get pre-compted 1d rope + freqs_1d = init_1d_freqs(dim=self.head_dim, end=self.num_latent_tokens) + self.freqs_1d = nn.Parameter(freqs_1d, requires_grad=True) + + def forward(self, x, attn_mask=None): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + + ###### Apply rotary position embedding + w = h = math.sqrt(x.shape[1] - 1) + if self.rope_mixed: + t_x, t_y = self.freqs_t_x, self.freqs_t_y + if ( + self.freqs_t_x.shape[0] + != x.shape[1] - self.num_prefix_tokens - self.num_latent_tokens + ): + t_x, t_y = init_t_xy(end_x=w, end_y=h) + t_x, t_y = t_x.to(x.device), t_y.to(x.device) + freqs_cis = self.compute_cis(self.freqs, t_x, t_y) + else: + freqs_cis = self.freqs_cis + if ( + self.freqs_cis.shape[0] + != x.shape[1] - self.num_prefix_tokens - self.num_latent_tokens + ): + freqs_cis = self.compute_cis(end_x=w, end_y=h) + freqs_cis = freqs_cis.to(x.device) + + # apply rotary position embedding to image tokens + dtype = x.dtype + with torch.cuda.amp.autocast(enabled=False): + ( + q[:, :, self.num_prefix_tokens : -self.num_latent_tokens], + k[:, :, self.num_prefix_tokens : -self.num_latent_tokens], + ) = apply_rotary_emb( + q[:, :, self.num_prefix_tokens : -self.num_latent_tokens], + k[:, :, self.num_prefix_tokens : -self.num_latent_tokens], + freqs_cis=freqs_cis, + ) + q[:, :, -self.num_latent_tokens :], k[:, :, -self.num_latent_tokens :] = ( + apply_rotary_emb( + q[:, :, -self.num_latent_tokens :], + k[:, :, -self.num_latent_tokens :], + freqs_cis=self.freqs_1d, + ) + ) + q, k = q.to(dtype), k.to(dtype) + ######### + + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + attn_layer: nn.Module = Attention, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = attn_layer( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class ResPostBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.init_values = init_values + + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.init_weights() + + def init_weights(self) -> None: + # NOTE this init overrides that base model init with specific changes for the block type + if self.init_values is not None: + nn.init.constant_(self.norm1.weight, self.init_values) + nn.init.constant_(self.norm2.weight, self.init_values) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path1(self.norm1(self.attn(x))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + return x + + +class ParallelScalingBlock(nn.Module): + """Parallel ViT block (MLP & Attention in parallel) + Based on: + 'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442 + """ + + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: Optional[nn.Module] = None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = use_fused_attn() + mlp_hidden_dim = int(mlp_ratio * dim) + in_proj_out_dim = mlp_hidden_dim + 3 * dim + + self.in_norm = norm_layer(dim) + self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias) + self.in_split = [mlp_hidden_dim] + [dim] * 3 + if qkv_bias: + self.register_buffer("qkv_bias", None) + self.register_parameter("mlp_bias", None) + else: + self.register_buffer("qkv_bias", torch.zeros(3 * dim), persistent=False) + self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim)) + + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.attn_out_proj = nn.Linear(dim, dim) + + self.mlp_drop = nn.Dropout(proj_drop) + self.mlp_act = act_layer() + self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim) + + self.ls = ( + LayerScale(dim, init_values=init_values) + if init_values is not None + else nn.Identity() + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + + # Combined MLP fc1 & qkv projections + y = self.in_norm(x) + if self.mlp_bias is not None: + # Concat constant zero-bias for qkv w/ trainable mlp_bias. + # Appears faster than adding to x_mlp separately + y = F.linear( + y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias)) + ) + else: + y = self.in_proj(y) + x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1) + + # Dot product attention w/ qk norm + q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2) + k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2) + v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) + if self.fused_attn: + x_attn = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x_attn = attn @ v + x_attn = x_attn.transpose(1, 2).reshape(B, N, C) + x_attn = self.attn_out_proj(x_attn) + + # MLP activation, dropout, fc2 + x_mlp = self.mlp_act(x_mlp) + x_mlp = self.mlp_drop(x_mlp) + x_mlp = self.mlp_out_proj(x_mlp) + + # Add residual w/ drop path & layer scale applied + y = self.drop_path(self.ls(x_attn + x_mlp)) + x = x + y + return x + + +class ParallelThingsBlock(nn.Module): + """Parallel ViT block (N parallel attention followed by N parallel MLP) + Based on: + `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + """ + + def __init__( + self, + dim: int, + num_heads: int, + num_parallel: int = 2, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + init_values: Optional[float] = None, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.num_parallel = num_parallel + self.attns = nn.ModuleList() + self.ffns = nn.ModuleList() + for _ in range(num_parallel): + self.attns.append( + nn.Sequential( + OrderedDict( + [ + ("norm", norm_layer(dim)), + ( + "attn", + Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ), + ), + ( + "ls", + ( + LayerScale(dim, init_values=init_values) + if init_values + else nn.Identity() + ), + ), + ( + "drop_path", + ( + DropPath(drop_path) + if drop_path > 0.0 + else nn.Identity() + ), + ), + ] + ) + ) + ) + self.ffns.append( + nn.Sequential( + OrderedDict( + [ + ("norm", norm_layer(dim)), + ( + "mlp", + mlp_layer( + dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ), + ), + ( + "ls", + ( + LayerScale(dim, init_values=init_values) + if init_values + else nn.Identity() + ), + ), + ( + "drop_path", + ( + DropPath(drop_path) + if drop_path > 0.0 + else nn.Identity() + ), + ), + ] + ) + ) + ) + + def _forward_jit(self, x: torch.Tensor) -> torch.Tensor: + x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) + x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) + return x + + @torch.jit.ignore + def _forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + sum(attn(x) for attn in self.attns) + x = x + sum(ffn(x) for ffn in self.ffns) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return self._forward_jit(x) + else: + return self._forward(x) + + +def global_pool_nlc( + x: torch.Tensor, + pool_type: str = "token", + num_prefix_tokens: int = 1, + reduce_include_prefix: bool = False, +): + if not pool_type: + return x + + if pool_type == "token": + x = x[:, 0] # class token + else: + x = x if reduce_include_prefix else x[:, num_prefix_tokens:] + if pool_type == "avg": + x = x.mean(dim=1) + elif pool_type == "avgmax": + x = 0.5 * (x.amax(dim=1) + x.mean(dim=1)) + elif pool_type == "max": + x = x.amax(dim=1) + else: + assert not pool_type, f"Unknown pool type {pool_type}" + + return x + + +class VisionTransformer(nn.Module): + """Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + dynamic_img_size: Final[bool] + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal["", "avg", "avgmax", "max", "token", "map"] = "token", + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: Optional[float] = None, + class_token: bool = True, + pos_embed: str = "learn", + no_embed_class: bool = False, + reg_tokens: int = 0, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, + drop_rate: float = 0.0, + pos_drop_rate: float = 0.0, + patch_drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "", + fix_init: bool = False, + embed_layer: Callable = PatchEmbed, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + block_fn: Type[nn.Module] = Block, + mlp_layer: Type[nn.Module] = Mlp, + attn_layer: Type[nn.Module] = Attention, + num_latent_tokens: int = 32, + ) -> None: + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of image input channels. + num_classes: Number of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. + fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + fix_init: Apply weight initialization fix (scaling w/ layer index). + embed_layer: Patch embedding layer. + norm_layer: Normalization layer. + act_layer: MLP activation layer. + block_fn: Transformer block layer. + """ + super().__init__() + assert global_pool in ("", "avg", "avgmax", "max", "token", "map") + assert class_token or global_pool != "token" + assert pos_embed in ("", "none", "learn") + use_fc_norm = ( + global_pool in ("avg", "avgmax", "max") if fc_norm is None else fc_norm + ) + norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + act_layer = get_act_layer(act_layer) or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.head_hidden_size = self.embed_dim = ( + embed_dim # for consistency with other models + ) + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = ( + no_embed_class # don't embed prefix positions (includes reg) + ) + self.dynamic_img_size = dynamic_img_size + self.grad_checkpointing = False + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt="NHWC")) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + **embed_args, + ) + num_patches = self.patch_embed.num_patches + reduction = ( + self.patch_embed.feat_ratio() + if hasattr(self.patch_embed, "feat_ratio") + else patch_size + ) + + self.cls_token = ( + nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + ) + self.reg_token = ( + nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + ) + embed_len = ( + num_patches if no_embed_class else num_patches + self.num_prefix_tokens + ) + if not pos_embed or pos_embed == "none": + self.pos_embed = None + else: + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + init_values=init_values, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + attn_layer=partial( + attn_layer, + num_prefix_tokens=self.num_prefix_tokens, + num_latent_tokens=num_latent_tokens, + patch_size=patch_size, + ), + ) + for i in range(depth) + ] + ) + self.feature_info = [ + dict(module=f"blocks.{i}", num_chs=embed_dim, reduction=reduction) + for i in range(depth) + ] + self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # Classifier Head + if global_pool == "map": + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + ) + else: + self.attn_pool = None + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + if weight_init != "skip": + self.init_weights(weight_init) + if fix_init: + self.fix_init_weight() + + def fix_init_weight(self): + def rescale(param, _layer_id): + param.div_(math.sqrt(2.0 * _layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def init_weights(self, mode: str = "") -> None: + assert mode in ("jax", "jax_nlhb", "moco", "") + head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(get_init_weights_vit(mode, head_bias), self) + + def _init_weights(self, m: nn.Module) -> None: + # this fn left here for compat with downstream users + init_weights_vit_timm(m) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path: str, prefix: str = "") -> None: + _load_weights(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {"pos_embed", "cls_token", "dist_token"} + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r"^cls_token|pos_embed|patch_embed", # stem and embed + blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + if hasattr(self.patch_embed, "set_grad_checkpointing"): + self.patch_embed.set_grad_checkpointing(enable) + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ("", "avg", "avgmax", "max", "token", "map") + if global_pool == "map" and self.attn_pool is None: + assert ( + False + ), "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != "map " and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.pos_embed is None: + return x.view(x.shape[0], -1, x.shape[-1]) + + if self.dynamic_img_size or len(x.shape) == 4: + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + x = x + pos_embed + + return self.pos_drop(x) + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = "NCHW", + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ( + "NCHW", + "NLC", + ), "Output format must be one of NCHW or NLC." + reshape = output_fmt == "NCHW" + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + + # forward pass + B, _, height, width = x.shape + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + + if ( + torch.jit.is_scripting() or not stop_early + ): # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[: max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0 : self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens :] for y in intermediates] + if reshape: + # reshape to BCHW output format + H, W = self.patch_embed.dynamic_feat_size((height, width)) + intermediates = [ + y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + for y in intermediates + ] + if not torch.jit.is_scripting() and return_prefix_tokens: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """Prune layers not required for specified intermediates.""" + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + self.blocks = self.blocks[: max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.fc_norm = nn.Identity() + self.reset_classifier(0, "") + return take_indices + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, List[int], Tuple[int]] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + ) -> List[torch.Tensor]: + """Intermediate layer accessor inspired by DINO / DINOv2 interface. + NOTE: This API is for backwards compat, favour using forward_intermediates() directly. + """ + return self.forward_intermediates( + x, + n, + return_prefix_tokens=return_prefix_tokens, + norm=norm, + output_fmt="NCHW" if reshape else "NLC", + intermediates_only=True, + ) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + x = self.norm(x) + return x + + def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor: + if self.attn_pool is not None: + x = self.attn_pool(x) + return x + pool_type = self.global_pool if pool_type is None else pool_type + x = global_pool_nlc( + x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens + ) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + x = self.pool(x) + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def init_weights_vit_timm(module: nn.Module, name: str = "") -> None: + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, "init_weights"): + module.init_weights() + + +def init_weights_vit_jax( + module: nn.Module, name: str = "", head_bias: float = 0.0 +) -> None: + """ViT weight initialization, matching JAX (Flax) impl""" + if isinstance(module, nn.Linear): + if name.startswith("head"): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + ( + nn.init.normal_(module.bias, std=1e-6) + if "mlp" in name + else nn.init.zeros_(module.bias) + ) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, "init_weights"): + module.init_weights() + + +def init_weights_vit_moco(module: nn.Module, name: str = "") -> None: + """ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed""" + if isinstance(module, nn.Linear): + if "qkv" in name: + # treat the weights of Q, K, V separately + val = math.sqrt( + 6.0 / float(module.weight.shape[0] // 3 + module.weight.shape[1]) + ) + nn.init.uniform_(module.weight, -val, val) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, "init_weights"): + module.init_weights() + + +def get_init_weights_vit(mode: str = "jax", head_bias: float = 0.0) -> Callable: + if "jax" in mode: + return partial(init_weights_vit_jax, head_bias=head_bias) + elif "moco" in mode: + return init_weights_vit_moco + else: + return init_weights_vit_timm + + +def resize_pos_embed( + posemb: torch.Tensor, + posemb_new: torch.Tensor, + num_prefix_tokens: int = 1, + gs_new: Tuple[int, int] = (), + interpolation: str = "bicubic", + antialias: bool = False, +) -> torch.Tensor: + """Rescale the grid of position embeddings when loading from state_dict. + *DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed + """ + ntok_new = posemb_new.shape[1] - num_prefix_tokens + ntok_old = posemb.shape[1] - num_prefix_tokens + gs_old = [int(math.sqrt(ntok_old))] * 2 + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + return resample_abs_pos_embed( + posemb, + gs_new, + gs_old, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + + +@torch.no_grad() +def _load_weights( + model: VisionTransformer, checkpoint_path: str, prefix: str = "" +) -> None: + """Load weights from .npz checkpoints for official Google Brain Flax implementation""" + import numpy as np + + def _n2p(w, t=True, idx=None): + if idx is not None: + w = w[idx] + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + interpolation = "bilinear" + antialias = False + big_vision = False + if not prefix: + if "opt/target/embedding/kernel" in w: + prefix = "opt/target/" + elif "params/embedding/kernel" in w: + prefix = "params/" + big_vision = True + elif "params/img/embedding/kernel" in w: + prefix = "params/img/" + big_vision = True + + if hasattr(model.patch_embed, "backbone"): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, "stem") + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_( + adapt_input_conv( + stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"]) + ) + ) + stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"])) + stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f"{prefix}block{i + 1}/unit{j + 1}/" + for r in range(3): + getattr(block, f"conv{r + 1}").weight.copy_( + _n2p(w[f"{bp}conv{r + 1}/kernel"]) + ) + getattr(block, f"norm{r + 1}").weight.copy_( + _n2p(w[f"{bp}gn{r + 1}/scale"]) + ) + getattr(block, f"norm{r + 1}").bias.copy_( + _n2p(w[f"{bp}gn{r + 1}/bias"]) + ) + if block.downsample is not None: + block.downsample.conv.weight.copy_( + _n2p(w[f"{bp}conv_proj/kernel"]) + ) + block.downsample.norm.weight.copy_( + _n2p(w[f"{bp}gn_proj/scale"]) + ) + block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"])) + embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"]) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f"{prefix}embedding/kernel"]) + ) + if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]: + embed_conv_w = resample_patch_embed( + embed_conv_w, + model.patch_embed.proj.weight.shape[-2:], + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"])) + if model.cls_token is not None: + model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False)) + if big_vision: + pos_embed_w = _n2p(w[f"{prefix}pos_embedding"], t=False) + else: + pos_embed_w = _n2p( + w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False + ) + if pos_embed_w.shape != model.pos_embed.shape: + old_shape = pos_embed_w.shape + num_prefix_tokens = ( + 0 + if getattr(model, "no_embed_class", False) + else getattr(model, "num_prefix_tokens", 1) + ) + pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, + new_size=model.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"])) + model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"])) + if ( + isinstance(model.head, nn.Linear) + and f"{prefix}head/bias" in w + and model.head.bias.shape[0] == w[f"{prefix}head/bias"].shape[-1] + ): + model.head.weight.copy_(_n2p(w[f"{prefix}head/kernel"])) + model.head.bias.copy_(_n2p(w[f"{prefix}head/bias"])) + # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights + # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + if model.attn_pool is not None: + block_prefix = f"{prefix}MAPHead_0/" + mha_prefix = block_prefix + f"MultiHeadDotProductAttention_0/" + model.attn_pool.latent.copy_(_n2p(w[f"{block_prefix}probe"], t=False)) + model.attn_pool.kv.weight.copy_( + torch.cat( + [ + _n2p(w[f"{mha_prefix}{n}/kernel"], t=False).flatten(1).T + for n in ("key", "value") + ] + ) + ) + model.attn_pool.kv.bias.copy_( + torch.cat( + [ + _n2p(w[f"{mha_prefix}{n}/bias"], t=False).reshape(-1) + for n in ("key", "value") + ] + ) + ) + model.attn_pool.q.weight.copy_( + _n2p(w[f"{mha_prefix}query/kernel"], t=False).flatten(1).T + ) + model.attn_pool.q.bias.copy_( + _n2p(w[f"{mha_prefix}query/bias"], t=False).reshape(-1) + ) + model.attn_pool.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"]).flatten(1)) + model.attn_pool.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"])) + model.attn_pool.norm.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"])) + model.attn_pool.norm.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"])) + for r in range(2): + getattr(model.attn_pool.mlp, f"fc{r + 1}").weight.copy_( + _n2p(w[f"{block_prefix}MlpBlock_0/Dense_{r}/kernel"]) + ) + getattr(model.attn_pool.mlp, f"fc{r + 1}").bias.copy_( + _n2p(w[f"{block_prefix}MlpBlock_0/Dense_{r}/bias"]) + ) + + mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2) + for i, block in enumerate(model.blocks.children()): + if f"{prefix}Transformer/encoderblock/LayerNorm_0/scale" in w: + block_prefix = f"{prefix}Transformer/encoderblock/" + idx = i + else: + block_prefix = f"{prefix}Transformer/encoderblock_{i}/" + idx = None + mha_prefix = block_prefix + f"MultiHeadDotProductAttention_{mha_sub}/" + block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"], idx=idx)) + block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"], idx=idx)) + block.attn.qkv.weight.copy_( + torch.cat( + [ + _n2p(w[f"{mha_prefix}{n}/kernel"], t=False, idx=idx).flatten(1).T + for n in ("query", "key", "value") + ] + ) + ) + block.attn.qkv.bias.copy_( + torch.cat( + [ + _n2p(w[f"{mha_prefix}{n}/bias"], t=False, idx=idx).reshape(-1) + for n in ("query", "key", "value") + ] + ) + ) + block.attn.proj.weight.copy_( + _n2p(w[f"{mha_prefix}out/kernel"], idx=idx).flatten(1) + ) + block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"], idx=idx)) + block.norm2.weight.copy_( + _n2p(w[f"{block_prefix}LayerNorm_{ln1_sub}/scale"], idx=idx) + ) + block.norm2.bias.copy_( + _n2p(w[f"{block_prefix}LayerNorm_{ln1_sub}/bias"], idx=idx) + ) + for r in range(2): + getattr(block.mlp, f"fc{r + 1}").weight.copy_( + _n2p(w[f"{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel"], idx=idx) + ) + getattr(block.mlp, f"fc{r + 1}").bias.copy_( + _n2p(w[f"{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias"], idx=idx) + ) + + +def _convert_openai_clip( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, + prefix: str = "visual.", +) -> Dict[str, torch.Tensor]: + out_dict = {} + swaps = [ + ("conv1", "patch_embed.proj"), + ("positional_embedding", "pos_embed"), + ("transformer.resblocks.", "blocks."), + ("ln_pre", "norm_pre"), + ("ln_post", "norm"), + ("ln_", "norm"), + ("in_proj_", "qkv."), + ("out_proj", "proj"), + ("mlp.c_fc", "mlp.fc1"), + ("mlp.c_proj", "mlp.fc2"), + ] + for k, v in state_dict.items(): + if not k.startswith(prefix): + continue + k = k.replace(prefix, "") + for sp in swaps: + k = k.replace(sp[0], sp[1]) + + if k == "proj": + k = "head.weight" + v = v.transpose(0, 1) + out_dict["head.bias"] = torch.zeros(v.shape[0]) + elif k == "class_embedding": + k = "cls_token" + v = v.unsqueeze(0).unsqueeze(1) + elif k == "pos_embed": + v = v.unsqueeze(0) + out_dict[k] = v + return out_dict + + +def _convert_dinov2( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, +) -> Dict[str, torch.Tensor]: + import re + + out_dict = {} + state_dict.pop("mask_token", None) + if "register_tokens" in state_dict: + # convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed) + out_dict["reg_token"] = state_dict.pop("register_tokens") + out_dict["cls_token"] = ( + state_dict.pop("cls_token") + state_dict["pos_embed"][:, 0] + ) + out_dict["pos_embed"] = state_dict.pop("pos_embed")[:, 1:] + for k, v in state_dict.items(): + if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k): + out_dict[k.replace("w12", "fc1")] = v + continue + elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k): + out_dict[k.replace("w3", "fc2")] = v + continue + out_dict[k] = v + return out_dict + + +def checkpoint_filter_fn( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, + adapt_layer_scale: bool = False, + interpolation: str = "bicubic", + antialias: bool = True, +) -> Dict[str, torch.Tensor]: + """convert patch embedding weight from manual patchify + linear proj to conv""" + import re + + out_dict = {} + state_dict = state_dict.get("model", state_dict) + state_dict = state_dict.get("state_dict", state_dict) + prefix = "" + + if "visual.class_embedding" in state_dict: + state_dict = _convert_openai_clip(state_dict, model) + elif "module.visual.class_embedding" in state_dict: + state_dict = _convert_openai_clip(state_dict, model, prefix="module.visual.") + elif "mask_token" in state_dict: + state_dict = _convert_dinov2(state_dict, model) + elif "encoder" in state_dict: + # IJEPA, vit in an 'encoder' submodule + state_dict = state_dict["encoder"] + prefix = "module." + elif ( + "visual.trunk.pos_embed" in state_dict + or "visual.trunk.blocks.0.norm1.weight" in state_dict + ): + # OpenCLIP model with timm vision encoder + prefix = "visual.trunk." + if "visual.head.proj.weight" in state_dict and isinstance( + model.head, nn.Linear + ): + # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) + out_dict["head.weight"] = state_dict["visual.head.proj.weight"] + out_dict["head.bias"] = torch.zeros( + state_dict["visual.head.proj.weight"].shape[0] + ) + + if prefix: + # filter on & remove prefix string from keys + state_dict = { + k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix) + } + + for k, v in state_dict.items(): + if "patch_embed.proj.weight" in k: + O, I, H, W = model.patch_embed.proj.weight.shape + if len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + if v.shape[-1] != W or v.shape[-2] != H: + v = resample_patch_embed( + v, + (H, W), + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif k == "pos_embed" and v.shape[1] != model.pos_embed.shape[1]: + # To resize pos embedding when using model at different size from pretrained weights + num_prefix_tokens = ( + 0 + if getattr(model, "no_embed_class", False) + else getattr(model, "num_prefix_tokens", 1) + ) + v = resample_abs_pos_embed( + v, + new_size=model.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif adapt_layer_scale and "gamma_" in k: + # remap layer-scale gamma into sub-module (deit3 models) + k = re.sub(r"gamma_([0-9])", r"ls\1.gamma", k) + elif "pre_logits" in k: + # NOTE representation layer removed as not used in latest 21k/1k pretrained weights + continue + out_dict[k] = v + return out_dict + + +def _cfg(url: str = "", **kwargs) -> Dict[str, Any]: + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 224, 224), + "pool_size": None, + "crop_pct": 0.9, + "interpolation": "bicubic", + "fixed_input_size": True, + "mean": IMAGENET_INCEPTION_MEAN, + "std": IMAGENET_INCEPTION_STD, + "first_conv": "patch_embed.proj", + "classifier": "head", + **kwargs, + } + + +default_cfgs = { + # re-finetuned augreg 21k FT on in1k weights + "vit_base_patch16_224.augreg2_in21k_ft_in1k": _cfg(hf_hub_id="timm/"), + "vit_base_patch16_384.augreg2_in21k_ft_in1k": _cfg(), + "vit_base_patch8_224.augreg2_in21k_ft_in1k": _cfg(hf_hub_id="timm/"), + # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k + "vit_tiny_patch16_224.augreg_in21k_ft_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz", + hf_hub_id="timm/", + custom_load=True, + ), + "vit_tiny_patch16_384.augreg_in21k_ft_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz", + hf_hub_id="timm/", + custom_load=True, + input_size=(3, 384, 384), + crop_pct=1.0, + ), + "vit_small_patch32_224.augreg_in21k_ft_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz", + hf_hub_id="timm/", + custom_load=True, + ), + "vit_small_patch32_384.augreg_in21k_ft_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz", + hf_hub_id="timm/", + custom_load=True, + input_size=(3, 384, 384), + crop_pct=1.0, + ), + "vit_small_patch16_224.augreg_in21k_ft_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz", + hf_hub_id="timm/", + custom_load=True, + ), + "vit_small_patch16_384.augreg_in21k_ft_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz", + hf_hub_id="timm/", + custom_load=True, + input_size=(3, 384, 384), + crop_pct=1.0, + ), + "vit_base_patch32_224.augreg_in21k_ft_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz", + hf_hub_id="timm/", + custom_load=True, + ), + "vit_base_patch32_384.augreg_in21k_ft_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz", + hf_hub_id="timm/", + custom_load=True, + input_size=(3, 384, 384), + crop_pct=1.0, + ), + "vit_base_patch16_224.augreg_in21k_ft_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz", + hf_hub_id="timm/", + custom_load=True, + ), + "vit_base_patch16_384.augreg_in21k_ft_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz", + hf_hub_id="timm/", + custom_load=True, + input_size=(3, 384, 384), + crop_pct=1.0, + ), + "vit_base_patch8_224.augreg_in21k_ft_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz", + hf_hub_id="timm/", + custom_load=True, + ), + "vit_large_patch16_224.augreg_in21k_ft_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz", + hf_hub_id="timm/", + custom_load=True, + ), + "vit_large_patch16_384.augreg_in21k_ft_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz", + hf_hub_id="timm/", + custom_load=True, + input_size=(3, 384, 384), + crop_pct=1.0, + ), + # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k + "vit_base_patch16_224.orig_in21k_ft_in1k": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth", + hf_hub_id="timm/", + ), + "vit_base_patch16_384.orig_in21k_ft_in1k": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth", + hf_hub_id="timm/", + input_size=(3, 384, 384), + crop_pct=1.0, + ), + "vit_large_patch32_384.orig_in21k_ft_in1k": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth", + hf_hub_id="timm/", + input_size=(3, 384, 384), + crop_pct=1.0, + ), + # How to train your ViT (augreg) weights trained on in1k only + "vit_small_patch16_224.augreg_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz", + hf_hub_id="timm/", + custom_load=True, + ), + "vit_small_patch16_384.augreg_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz", + hf_hub_id="timm/", + custom_load=True, + input_size=(3, 384, 384), + crop_pct=1.0, + ), + "vit_base_patch32_224.augreg_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz", + hf_hub_id="timm/", + custom_load=True, + ), + "vit_base_patch32_384.augreg_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz", + hf_hub_id="timm/", + custom_load=True, + input_size=(3, 384, 384), + crop_pct=1.0, + ), + "vit_base_patch16_224.augreg_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz", + hf_hub_id="timm/", + custom_load=True, + ), + "vit_base_patch16_384.augreg_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz", + hf_hub_id="timm/", + custom_load=True, + input_size=(3, 384, 384), + crop_pct=1.0, + ), + "vit_large_patch14_224.untrained": _cfg(url=""), + "vit_huge_patch14_224.untrained": _cfg(url=""), + "vit_giant_patch14_224.untrained": _cfg(url=""), + "vit_gigantic_patch14_224.untrained": _cfg(url=""), + # patch models, imagenet21k (weights from official Google JAX impl), classifier not valid + "vit_base_patch32_224.orig_in21k": _cfg( + # url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth', + hf_hub_id="timm/", + num_classes=0, + ), + "vit_base_patch16_224.orig_in21k": _cfg( + # url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth', + hf_hub_id="timm/", + num_classes=0, + ), + "vit_large_patch32_224.orig_in21k": _cfg( + # url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', + hf_hub_id="timm/", + num_classes=0, + ), + "vit_large_patch16_224.orig_in21k": _cfg( + # url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth', + hf_hub_id="timm/", + num_classes=0, + ), + "vit_huge_patch14_224.orig_in21k": _cfg(hf_hub_id="timm/", num_classes=0), + # How to train your ViT (augreg) weights, pretrained on in21k + "vit_tiny_patch16_224.augreg_in21k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz", + hf_hub_id="timm/", + custom_load=True, + num_classes=21843, + ), + "vit_small_patch32_224.augreg_in21k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz", + hf_hub_id="timm/", + custom_load=True, + num_classes=21843, + ), + "vit_small_patch16_224.augreg_in21k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz", + hf_hub_id="timm/", + custom_load=True, + num_classes=21843, + ), + "vit_base_patch32_224.augreg_in21k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz", + hf_hub_id="timm/", + custom_load=True, + num_classes=21843, + ), + "vit_base_patch16_224.augreg_in21k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", + hf_hub_id="timm/", + custom_load=True, + num_classes=21843, + ), + "vit_base_patch8_224.augreg_in21k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", + hf_hub_id="timm/", + custom_load=True, + num_classes=21843, + ), + "vit_large_patch16_224.augreg_in21k": _cfg( + url="https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz", + hf_hub_id="timm/", + custom_load=True, + num_classes=21843, + ), + # SAM trained models (https://arxiv.org/abs/2106.01548) + "vit_base_patch32_224.sam_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz", + custom_load=True, + hf_hub_id="timm/", + ), + "vit_base_patch16_224.sam_in1k": _cfg( + url="https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz", + custom_load=True, + hf_hub_id="timm/", + ), + # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) + "vit_small_patch16_224.dino": _cfg( + url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + hf_hub_id="timm/", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + ), + "vit_small_patch8_224.dino": _cfg( + url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", + hf_hub_id="timm/", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + ), + "vit_base_patch16_224.dino": _cfg( + url="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth", + hf_hub_id="timm/", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + ), + "vit_base_patch8_224.dino": _cfg( + url="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", + hf_hub_id="timm/", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + ), + # DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only) + "vit_small_patch14_dinov2.lvd142m": _cfg( + url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth", + hf_hub_id="timm/", + license="apache-2.0", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + input_size=(3, 518, 518), + crop_pct=1.0, + ), + "vit_base_patch14_dinov2.lvd142m": _cfg( + url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth", + hf_hub_id="timm/", + license="apache-2.0", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + input_size=(3, 518, 518), + crop_pct=1.0, + ), + "vit_large_patch14_dinov2.lvd142m": _cfg( + url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", + hf_hub_id="timm/", + license="apache-2.0", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + input_size=(3, 518, 518), + crop_pct=1.0, + ), + "vit_giant_patch14_dinov2.lvd142m": _cfg( + url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth", + hf_hub_id="timm/", + license="apache-2.0", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + input_size=(3, 518, 518), + crop_pct=1.0, + ), + # DINOv2 pretrained w/ registers - https://arxiv.org/abs/2309.16588 (no classifier head, for fine-tune/features only) + "vit_small_patch14_reg4_dinov2.lvd142m": _cfg( + url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth", + hf_hub_id="timm/", + license="apache-2.0", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + input_size=(3, 518, 518), + crop_pct=1.0, + ), + "vit_base_patch14_reg4_dinov2.lvd142m": _cfg( + url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth", + hf_hub_id="timm/", + license="apache-2.0", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + input_size=(3, 518, 518), + crop_pct=1.0, + ), + "vit_large_patch14_reg4_dinov2.lvd142m": _cfg( + url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth", + hf_hub_id="timm/", + license="apache-2.0", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + input_size=(3, 518, 518), + crop_pct=1.0, + ), + "vit_giant_patch14_reg4_dinov2.lvd142m": _cfg( + url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth", + hf_hub_id="timm/", + license="apache-2.0", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + input_size=(3, 518, 518), + crop_pct=1.0, + ), + # ViT ImageNet-21K-P pretraining by MILL + "vit_base_patch16_224_miil.in21k": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth", + hf_hub_id="timm/", + mean=(0.0, 0.0, 0.0), + std=(1.0, 1.0, 1.0), + crop_pct=0.875, + interpolation="bilinear", + num_classes=11221, + ), + "vit_base_patch16_224_miil.in21k_ft_in1k": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth", + hf_hub_id="timm/", + mean=(0.0, 0.0, 0.0), + std=(1.0, 1.0, 1.0), + crop_pct=0.875, + interpolation="bilinear", + ), + # Custom timm variants + "vit_base_patch16_rpn_224.sw_in1k": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth", + hf_hub_id="timm/", + ), + "vit_medium_patch16_gap_240.sw_in12k": _cfg( + hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821 + ), + "vit_medium_patch16_gap_256.sw_in12k_ft_in1k": _cfg( + hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_medium_patch16_gap_384.sw_in12k_ft_in1k": _cfg( + hf_hub_id="timm/", input_size=(3, 384, 384), crop_pct=0.95, crop_mode="squash" + ), + "vit_base_patch16_gap_224": _cfg(), + # CLIP pretrained image tower and related fine-tuned weights + "vit_base_patch32_clip_224.laion2b_ft_in12k_in1k": _cfg( + hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD + ), + "vit_base_patch32_clip_384.laion2b_ft_in12k_in1k": _cfg( + hf_hub_id="timm/", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + input_size=(3, 384, 384), + ), + "vit_base_patch32_clip_448.laion2b_ft_in12k_in1k": _cfg( + hf_hub_id="timm/", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + input_size=(3, 448, 448), + ), + "vit_base_patch16_clip_224.laion2b_ft_in12k_in1k": _cfg( + hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95 + ), + "vit_base_patch16_clip_384.laion2b_ft_in12k_in1k": _cfg( + hf_hub_id="timm/", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + input_size=(3, 384, 384), + crop_mode="squash", + ), + "vit_large_patch14_clip_224.laion2b_ft_in12k_in1k": _cfg( + hf_hub_id="timm/", + mean=IMAGENET_INCEPTION_MEAN, + std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, + ), + "vit_large_patch14_clip_336.laion2b_ft_in12k_in1k": _cfg( + hf_hub_id="timm/", + mean=IMAGENET_INCEPTION_MEAN, + std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, + input_size=(3, 336, 336), + crop_mode="squash", + ), + "vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k": _cfg( + hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0 + ), + "vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k": _cfg( + hf_hub_id="timm/", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + input_size=(3, 336, 336), + crop_mode="squash", + ), + "vit_base_patch32_clip_224.openai_ft_in12k_in1k": _cfg( + # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k', # FIXME weight exists, need to push + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + ), + "vit_base_patch32_clip_384.openai_ft_in12k_in1k": _cfg( + hf_hub_id="timm/", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=0.95, + input_size=(3, 384, 384), + crop_mode="squash", + ), + "vit_base_patch16_clip_224.openai_ft_in12k_in1k": _cfg( + hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95 + ), + "vit_base_patch16_clip_384.openai_ft_in12k_in1k": _cfg( + hf_hub_id="timm/", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=0.95, + input_size=(3, 384, 384), + crop_mode="squash", + ), + "vit_large_patch14_clip_224.openai_ft_in12k_in1k": _cfg( + hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0 + ), + "vit_large_patch14_clip_336.openai_ft_in12k_in1k": _cfg( + hf_hub_id="timm/", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + input_size=(3, 336, 336), + crop_mode="squash", + ), + "vit_base_patch32_clip_224.laion2b_ft_in1k": _cfg( + hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD + ), + "vit_base_patch16_clip_224.laion2b_ft_in1k": _cfg( + hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0 + ), + "vit_base_patch16_clip_384.laion2b_ft_in1k": _cfg( + hf_hub_id="timm/", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + input_size=(3, 384, 384), + crop_mode="squash", + ), + "vit_large_patch14_clip_224.laion2b_ft_in1k": _cfg( + hf_hub_id="timm/", + mean=IMAGENET_INCEPTION_MEAN, + std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, + ), + "vit_large_patch14_clip_336.laion2b_ft_in1k": _cfg( + hf_hub_id="timm/", + mean=IMAGENET_INCEPTION_MEAN, + std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, + input_size=(3, 336, 336), + crop_mode="squash", + ), + "vit_huge_patch14_clip_224.laion2b_ft_in1k": _cfg( + hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0 + ), + "vit_huge_patch14_clip_336.laion2b_ft_in1k": _cfg( + hf_hub_id="", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + input_size=(3, 336, 336), + crop_mode="squash", + ), + "vit_base_patch32_clip_224.openai_ft_in1k": _cfg( + hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD + ), + "vit_base_patch16_clip_224.openai_ft_in1k": _cfg( + hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD + ), + "vit_base_patch16_clip_384.openai_ft_in1k": _cfg( + hf_hub_id="timm/", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + input_size=(3, 384, 384), + crop_mode="squash", + ), + "vit_large_patch14_clip_224.openai_ft_in1k": _cfg( + hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0 + ), + "vit_base_patch32_clip_224.laion2b_ft_in12k": _cfg( + # hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k', # FIXME weight exists, need to push + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + num_classes=11821, + ), + "vit_base_patch16_clip_224.laion2b_ft_in12k": _cfg( + hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821 + ), + "vit_large_patch14_clip_224.laion2b_ft_in12k": _cfg( + hf_hub_id="timm/", + mean=IMAGENET_INCEPTION_MEAN, + std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, + num_classes=11821, + ), + "vit_huge_patch14_clip_224.laion2b_ft_in12k": _cfg( + hf_hub_id="timm/", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=11821, + ), + "vit_base_patch32_clip_224.openai_ft_in12k": _cfg( + # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k', # FIXME weight exists, need to push + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + num_classes=11821, + ), + "vit_base_patch16_clip_224.openai_ft_in12k": _cfg( + hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821 + ), + "vit_large_patch14_clip_224.openai_ft_in12k": _cfg( + hf_hub_id="timm/", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=11821, + ), + "vit_base_patch32_clip_224.laion2b": _cfg( + hf_hub_id="laion/CLIP-ViT-B-32-laion2B-s34B-b79K", + hf_hub_filename="open_clip_pytorch_model.bin", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + num_classes=512, + ), + "vit_base_patch16_clip_224.laion2b": _cfg( + hf_hub_id="laion/CLIP-ViT-B-16-laion2B-s34B-b88K", + hf_hub_filename="open_clip_pytorch_model.bin", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=512, + ), + "vit_large_patch14_clip_224.laion2b": _cfg( + hf_hub_id="laion/CLIP-ViT-L-14-laion2B-s32B-b82K", + hf_hub_filename="open_clip_pytorch_model.bin", + mean=IMAGENET_INCEPTION_MEAN, + std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, + num_classes=768, + ), + "vit_huge_patch14_clip_224.laion2b": _cfg( + hf_hub_id="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", + hf_hub_filename="open_clip_pytorch_model.bin", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=1024, + ), + "vit_giant_patch14_clip_224.laion2b": _cfg( + hf_hub_id="laion/CLIP-ViT-g-14-laion2B-s12B-b42K", + hf_hub_filename="open_clip_pytorch_model.bin", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=1024, + ), + "vit_gigantic_patch14_clip_224.laion2b": _cfg( + hf_hub_id="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", + hf_hub_filename="open_clip_pytorch_model.bin", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=1280, + ), + "vit_base_patch32_clip_224.datacompxl": _cfg( + hf_hub_id="laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K", + hf_hub_filename="open_clip_pytorch_model.bin", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=512, + ), + "vit_base_patch32_clip_256.datacompxl": _cfg( + hf_hub_id="laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K", + hf_hub_filename="open_clip_pytorch_model.bin", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + input_size=(3, 256, 256), + num_classes=512, + ), + "vit_base_patch16_clip_224.datacompxl": _cfg( + hf_hub_id="laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K", + hf_hub_filename="open_clip_pytorch_model.bin", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=512, + ), + "vit_large_patch14_clip_224.datacompxl": _cfg( + hf_hub_id="laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K", + hf_hub_filename="open_clip_pytorch_model.bin", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=768, + ), + "vit_base_patch16_clip_224.dfn2b": _cfg( + hf_hub_id="apple/DFN2B-CLIP-ViT-B-16", + hf_hub_filename="open_clip_pytorch_model.bin", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=512, + ), + "vit_large_patch14_clip_224.dfn2b": _cfg( + hf_hub_id="apple/DFN2B-CLIP-ViT-L-14", + hf_hub_filename="open_clip_pytorch_model.bin", + notes=("natively QuickGELU, use quickgelu model variant for original results",), + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=768, + ), + "vit_huge_patch14_clip_224.dfn5b": _cfg( + hf_hub_id="apple/DFN5B-CLIP-ViT-H-14", + hf_hub_filename="open_clip_pytorch_model.bin", + notes=("natively QuickGELU, use quickgelu model variant for original results",), + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=1024, + ), + "vit_huge_patch14_clip_378.dfn5b": _cfg( + hf_hub_id="apple/DFN5B-CLIP-ViT-H-14-378", + hf_hub_filename="open_clip_pytorch_model.bin", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + notes=("natively QuickGELU, use quickgelu model variant for original results",), + crop_pct=1.0, + input_size=(3, 378, 378), + num_classes=1024, + ), + "vit_base_patch32_clip_224.metaclip_2pt5b": _cfg( + hf_hub_id="facebook/metaclip-b32-fullcc2.5b", + hf_hub_filename="metaclip_b32_fullcc2.5b.bin", + license="cc-by-nc-4.0", + notes=("natively QuickGELU, use quickgelu model variant for original results",), + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=512, + ), + "vit_base_patch16_clip_224.metaclip_2pt5b": _cfg( + hf_hub_id="facebook/metaclip-b16-fullcc2.5b", + hf_hub_filename="metaclip_b16_fullcc2.5b.bin", + license="cc-by-nc-4.0", + notes=("natively QuickGELU, use quickgelu model variant for original results",), + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=512, + ), + "vit_large_patch14_clip_224.metaclip_2pt5b": _cfg( + hf_hub_id="facebook/metaclip-l14-fullcc2.5b", + hf_hub_filename="metaclip_l14_fullcc2.5b.bin", + license="cc-by-nc-4.0", + notes=("natively QuickGELU, use quickgelu model variant for original results",), + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=768, + ), + "vit_huge_patch14_clip_224.metaclip_2pt5b": _cfg( + hf_hub_id="facebook/metaclip-h14-fullcc2.5b", + hf_hub_filename="metaclip_h14_fullcc2.5b.bin", + license="cc-by-nc-4.0", + notes=("natively QuickGELU, use quickgelu model variant for original results",), + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=1024, + ), + "vit_base_patch32_clip_224.openai": _cfg( + hf_hub_id="timm/vit_base_patch32_clip_224.openai", + notes=("natively QuickGELU, use quickgelu model variant for original results",), + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + num_classes=512, + ), + "vit_base_patch16_clip_224.openai": _cfg( + hf_hub_id="timm/vit_base_patch16_clip_224.openai", + notes=("natively QuickGELU, use quickgelu model variant for original results",), + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + num_classes=512, + ), + "vit_large_patch14_clip_224.openai": _cfg( + hf_hub_id="timm/vit_large_patch14_clip_224.openai", + notes=("natively QuickGELU, use quickgelu model variant for original results",), + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + num_classes=768, + ), + "vit_large_patch14_clip_336.openai": _cfg( + hf_hub_id="timm/vit_large_patch14_clip_336.openai", + hf_hub_filename="open_clip_pytorch_model.bin", + notes=("natively QuickGELU, use quickgelu model variant for original results",), + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + crop_pct=1.0, + input_size=(3, 336, 336), + num_classes=768, + ), + # experimental (may be removed) + "vit_base_patch32_plus_256.untrained": _cfg( + url="", input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_base_patch16_plus_240.untrained": _cfg( + url="", input_size=(3, 240, 240), crop_pct=0.95 + ), + "vit_small_patch16_36x1_224.untrained": _cfg(url=""), + "vit_small_patch16_18x2_224.untrained": _cfg(url=""), + "vit_base_patch16_18x2_224.untrained": _cfg(url=""), + # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain + # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip + "eva_large_patch14_196.in22k_ft_in22k_in1k": _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt', + hf_hub_id="timm/", + license="mit", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), + crop_pct=1.0, + ), + "eva_large_patch14_336.in22k_ft_in22k_in1k": _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt', + hf_hub_id="timm/", + license="mit", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), + crop_pct=1.0, + crop_mode="squash", + ), + "eva_large_patch14_196.in22k_ft_in1k": _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt', + hf_hub_id="timm/", + license="mit", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), + crop_pct=1.0, + ), + "eva_large_patch14_336.in22k_ft_in1k": _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt', + hf_hub_id="timm/", + license="mit", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), + crop_pct=1.0, + crop_mode="squash", + ), + "flexivit_small.1200ep_in1k": _cfg( + url="https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz", + custom_load=True, + hf_hub_id="timm/", + input_size=(3, 240, 240), + crop_pct=0.95, + ), + "flexivit_small.600ep_in1k": _cfg( + url="https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz", + custom_load=True, + hf_hub_id="timm/", + input_size=(3, 240, 240), + crop_pct=0.95, + ), + "flexivit_small.300ep_in1k": _cfg( + url="https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz", + custom_load=True, + hf_hub_id="timm/", + input_size=(3, 240, 240), + crop_pct=0.95, + ), + "flexivit_base.1200ep_in1k": _cfg( + url="https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz", + custom_load=True, + hf_hub_id="timm/", + input_size=(3, 240, 240), + crop_pct=0.95, + ), + "flexivit_base.600ep_in1k": _cfg( + url="https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz", + custom_load=True, + hf_hub_id="timm/", + input_size=(3, 240, 240), + crop_pct=0.95, + ), + "flexivit_base.300ep_in1k": _cfg( + url="https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz", + custom_load=True, + hf_hub_id="timm/", + input_size=(3, 240, 240), + crop_pct=0.95, + ), + "flexivit_base.1000ep_in21k": _cfg( + url="https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz", + custom_load=True, + hf_hub_id="timm/", + input_size=(3, 240, 240), + crop_pct=0.95, + num_classes=21843, + ), + "flexivit_base.300ep_in21k": _cfg( + url="https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz", + custom_load=True, + hf_hub_id="timm/", + input_size=(3, 240, 240), + crop_pct=0.95, + num_classes=21843, + ), + "flexivit_large.1200ep_in1k": _cfg( + url="https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz", + custom_load=True, + hf_hub_id="timm/", + input_size=(3, 240, 240), + crop_pct=0.95, + ), + "flexivit_large.600ep_in1k": _cfg( + url="https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz", + custom_load=True, + hf_hub_id="timm/", + input_size=(3, 240, 240), + crop_pct=0.95, + ), + "flexivit_large.300ep_in1k": _cfg( + url="https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz", + custom_load=True, + hf_hub_id="timm/", + input_size=(3, 240, 240), + crop_pct=0.95, + ), + "flexivit_base.patch16_in21k": _cfg( + url="https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz", + custom_load=True, + hf_hub_id="timm/", + input_size=(3, 240, 240), + crop_pct=0.95, + num_classes=21843, + ), + "flexivit_base.patch30_in21k": _cfg( + url="https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz", + custom_load=True, + hf_hub_id="timm/", + input_size=(3, 240, 240), + crop_pct=0.95, + num_classes=21843, + ), + "vit_base_patch16_xp_224.untrained": _cfg(url=""), + "vit_large_patch14_xp_224.untrained": _cfg(url=""), + "vit_huge_patch14_xp_224.untrained": _cfg(url=""), + "vit_base_patch16_224.mae": _cfg( + url="https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth", + hf_hub_id="timm/", + license="cc-by-nc-4.0", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + ), + "vit_large_patch16_224.mae": _cfg( + url="https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth", + hf_hub_id="timm/", + license="cc-by-nc-4.0", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + ), + "vit_huge_patch14_224.mae": _cfg( + url="https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_huge.pth", + hf_hub_id="timm/", + license="cc-by-nc-4.0", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + ), + "vit_huge_patch14_gap_224.in1k_ijepa": _cfg( + url="https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar", + # hf_hub_id='timm/', + license="cc-by-nc-4.0", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + ), + "vit_huge_patch14_gap_224.in22k_ijepa": _cfg( + url="https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar", + # hf_hub_id='timm/', + license="cc-by-nc-4.0", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + ), + "vit_huge_patch16_gap_448.in1k_ijepa": _cfg( + url="https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar", + # hf_hub_id='timm/', + license="cc-by-nc-4.0", + input_size=(3, 448, 448), + crop_pct=1.0, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + ), + "vit_giant_patch16_gap_224.in22k_ijepa": _cfg( + url="https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar", + # hf_hub_id='timm/', + license="cc-by-nc-4.0", + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + ), + "vit_base_patch16_siglip_224.webli": _cfg( + hf_hub_id="timm/ViT-B-16-SigLIP", + hf_hub_filename="open_clip_pytorch_model.bin", + num_classes=0, + ), + "vit_base_patch16_siglip_256.webli": _cfg( + hf_hub_id="timm/ViT-B-16-SigLIP-256", + hf_hub_filename="open_clip_pytorch_model.bin", + input_size=(3, 256, 256), + num_classes=0, + ), + "vit_base_patch16_siglip_384.webli": _cfg( + hf_hub_id="timm/ViT-B-16-SigLIP-384", + hf_hub_filename="open_clip_pytorch_model.bin", + input_size=(3, 384, 384), + num_classes=0, + ), + "vit_base_patch16_siglip_512.webli": _cfg( + hf_hub_id="timm/ViT-B-16-SigLIP-512", + hf_hub_filename="open_clip_pytorch_model.bin", + input_size=(3, 512, 512), + num_classes=0, + ), + "vit_large_patch16_siglip_256.webli": _cfg( + hf_hub_id="timm/ViT-L-16-SigLIP-256", + hf_hub_filename="open_clip_pytorch_model.bin", + input_size=(3, 256, 256), + num_classes=0, + ), + "vit_large_patch16_siglip_384.webli": _cfg( + hf_hub_id="timm/ViT-L-16-SigLIP-384", + hf_hub_filename="open_clip_pytorch_model.bin", + input_size=(3, 384, 384), + num_classes=0, + ), + "vit_so400m_patch14_siglip_224.webli": _cfg( + hf_hub_id="timm/ViT-SO400M-14-SigLIP", + hf_hub_filename="open_clip_pytorch_model.bin", + num_classes=0, + ), + "vit_so400m_patch14_siglip_384.webli": _cfg( + hf_hub_id="timm/ViT-SO400M-14-SigLIP-384", + hf_hub_filename="open_clip_pytorch_model.bin", + input_size=(3, 384, 384), + num_classes=0, + ), + "vit_base_patch16_siglip_gap_224.webli": _cfg( + hf_hub_id="timm/ViT-B-16-SigLIP", + hf_hub_filename="open_clip_pytorch_model.bin", + num_classes=0, + ), + "vit_base_patch16_siglip_gap_256.webli": _cfg( + hf_hub_id="timm/ViT-B-16-SigLIP-256", + hf_hub_filename="open_clip_pytorch_model.bin", + input_size=(3, 256, 256), + num_classes=0, + ), + "vit_base_patch16_siglip_gap_384.webli": _cfg( + hf_hub_id="timm/ViT-B-16-SigLIP-384", + hf_hub_filename="open_clip_pytorch_model.bin", + input_size=(3, 384, 384), + num_classes=0, + ), + "vit_base_patch16_siglip_gap_512.webli": _cfg( + hf_hub_id="timm/ViT-B-16-SigLIP-512", + hf_hub_filename="open_clip_pytorch_model.bin", + input_size=(3, 512, 512), + num_classes=0, + ), + "vit_large_patch16_siglip_gap_256.webli": _cfg( + hf_hub_id="timm/ViT-L-16-SigLIP-256", + hf_hub_filename="open_clip_pytorch_model.bin", + input_size=(3, 256, 256), + num_classes=0, + ), + "vit_large_patch16_siglip_gap_384.webli": _cfg( + hf_hub_id="timm/ViT-L-16-SigLIP-384", + hf_hub_filename="open_clip_pytorch_model.bin", + input_size=(3, 384, 384), + num_classes=0, + ), + "vit_so400m_patch14_siglip_gap_224.webli": _cfg( + hf_hub_id="timm/ViT-SO400M-14-SigLIP", + hf_hub_filename="open_clip_pytorch_model.bin", + num_classes=0, + ), + "vit_so400m_patch14_siglip_gap_224.pali_mix": _cfg( + hf_hub_id="google/paligemma-3b-mix-224-jax", + hf_hub_filename="paligemma-3b-mix-224.npz", + custom_load="hf", + num_classes=0, + ), + "vit_so400m_patch14_siglip_gap_224.pali_pt": _cfg( + hf_hub_id="google/paligemma-3b-pt-224-jax", + hf_hub_filename="paligemma-3b-pt-224.npz", + custom_load="hf", + num_classes=0, + ), + "vit_so400m_patch14_siglip_gap_384.webli": _cfg( + hf_hub_id="timm/ViT-SO400M-14-SigLIP-384", + hf_hub_filename="open_clip_pytorch_model.bin", + input_size=(3, 384, 384), + crop_pct=1.0, + num_classes=0, + ), + "vit_so400m_patch14_siglip_gap_448.pali_mix": _cfg( + hf_hub_id="google/paligemma-3b-mix-448-jax", + hf_hub_filename="paligemma-3b-mix-448.npz", + custom_load="hf", + input_size=(3, 448, 448), + crop_pct=1.0, + num_classes=0, + ), + "vit_so400m_patch14_siglip_gap_448.pali_pt": _cfg( + hf_hub_id="google/paligemma-3b-pt-448-jax", + hf_hub_filename="paligemma-3b-pt-448.npz", + custom_load="hf", + input_size=(3, 448, 448), + crop_pct=1.0, + num_classes=0, + ), + "vit_so400m_patch14_siglip_gap_896.pali_pt": _cfg( + hf_hub_id="google/paligemma-3b-pt-896-jax", + hf_hub_filename="paligemma-3b-pt-896.npz", + custom_load="hf", + input_size=(3, 896, 896), + crop_pct=1.0, + num_classes=0, + ), + "vit_xsmall_patch16_clip_224.tinyclip_yfcc15m": _cfg( + hf_hub_id="timm/", + hf_hub_filename="open_clip_pytorch_model.bin", + license="mit", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + num_classes=512, + ), + "vit_medium_patch32_clip_224.tinyclip_laion400m": _cfg( + hf_hub_id="timm/", + hf_hub_filename="open_clip_pytorch_model.bin", + license="mit", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + num_classes=512, + ), + "vit_medium_patch16_clip_224.tinyclip_yfcc15m": _cfg( + hf_hub_id="timm/", + hf_hub_filename="open_clip_pytorch_model.bin", + license="mit", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + num_classes=512, + ), + "vit_betwixt_patch32_clip_224.tinyclip_laion400m": _cfg( + hf_hub_id="timm/", + hf_hub_filename="open_clip_pytorch_model.bin", + license="mit", + mean=OPENAI_CLIP_MEAN, + std=OPENAI_CLIP_STD, + num_classes=512, + ), + "vit_wee_patch16_reg1_gap_256.sbb_in1k": _cfg( + hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_pwee_patch16_reg1_gap_256.sbb_in1k": _cfg( + hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_little_patch16_reg1_gap_256.sbb_in12k_ft_in1k": _cfg( + hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_little_patch16_reg1_gap_256.sbb_in12k": _cfg( + hf_hub_id="timm/", num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_little_patch16_reg4_gap_256.sbb_in1k": _cfg( + hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_medium_patch16_reg1_gap_256.sbb_in1k": _cfg( + hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_medium_patch16_reg4_gap_256.sbb_in12k_ft_in1k": _cfg( + hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_medium_patch16_reg4_gap_256.sbb_in1k": _cfg( + hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_medium_patch16_reg4_gap_256.sbb_in12k": _cfg( + hf_hub_id="timm/", num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k": _cfg( + hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_mediumd_patch16_reg4_gap_256.sbb_in12k": _cfg( + hf_hub_id="timm/", num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_betwixt_patch16_reg1_gap_256.sbb_in1k": _cfg( + hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k": _cfg( + hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_betwixt_patch16_reg4_gap_256.sbb_in1k": _cfg( + hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_betwixt_patch16_reg4_gap_256.sbb_in12k": _cfg( + hf_hub_id="timm/", num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95 + ), + "vit_base_patch16_reg4_gap_256": _cfg(input_size=(3, 256, 256)), + "vit_so150m_patch16_reg4_gap_256": _cfg(input_size=(3, 256, 256)), + "vit_so150m_patch16_reg4_map_256": _cfg(input_size=(3, 256, 256)), +} + +_quick_gelu_cfgs = [ + "vit_large_patch14_clip_224.dfn2b", + "vit_huge_patch14_clip_224.dfn5b", + "vit_huge_patch14_clip_378.dfn5b", + "vit_base_patch32_clip_224.metaclip_2pt5b", + "vit_base_patch16_clip_224.metaclip_2pt5b", + "vit_large_patch14_clip_224.metaclip_2pt5b", + "vit_huge_patch14_clip_224.metaclip_2pt5b", + "vit_base_patch32_clip_224.openai", + "vit_base_patch16_clip_224.openai", + "vit_large_patch14_clip_224.openai", + "vit_large_patch14_clip_336.openai", +] +default_cfgs.update( + {n.replace("_clip_", "_clip_quickgelu_"): default_cfgs[n] for n in _quick_gelu_cfgs} +) +default_cfgs = generate_default_cfgs(default_cfgs) + + +def _create_vision_transformer( + variant: str, pretrained: bool = False, **kwargs +) -> VisionTransformer: + out_indices = kwargs.pop("out_indices", 3) + if "flexi" in variant: + # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed + # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. + _filter_fn = partial( + checkpoint_filter_fn, interpolation="bilinear", antialias=False + ) + else: + _filter_fn = checkpoint_filter_fn + + # FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln? + strict = True + if "siglip" in variant and kwargs.get("global_pool", None) != "map": + strict = False + + if "attn_layer" in kwargs: + strict = False + + return build_model_with_cfg( + VisionTransformer, + variant, + pretrained, + pretrained_filter_fn=_filter_fn, + pretrained_strict=strict, + feature_cfg=dict(out_indices=out_indices, feature_cls="getter"), + **kwargs, + ) + + +@register_model +def vit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Tiny (Vit-Ti/16)""" + model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) + model = _create_vision_transformer( + "vit_tiny_patch16_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_tiny_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Tiny (Vit-Ti/16) @ 384x384.""" + model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) + model = _create_vision_transformer( + "vit_tiny_patch16_384", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_small_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Small (ViT-S/32)""" + model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer( + "vit_small_patch32_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_small_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Small (ViT-S/32) at 384x384.""" + model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer( + "vit_small_patch32_384", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Small (ViT-S/16)""" + model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer( + "vit_small_patch16_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_small_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Small (ViT-S/16)""" + model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer( + "vit_small_patch16_384", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_small_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Small (ViT-S/8)""" + model_args = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer( + "vit_small_patch8_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer( + "vit_base_patch32_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer( + "vit_base_patch32_384", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer( + "vit_base_patch16_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer( + "vit_base_patch16_384", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer( + "vit_base_patch8_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_large_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.""" + model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer( + "vit_large_patch32_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_large_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer( + "vit_large_patch32_384", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer( + "vit_large_patch16_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_large_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer( + "vit_large_patch16_384", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Large model (ViT-L/14)""" + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer( + "vit_large_patch14_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).""" + model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16) + model = _create_vision_transformer( + "vit_huge_patch14_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560""" + model_args = dict( + patch_size=14, embed_dim=1408, mlp_ratio=48 / 11, depth=40, num_heads=16 + ) + model = _create_vision_transformer( + "vit_giant_patch14_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_gigantic_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560""" + model_args = dict( + patch_size=14, embed_dim=1664, mlp_ratio=64 / 13, depth=48, num_heads=16 + ) + model = _create_vision_transformer( + "vit_gigantic_patch14_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch16_224_miil(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False + ) + model = _create_vision_transformer( + "vit_base_patch16_224_miil", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_medium_patch16_gap_240(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240""" + model_args = dict( + patch_size=16, + embed_dim=512, + depth=12, + num_heads=8, + class_token=False, + global_pool="avg", + qkv_bias=False, + init_values=1e-6, + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_medium_patch16_gap_240", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_medium_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256""" + model_args = dict( + patch_size=16, + embed_dim=512, + depth=12, + num_heads=8, + class_token=False, + global_pool="avg", + qkv_bias=False, + init_values=1e-6, + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_medium_patch16_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_medium_patch16_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384""" + model_args = dict( + patch_size=16, + embed_dim=512, + depth=12, + num_heads=8, + class_token=False, + global_pool="avg", + qkv_bias=False, + init_values=1e-6, + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_medium_patch16_gap_384", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_betwixt_patch16_gap_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """ViT-Betwixt (ViT-b/16) w/o class token, w/ avg-pool @ 256x256""" + model_args = dict( + patch_size=16, + embed_dim=640, + depth=12, + num_heads=10, + class_token=False, + global_pool="avg", + qkv_bias=False, + init_values=1e-6, + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_medium_patch16_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=16, + class_token=False, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_base_patch16_gap_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_huge_patch14_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Huge model (ViT-H/14) w/ no class token, avg pool""" + model_args = dict( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + class_token=False, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_huge_patch14_gap_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_huge_patch16_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448""" + model_args = dict( + patch_size=16, + embed_dim=1280, + depth=32, + num_heads=16, + class_token=False, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_huge_patch16_gap_448", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool""" + model_args = dict( + patch_size=16, + embed_dim=1408, + depth=40, + num_heads=16, + mlp_ratio=48 / 11, + class_token=False, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_giant_patch16_gap_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_xsmall_patch16_clip_224( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + # TinyCLIP 8M + model_args = dict( + embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm + ) + model = _create_vision_transformer( + "vit_xsmall_patch16_clip_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_medium_patch32_clip_224( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + # TinyCLIP 40M + model_args = dict( + patch_size=32, + embed_dim=512, + depth=12, + num_heads=8, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_medium_patch32_clip_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_medium_patch16_clip_224( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + # TinyCLIP 39M + model_args = dict( + embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm + ) + model = _create_vision_transformer( + "vit_medium_patch16_clip_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_betwixt_patch32_clip_224( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + # TinyCLIP 61M + model_args = dict( + patch_size=32, + embed_dim=640, + depth=12, + num_heads=10, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_betwixt_patch32_clip_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-B/32 CLIP image tower @ 224x224""" + model_args = dict( + patch_size=32, + embed_dim=768, + depth=12, + num_heads=12, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_base_patch32_clip_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch32_clip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-B/32 CLIP image tower @ 256x256""" + model_args = dict( + patch_size=32, + embed_dim=768, + depth=12, + num_heads=12, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_base_patch32_clip_256", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch32_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-B/32 CLIP image tower @ 384x384""" + model_args = dict( + patch_size=32, + embed_dim=768, + depth=12, + num_heads=12, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_base_patch32_clip_384", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch32_clip_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-B/32 CLIP image tower @ 448x448""" + model_args = dict( + patch_size=32, + embed_dim=768, + depth=12, + num_heads=12, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_base_patch32_clip_448", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-B/16 CLIP image tower""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_base_patch16_clip_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch16_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-B/16 CLIP image tower @ 384x384""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_base_patch16_clip_384", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Large model (ViT-L/14) CLIP image tower""" + model_args = dict( + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_large_patch14_clip_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Large model (ViT-L/14) CLIP image tower @ 336x336""" + model_args = dict( + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_large_patch14_clip_336", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_huge_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Huge model (ViT-H/14) CLIP image tower.""" + model_args = dict( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_huge_patch14_clip_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_huge_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336""" + model_args = dict( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_huge_patch14_clip_336", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_huge_patch14_clip_378(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378""" + model_args = dict( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_huge_patch14_clip_378", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_giant_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + Pretrained weights from CLIP image tower. + """ + model_args = dict( + patch_size=14, + embed_dim=1408, + mlp_ratio=48 / 11, + depth=40, + num_heads=16, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_giant_patch14_clip_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_gigantic_patch14_clip_224( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + Pretrained weights from CLIP image tower. + """ + model_args = dict( + patch_size=14, + embed_dim=1664, + mlp_ratio=64 / 13, + depth=48, + num_heads=16, + pre_norm=True, + norm_layer=nn.LayerNorm, + ) + model = _create_vision_transformer( + "vit_gigantic_patch14_clip_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch32_clip_quickgelu_224( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """ViT-B/32 CLIP image tower @ 224x224""" + model_args = dict( + patch_size=32, + embed_dim=768, + depth=12, + num_heads=12, + pre_norm=True, + norm_layer=nn.LayerNorm, + act_layer="quick_gelu", + ) + model = _create_vision_transformer( + "vit_base_patch32_clip_quickgelu_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch16_clip_quickgelu_224( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """ViT-B/16 CLIP image tower w/ QuickGELU act""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + pre_norm=True, + norm_layer=nn.LayerNorm, + act_layer="quick_gelu", + ) + model = _create_vision_transformer( + "vit_base_patch16_clip_quickgelu_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_large_patch14_clip_quickgelu_224( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act""" + model_args = dict( + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + pre_norm=True, + norm_layer=nn.LayerNorm, + act_layer="quick_gelu", + ) + model = _create_vision_transformer( + "vit_large_patch14_clip_quickgelu_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_large_patch14_clip_quickgelu_336( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 w/ QuickGELU act""" + model_args = dict( + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + pre_norm=True, + norm_layer=nn.LayerNorm, + act_layer="quick_gelu", + ) + model = _create_vision_transformer( + "vit_large_patch14_clip_quickgelu_336", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_huge_patch14_clip_quickgelu_224( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """ViT-Huge model (ViT-H/14) CLIP image tower w/ QuickGELU act.""" + model_args = dict( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + pre_norm=True, + norm_layer=nn.LayerNorm, + act_layer="quick_gelu", + ) + model = _create_vision_transformer( + "vit_huge_patch14_clip_quickgelu_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_huge_patch14_clip_quickgelu_378( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 w/ QuickGELU act""" + model_args = dict( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + pre_norm=True, + norm_layer=nn.LayerNorm, + act_layer="quick_gelu", + ) + model = _create_vision_transformer( + "vit_huge_patch14_clip_quickgelu_378", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +# Experimental models below + + +@register_model +def vit_base_patch32_plus_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Base (ViT-B/32+)""" + model_args = dict( + patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5 + ) + model = _create_vision_transformer( + "vit_base_patch32_plus_256", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch16_plus_240(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Base (ViT-B/16+)""" + model_args = dict( + patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5 + ) + model = _create_vision_transformer( + "vit_base_patch16_plus_240", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch16_rpn_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Base (ViT-B/16) w/ residual post-norm""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + qkv_bias=False, + init_values=1e-5, + class_token=False, + block_fn=ResPostBlock, + global_pool="avg", + ) + model = _create_vision_transformer( + "vit_base_patch16_rpn_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_small_patch16_36x1_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. + """ + model_args = dict( + patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5 + ) + model = _create_vision_transformer( + "vit_small_patch16_36x1_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_small_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. + """ + model_args = dict( + patch_size=16, + embed_dim=384, + depth=18, + num_heads=6, + init_values=1e-5, + block_fn=ParallelThingsBlock, + ) + model = _create_vision_transformer( + "vit_small_patch16_18x2_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + """ + model_args = dict( + patch_size=16, + embed_dim=768, + depth=18, + num_heads=12, + init_values=1e-5, + block_fn=ParallelThingsBlock, + ) + model = _create_vision_transformer( + "vit_base_patch16_18x2_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def eva_large_patch14_196(pretrained: bool = False, **kwargs) -> VisionTransformer: + """EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain""" + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool="avg" + ) + model = _create_vision_transformer( + "eva_large_patch14_196", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def eva_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain""" + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool="avg" + ) + model = _create_vision_transformer( + "eva_large_patch14_336", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def flexivit_small(pretrained: bool = False, **kwargs) -> VisionTransformer: + """FlexiViT-Small""" + model_args = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True + ) + model = _create_vision_transformer( + "flexivit_small", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def flexivit_base(pretrained: bool = False, **kwargs) -> VisionTransformer: + """FlexiViT-Base""" + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True + ) + model = _create_vision_transformer( + "flexivit_base", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def flexivit_large(pretrained: bool = False, **kwargs) -> VisionTransformer: + """FlexiViT-Large""" + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True + ) + model = _create_vision_transformer( + "flexivit_large", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch16_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + pre_norm=True, + no_embed_class=True, + norm_layer=RmsNorm, + block_fn=ParallelScalingBlock, + qkv_bias=False, + qk_norm=True, + ) + model = _create_vision_transformer( + "vit_base_patch16_xp_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_large_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.""" + model_args = dict( + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + pre_norm=True, + no_embed_class=True, + norm_layer=RmsNorm, + block_fn=ParallelScalingBlock, + qkv_bias=False, + qk_norm=True, + ) + model = _create_vision_transformer( + "vit_large_patch14_xp_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-Huge model (ViT-H/14) w/ parallel blocks and qk norm enabled.""" + model_args = dict( + patch_size=14, + embed_dim=1280, + depth=32, + num_heads=16, + pre_norm=True, + no_embed_class=True, + norm_layer=RmsNorm, + block_fn=ParallelScalingBlock, + qkv_bias=False, + qk_norm=True, + ) + model = _create_vision_transformer( + "vit_huge_patch14_xp_224", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-S/14 for DINOv2""" + model_args = dict( + patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5 + ) + model = _create_vision_transformer( + "vit_small_patch14_dinov2", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-B/14 for DINOv2""" + model_args = dict( + patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5 + ) + model = _create_vision_transformer( + "vit_base_patch14_dinov2", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-L/14 for DINOv2""" + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5 + ) + model = _create_vision_transformer( + "vit_large_patch14_dinov2", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ViT-G/14 for DINOv2""" + # The hidden_features of SwiGLU is calculated by: + # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + # When embed_dim=1536, hidden_features=4096 + # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 + model_args = dict( + patch_size=14, + embed_dim=1536, + depth=40, + num_heads=24, + init_values=1e-5, + mlp_ratio=2.66667 * 2, + mlp_layer=SwiGLUPacked, + act_layer=nn.SiLU, + ) + model = _create_vision_transformer( + "vit_giant_patch14_dinov2", pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + +@register_model +def vit_small_patch14_reg4_dinov2( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """ViT-S/14 for DINOv2 w/ 4 registers""" + model_args = dict( + patch_size=14, + embed_dim=384, + depth=12, + num_heads=6, + init_values=1e-5, + reg_tokens=4, + no_embed_class=True, + ) + model = _create_vision_transformer( + "vit_small_patch14_reg4_dinov2", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch14_reg4_dinov2( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """ViT-B/14 for DINOv2 w/ 4 registers""" + model_args = dict( + patch_size=14, + embed_dim=768, + depth=12, + num_heads=12, + init_values=1e-5, + reg_tokens=4, + no_embed_class=True, + ) + model = _create_vision_transformer( + "vit_base_patch14_reg4_dinov2", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_large_patch14_reg4_dinov2( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """ViT-L/14 for DINOv2 w/ 4 registers""" + model_args = dict( + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + init_values=1e-5, + reg_tokens=4, + no_embed_class=True, + ) + model = _create_vision_transformer( + "vit_large_patch14_reg4_dinov2", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_giant_patch14_reg4_dinov2( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """ViT-G/14 for DINOv2""" + # The hidden_features of SwiGLU is calculated by: + # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + # When embed_dim=1536, hidden_features=4096 + # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 + model_args = dict( + patch_size=14, + embed_dim=1536, + depth=40, + num_heads=24, + init_values=1e-5, + mlp_ratio=2.66667 * 2, + mlp_layer=SwiGLUPacked, + act_layer=nn.SiLU, + reg_tokens=4, + no_embed_class=True, + ) + model = _create_vision_transformer( + "vit_giant_patch14_reg4_dinov2", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch16_siglip_224( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + class_token=False, + global_pool="map", + ) + model = _create_vision_transformer( + "vit_base_patch16_siglip_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch16_siglip_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + class_token=False, + global_pool="map", + ) + model = _create_vision_transformer( + "vit_base_patch16_siglip_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch16_siglip_384( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + class_token=False, + global_pool="map", + ) + model = _create_vision_transformer( + "vit_base_patch16_siglip_384", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch16_siglip_512( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + class_token=False, + global_pool="map", + ) + model = _create_vision_transformer( + "vit_base_patch16_siglip_512", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_large_patch16_siglip_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + class_token=False, + global_pool="map", + ) + model = _create_vision_transformer( + "vit_large_patch16_siglip_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_large_patch16_siglip_384( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + class_token=False, + global_pool="map", + ) + model = _create_vision_transformer( + "vit_large_patch16_siglip_384", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_so400m_patch14_siglip_224( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=14, + embed_dim=1152, + depth=27, + num_heads=16, + mlp_ratio=3.7362, + class_token=False, + global_pool="map", + ) + model = _create_vision_transformer( + "vit_so400m_patch14_siglip_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_so400m_patch14_siglip_384( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=14, + embed_dim=1152, + depth=27, + num_heads=16, + mlp_ratio=3.7362, + class_token=False, + global_pool="map", + ) + model = _create_vision_transformer( + "vit_so400m_patch14_siglip_384", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch16_siglip_gap_224( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + class_token=False, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_base_patch16_siglip_gap_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch16_siglip_gap_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + class_token=False, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_base_patch16_siglip_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch16_siglip_gap_384( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + class_token=False, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_base_patch16_siglip_gap_384", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch16_siglip_gap_512( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + class_token=False, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_base_patch16_siglip_gap_512", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_large_patch16_siglip_gap_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + class_token=False, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_large_patch16_siglip_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_large_patch16_siglip_gap_384( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + class_token=False, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_large_patch16_siglip_gap_384", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_224( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=14, + embed_dim=1152, + depth=27, + num_heads=16, + mlp_ratio=3.7362, + class_token=False, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_so400m_patch14_siglip_gap_224", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_384( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=14, + embed_dim=1152, + depth=27, + num_heads=16, + mlp_ratio=3.7362, + class_token=False, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_so400m_patch14_siglip_gap_384", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_448( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=14, + embed_dim=1152, + depth=27, + num_heads=16, + mlp_ratio=3.7362, + class_token=False, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_so400m_patch14_siglip_gap_448", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_896( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=14, + embed_dim=1152, + depth=27, + num_heads=16, + mlp_ratio=3.7362, + class_token=False, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_so400m_patch14_siglip_gap_896", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_wee_patch16_reg1_gap_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=256, + depth=14, + num_heads=4, + init_values=1e-5, + mlp_ratio=5, + class_token=False, + no_embed_class=True, + reg_tokens=1, + global_pool="avg", + ) + model = _create_vision_transformer( + "vit_wee_patch16_reg1_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_pwee_patch16_reg1_gap_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=256, + depth=16, + num_heads=4, + init_values=1e-5, + mlp_ratio=5, + class_token=False, + no_embed_class=True, + reg_tokens=1, + global_pool="avg", + block_fn=ParallelScalingBlock, + ) + model = _create_vision_transformer( + "vit_pwee_patch16_reg1_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_little_patch16_reg1_gap_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=320, + depth=14, + num_heads=5, + init_values=1e-5, + mlp_ratio=5.6, + class_token=False, + no_embed_class=True, + reg_tokens=1, + global_pool="avg", + ) + model = _create_vision_transformer( + "vit_little_patch16_reg1_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_little_patch16_reg4_gap_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=320, + depth=14, + num_heads=5, + init_values=1e-5, + mlp_ratio=5.6, + class_token=False, + no_embed_class=True, + reg_tokens=4, + global_pool="avg", + ) + model = _create_vision_transformer( + "vit_little_patch16_reg4_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_medium_patch16_reg1_gap_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=512, + depth=12, + num_heads=8, + init_values=1e-5, + class_token=False, + no_embed_class=True, + reg_tokens=1, + global_pool="avg", + ) + model = _create_vision_transformer( + "vit_medium_patch16_reg1_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_medium_patch16_reg4_gap_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=512, + depth=12, + num_heads=8, + init_values=1e-5, + class_token=False, + no_embed_class=True, + reg_tokens=4, + global_pool="avg", + ) + model = _create_vision_transformer( + "vit_medium_patch16_reg4_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_mediumd_patch16_reg4_gap_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=512, + depth=20, + num_heads=8, + init_values=1e-5, + class_token=False, + no_embed_class=True, + reg_tokens=4, + global_pool="avg", + ) + model = _create_vision_transformer( + "vit_mediumd_patch16_reg4_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_betwixt_patch16_reg1_gap_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=640, + depth=12, + num_heads=10, + init_values=1e-5, + class_token=False, + no_embed_class=True, + reg_tokens=1, + global_pool="avg", + ) + model = _create_vision_transformer( + "vit_betwixt_patch16_reg1_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_betwixt_patch16_reg4_gap_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=640, + depth=12, + num_heads=10, + init_values=1e-5, + class_token=False, + no_embed_class=True, + reg_tokens=4, + global_pool="avg", + ) + model = _create_vision_transformer( + "vit_betwixt_patch16_reg4_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_base_patch16_reg4_gap_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + class_token=False, + no_embed_class=True, + global_pool="avg", + reg_tokens=4, + ) + model = _create_vision_transformer( + "vit_base_patch16_reg4_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_so150m_patch16_reg4_map_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=896, + depth=18, + num_heads=14, + mlp_ratio=2.572, + class_token=False, + reg_tokens=4, + global_pool="map", + ) + model = _create_vision_transformer( + "vit_so150m_patch16_reg4_map_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +@register_model +def vit_so150m_patch16_reg4_gap_256( + pretrained: bool = False, **kwargs +) -> VisionTransformer: + model_args = dict( + patch_size=16, + embed_dim=896, + depth=18, + num_heads=14, + mlp_ratio=2.572, + class_token=False, + reg_tokens=4, + global_pool="avg", + fc_norm=False, + ) + model = _create_vision_transformer( + "vit_so150m_patch16_reg4_gap_256", + pretrained=pretrained, + **dict(model_args, **kwargs), + ) + return model + + +register_model_deprecations( + __name__, + { + "vit_tiny_patch16_224_in21k": "vit_tiny_patch16_224.augreg_in21k", + "vit_small_patch32_224_in21k": "vit_small_patch32_224.augreg_in21k", + "vit_small_patch16_224_in21k": "vit_small_patch16_224.augreg_in21k", + "vit_base_patch32_224_in21k": "vit_base_patch32_224.augreg_in21k", + "vit_base_patch16_224_in21k": "vit_base_patch16_224.augreg_in21k", + "vit_base_patch8_224_in21k": "vit_base_patch8_224.augreg_in21k", + "vit_large_patch32_224_in21k": "vit_large_patch32_224.orig_in21k", + "vit_large_patch16_224_in21k": "vit_large_patch16_224.augreg_in21k", + "vit_huge_patch14_224_in21k": "vit_huge_patch14_224.orig_in21k", + "vit_base_patch32_224_sam": "vit_base_patch32_224.sam", + "vit_base_patch16_224_sam": "vit_base_patch16_224.sam", + "vit_small_patch16_224_dino": "vit_small_patch16_224.dino", + "vit_small_patch8_224_dino": "vit_small_patch8_224.dino", + "vit_base_patch16_224_dino": "vit_base_patch16_224.dino", + "vit_base_patch8_224_dino": "vit_base_patch8_224.dino", + "vit_base_patch16_224_miil_in21k": "vit_base_patch16_224_miil.in21k", + "vit_base_patch32_224_clip_laion2b": "vit_base_patch32_clip_224.laion2b", + "vit_large_patch14_224_clip_laion2b": "vit_large_patch14_clip_224.laion2b", + "vit_huge_patch14_224_clip_laion2b": "vit_huge_patch14_clip_224.laion2b", + "vit_giant_patch14_224_clip_laion2b": "vit_giant_patch14_clip_224.laion2b", + }, +) diff --git a/src/vqvaes/xqgan/discriminator.py b/src/vqvaes/xqgan/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..cc447053a285823fc96043dee63d8c3003d498d9 --- /dev/null +++ b/src/vqvaes/xqgan/discriminator.py @@ -0,0 +1,283 @@ +# Modified from: +# taming-transformers: https://github.com/CompVis/taming-transformers +# stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py +# maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py +import functools +import math +import torch +import torch.nn as nn + +try: + from kornia.filters import filter2d +except: + pass + + +################################################################################# +# PatchGAN # +################################################################################# +class PatchGANDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(PatchGANDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if ( + type(norm_layer) == functools.partial + ): # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True), + ] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + nn.init.normal_(module.weight.data, 0.0, 0.02) + elif isinstance(module, nn.BatchNorm2d): + nn.init.normal_(module.weight.data, 1.0, 0.02) + nn.init.constant_(module.bias.data, 0) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + + +class ActNorm(nn.Module): + def __init__( + self, num_features, logdet=False, affine=True, allow_reverse_init=False + ): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:, :, None, None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height * width * torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:, :, None, None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h + + +################################################################################# +# StyleGAN # +################################################################################# +class StyleGANDiscriminator(nn.Module): + def __init__( + self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256 + ): + super().__init__() + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + log_size = int(math.log(image_size, 2)) + in_channel = channels[image_size] + + blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + blocks.append(DiscriminatorBlock(in_channel, out_channel)) + in_channel = out_channel + self.blocks = nn.ModuleList(blocks) + + self.final_conv = nn.Sequential( + nn.Conv2d(in_channel, channels[4], 3, padding=1), + leaky_relu(), + ) + self.final_linear = nn.Sequential( + nn.Linear(channels[4] * 4 * 4, channels[4]), + leaky_relu(), + nn.Linear(channels[4], 1), + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = self.final_conv(x) + x = x.view(x.shape[0], -1) + x = self.final_linear(x) + return x + + +class DiscriminatorBlock(nn.Module): + def __init__(self, input_channels, filters, downsample=True): + super().__init__() + self.conv_res = nn.Conv2d( + input_channels, filters, 1, stride=(2 if downsample else 1) + ) + + self.net = nn.Sequential( + nn.Conv2d(input_channels, filters, 3, padding=1), + leaky_relu(), + nn.Conv2d(filters, filters, 3, padding=1), + leaky_relu(), + ) + + self.downsample = ( + nn.Sequential(Blur(), nn.Conv2d(filters, filters, 3, padding=1, stride=2)) + if downsample + else None + ) + + def forward(self, x): + res = self.conv_res(x) + x = self.net(x) + if exists(self.downsample): + x = self.downsample(x) + x = (x + res) * (1 / math.sqrt(2)) + return x + + +class Blur(nn.Module): + def __init__(self): + super().__init__() + f = torch.Tensor([1, 2, 1]) + self.register_buffer("f", f) + + def forward(self, x): + f = self.f + f = f[None, None, :] * f[None, :, None] + return filter2d(x, f, normalized=True) + + +def leaky_relu(p=0.2): + return nn.LeakyReLU(p, inplace=True) + + +def exists(val): + return val is not None diff --git a/src/vqvaes/xqgan/discriminator_dino.py b/src/vqvaes/xqgan/discriminator_dino.py new file mode 100644 index 0000000000000000000000000000000000000000..3e11c38c0d40af5c52615bfb7481367e582b673e --- /dev/null +++ b/src/vqvaes/xqgan/discriminator_dino.py @@ -0,0 +1,620 @@ +import math +import os.path +import random +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.spectral_norm import SpectralNorm +from torchvision.transforms import RandomCrop + +import dist + +try: + from flash_attn.ops.layer_norm import dropout_add_layer_norm + from flash_attn.ops.fused_dense import fused_mlp_func +except: + dropout_add_layer_norm = fused_mlp_func = None + +try: + from flash_attn import flash_attn_qkvpacked_func # qkv: BL3Hc, ret: BLHcq +except: + flash_attn_qkvpacked_func = None + +try: + assert torch.cuda.is_available() + from torch.nn.functional import ( + scaled_dot_product_attention as slow_attn, + ) # q, k, v: BHLc +except: + + def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0): + attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL + if attn_mask is not None: + attn.add_(attn_mask) + return ( + F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) + if dropout_p > 0 + else attn.softmax(dim=-1) + ) @ value + + +class MLPNoDrop(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + fused_if_available=True, + ): + super().__init__() + self.fused_mlp_func = ( + fused_mlp_func + if (torch.cuda.is_available() and fused_if_available) + else None + ) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = nn.GELU(approximate="tanh") + self.fc2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + if self.fused_mlp_func is not None: + return self.fused_mlp_func( + x=x, + weight1=self.fc1.weight, + weight2=self.fc2.weight, + bias1=self.fc1.bias, + bias2=self.fc2.bias, + activation="gelu_approx", + save_pre_act=self.training, + return_residual=False, + checkpoint_lvl=0, + heuristic=0, + process_group=None, + ) + else: + return self.fc2(self.act(self.fc1(x))) + + def extra_repr(self) -> str: + return f"fused_mlp_func={self.fused_mlp_func is not None}" + + +class SelfAttentionNoDrop(nn.Module): + def __init__( + self, + block_idx, + embed_dim=768, + num_heads=12, + flash_if_available=True, + ): + super().__init__() + assert embed_dim % num_heads == 0 + self.block_idx, self.num_heads, self.head_dim = ( + block_idx, + num_heads, + embed_dim // num_heads, + ) # =64 + self.scale = 1 / math.sqrt(self.head_dim) + self.qkv, self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=True), nn.Linear( + embed_dim, embed_dim, bias=True + ) + self.using_flash_attn = ( + torch.cuda.is_available() + and flash_if_available + and flash_attn_qkvpacked_func is not None + ) + + def forward(self, x): + B, L, C = x.shape + qkv = self.qkv(x).view(B, L, 3, self.num_heads, self.head_dim) + if self.using_flash_attn and qkv.dtype != torch.float32: + oup = flash_attn_qkvpacked_func(qkv, softmax_scale=self.scale).view(B, L, C) + else: + q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0) # BHLc + oup = ( + slow_attn(query=q, key=k, value=v, scale=self.scale) + .transpose(1, 2) + .reshape(B, L, C) + ) + return self.proj(oup) + + def extra_repr(self) -> str: + return f"using_flash_attn={self.using_flash_attn}" + + +class SABlockNoDrop(nn.Module): + def __init__(self, block_idx, embed_dim, num_heads, mlp_ratio, norm_eps): + super(SABlockNoDrop, self).__init__() + self.norm1 = nn.LayerNorm(embed_dim, eps=norm_eps) + self.attn = SelfAttentionNoDrop( + block_idx=block_idx, + embed_dim=embed_dim, + num_heads=num_heads, + flash_if_available=True, + ) + self.norm2 = nn.LayerNorm(embed_dim, eps=norm_eps) + self.mlp = MLPNoDrop( + in_features=embed_dim, + hidden_features=round(embed_dim * mlp_ratio), + fused_if_available=True, + ) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class ResidualBlock(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.ratio = 1 / np.sqrt(2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x = x.float() + return (self.fn(x).add(x)).mul_(self.ratio) + + +class SpectralConv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + SpectralNorm.apply(self, name="weight", n_power_iterations=1, dim=0, eps=1e-12) + + +class BatchNormLocal(nn.Module): + def __init__( + self, + num_features: int, + affine: bool = True, + virtual_bs: int = 8, + eps: float = 1e-6, + ): + super().__init__() + self.virtual_bs = virtual_bs + self.eps = eps + self.affine = affine + + if self.affine: + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shape = x.size() + x = x.float() + + # Reshape batch into groups. + G = np.ceil(x.size(0) / self.virtual_bs).astype(int) + x = x.view(G, -1, x.size(-2), x.size(-1)) + + # Calculate stats. + mean = x.mean([1, 3], keepdim=True) + var = x.var([1, 3], keepdim=True, unbiased=False) + x = (x - mean) / (torch.sqrt(var + self.eps)) + + if self.affine: + x = x * self.weight[None, :, None] + self.bias[None, :, None] + + return x.view(shape) + + +def make_block( + channels: int, + kernel_size: int, + norm_type: str, + norm_eps: float, + using_spec_norm: bool, +) -> nn.Module: + if norm_type == "bn": + norm = BatchNormLocal(channels, eps=norm_eps) + elif norm_type == "sbn": + norm = nn.SyncBatchNorm(channels, eps=norm_eps, process_group=None) + elif norm_type in {"lbn", "hbn"}: + norm = nn.SyncBatchNorm( + channels, eps=norm_eps, process_group=dist.new_local_machine_group() + ) + elif norm_type == "gn": + norm = nn.GroupNorm( + num_groups=32, num_channels=channels, eps=norm_eps, affine=True + ) + else: + raise NotImplementedError + + return nn.Sequential( + (SpectralConv1d if using_spec_norm else nn.Conv1d)( + channels, + channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + padding_mode="circular", + ), + norm, + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + +class DinoDisc(nn.Module): + def __init__( + self, + dino_ckpt_path="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + device="cuda", + ks=9, + depth=12, + key_depths=(2, 5, 8, 11), + norm_type="bn", + using_spec_norm=True, + norm_eps=1e-6, + ): + super().__init__() + # load state + state = torch.hub.load_state_dict_from_url(dino_ckpt_path, map_location="cpu") + # state = torch.load(dino_ckpt_path, 'cpu') + for k in sorted(state.keys()): + if ".attn.qkv.bias" in k: + bias = state[k] + C = bias.numel() // 3 + bias[C : 2 * C].zero_() # zero out k_bias + # build DINO + key_depths = tuple(d for d in key_depths if d < depth) + d = FrozenDINOSmallNoDrop(depth=depth, key_depths=key_depths, norm_eps=norm_eps) + missing, unexpected = d.load_state_dict(state, strict=False) + missing = [ + m + for m in missing + if all( + x not in m + for x in { + "x_scale", + "x_shift", + } + ) + ] + if torch.cuda.is_available(): + assert len(missing) == 0, f"missing keys: {missing}" + assert len(unexpected) == 0, f"unexpected keys: {unexpected}" + + # todo: don't compile! reduce-overhead would raise CudaERR + self.dino_proxy: Tuple[FrozenDINOSmallNoDrop] = (d.to(device=device),) + dino_C = self.dino_proxy[0].embed_dim + # if 'KEVIN_LOCAL' in os.environ: + # torch.manual_seed(0) + # np.random.seed(0) + # random.seed(0) + self.heads = nn.ModuleList( + [ + nn.Sequential( + make_block( + dino_C, + kernel_size=1, + norm_type=norm_type, + norm_eps=norm_eps, + using_spec_norm=using_spec_norm, + ), + ResidualBlock( + make_block( + dino_C, + kernel_size=ks, + norm_type=norm_type, + norm_eps=norm_eps, + using_spec_norm=using_spec_norm, + ) + ), + (SpectralConv1d if using_spec_norm else nn.Conv1d)( + dino_C, 1, kernel_size=1, padding=0 + ), + ) + for _ in range(len(key_depths) + 1) # +1: before all attention blocks + ] + ) + + def reinit( + self, + dino_ckpt_path="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + device="cuda", + ks=9, + depth=12, + key_depths=(2, 5, 8, 11), + norm_type="bn", + using_spec_norm=True, + norm_eps=1e-6, + ): + dino_C = self.dino_proxy[0].embed_dim + heads = nn.ModuleList( + [ + nn.Sequential( + make_block( + dino_C, + kernel_size=1, + norm_type=norm_type, + norm_eps=norm_eps, + using_spec_norm=using_spec_norm, + ), + ResidualBlock( + make_block( + dino_C, + kernel_size=ks, + norm_type=norm_type, + norm_eps=norm_eps, + using_spec_norm=using_spec_norm, + ) + ), + (SpectralConv1d if using_spec_norm else nn.Conv1d)( + dino_C, 1, kernel_size=1, padding=0 + ), + ) + for _ in range(len(key_depths) + 1) + ] + ) + + self.heads.load_state_dict(heads.state_dict()) + + def forward( + self, x_in_pm1, grad_ckpt=False + ): # x_in_pm1: image tensor normalized to [-1, 1] + dino_grad_ckpt = grad_ckpt and x_in_pm1.requires_grad + FrozenDINOSmallNoDrop.forward + activations: List[torch.Tensor] = self.dino_proxy[0]( + x_in_pm1.float(), grad_ckpt=dino_grad_ckpt + ) + B = x_in_pm1.shape[0] + return torch.cat( + [ + ( + h(act) + if not grad_ckpt + else torch.utils.checkpoint.checkpoint(h, act, use_reentrant=False) + ).view(B, -1) + for h, act in zip(self.heads, activations) + ], + dim=1, + ) # cat 5 BL => B, 5L + + +class PatchEmbed(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + ): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = (img_size // patch_size) ** 2 + self.flatten = flatten + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x).flatten(2).transpose(1, 2) # BCHW => BCL => BLC + return self.norm(x) + + +class FrozenDINOSmallNoDrop(nn.Module): + """ + Frozen DINO ViT without any dropout or droppath layers (eval node only), based on timm.create_model('vit_small_patch16_224', pretrained=False, num_classes=0) + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__( + self, + depth=12, + key_depths=(2, 5, 8, 11), + norm_eps=1e-6, # 4 stages: 012, 345, 678, 9 10 11 + patch_size=16, + in_chans=3, + num_classes=0, + embed_dim=384, + num_heads=6, + mlp_ratio=4.0, + # drop_rate=0., attn_drop_rate=0., drop_path_rate=0. # no drop for frozen model + ): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + + self.img_size = 224 + self.patch_embed = PatchEmbed( + img_size=self.img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + self.patch_size = patch_size + self.patch_nums = self.img_size // patch_size + + # x \in [-1, 1] + # x = ((x+1)/2 - m) / s = 0.5x/s + 0.5/s - m/s = (0.5/s) x + (0.5-m)/s + m, s = torch.tensor((0.485, 0.456, 0.406)), torch.tensor((0.229, 0.224, 0.225)) + self.register_buffer("x_scale", (0.5 / s).reshape(1, 3, 1, 1)) + self.register_buffer("x_shift", ((0.5 - m) / s).reshape(1, 3, 1, 1)) + self.crop = RandomCrop(self.img_size) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = None + self.pos_embed = nn.Parameter( + torch.zeros(1, self.patch_nums * self.patch_nums + 1, embed_dim) + ) # +1: for cls + # self.pos_drop = nn.Dropout(p=drop_rate) + # self.pos_pool = dict() + + self.key_depths = set(d for d in key_depths if d < depth) + # dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # no drop for frozen model + self.blocks = nn.Sequential( + *[ + SABlockNoDrop( + block_idx=i, + embed_dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_eps=norm_eps, + ) + for i in range(max(depth, 1 + max(self.key_depths))) + ] + ) + self.norm = nn.LayerNorm(embed_dim, eps=norm_eps) + + # eval mode only + self.eval() + [p.requires_grad_(False) for p in self.parameters()] + + def inter_pos_embed(self, patch_nums=(14, 14)): + if patch_nums[0] == self.patch_nums and patch_nums[1] == self.patch_nums: + return self.pos_embed + pe_cls, pe_grid = self.pos_embed[:, :1], self.pos_embed[0, 1:] + pe_grid = pe_grid.reshape(1, self.patch_nums, self.patch_nums, -1).permute( + 0, 3, 1, 2 + ) + pe_grid = F.interpolate( + pe_grid, + size=(patch_nums[0], patch_nums[1]), + mode="bilinear", + align_corners=False, + ) + pe_grid = pe_grid.permute(0, 2, 3, 1).reshape( + 1, patch_nums[0] * patch_nums[1], -1 + ) + return torch.cat([pe_cls, pe_grid], dim=1) + + def forward(self, x, grad_ckpt=False): + with torch.cuda.amp.autocast(enabled=False): + x = (self.x_scale * x.float()).add_(self.x_shift) + H, W = x.shape[-2], x.shape[-1] + if H > self.img_size and W > self.img_size and random.random() <= 0.5: + x = self.crop(x) + else: + x = F.interpolate( + x, + size=(self.img_size, self.img_size), + mode="area" if H > self.img_size else "bicubic", + ) + # x now must be self.img_size x self.img_size + + # patch_nums = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size + # x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), self.patch_embed(x)), dim=1) + # if patch_nums in self.pos_pool: + # x += self.pos_pool[patch_nums] + # else: + # self.pos_pool[patch_nums] = pe = self.inter_pos_embed(patch_nums) + # x += pe + # x = self.pos_drop(x) + + x = self.patch_embed(x) + + with torch.cuda.amp.autocast(enabled=False): + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x.float()), dim=1) + x = x + self.pos_embed + activations = [(x[:, 1:] + x[:, :1]).transpose_(1, 2)] # readout + for i, b in enumerate(self.blocks): + if not grad_ckpt: + x = b(x) + else: + x = torch.utils.checkpoint.checkpoint(b, x, use_reentrant=False) + if i in self.key_depths: + activations.append( + (x[:, 1:].float() + x[:, :1].float()).transpose_(1, 2) + ) # readout + # x = self.norm(x) + return activations + + +if __name__ == "__main__": + torch.manual_seed(0) + np.random.seed(0) + random.seed(0) + ks = 9 + norm_type = "sbn" + norm_eps = 1e-6 + dino_C = 384 + key_layers = (2, 5, 8, 11) + using_spec_norm = True + + heads = nn.ModuleList( + [ + nn.Sequential( + make_block( + dino_C, + kernel_size=1, + norm_type=norm_type, + norm_eps=norm_eps, + using_spec_norm=using_spec_norm, + ), + ResidualBlock( + make_block( + dino_C, + kernel_size=ks, + norm_type=norm_type, + norm_eps=norm_eps, + using_spec_norm=using_spec_norm, + ) + ), + (SpectralConv1d if using_spec_norm else nn.Conv1d)( + dino_C, 1, kernel_size=1, padding=0 + ), + ) + for _ in range(len(key_layers) + 1) + ] + ) + + # ckpt = os.path.join(os.path.dirname(__file__), '/mnt/bn/foundation-lq/tiankeyu/ckpt_vae/vit_small_patch16_224.pth') + ckpt = "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" + DinoDisc.forward + dd = DinoDisc( + dino_ckpt_path=ckpt, + device="cpu", + ks=ks, + norm_type=norm_type, + norm_eps=norm_eps, + key_depths=key_layers, + ) + dd.eval() + dd.heads.load_state_dict(heads.state_dict()) + print(f"{sum(p.numel() for p in dd.parameters() if p.requires_grad) / 1e6:.2f}M") + inp = torch.linspace(-2, 2, 2 * 3 * 224 * 224).reshape(2, 3, 224, 224) + inp.requires_grad = True + cond = torch.rand(2, 64) + mid_ls = dd.dino_proxy[0](inp) + means = [round(m.mean().item(), 3) for m in mid_ls] + stds = [round(m.std().item(), 3) for m in mid_ls] + print(f"mean: {means}") + print(f"std: {stds}") + + o = dd(inp, grad_ckpt=True) + print(f"o: {o.abs().mean().item():.9f}, {o.abs().std().item():.9f}") + o.abs().mean().backward() + + # for n, p in dd.named_parameters(): + # tag = n.split('heads.')[-1][0] + # if p.ndim == 3: tag += '.conv1d' + # print(f'[{tag}] {n}: {p.shape}') + +""" +对于使用qkv的版本,输出是 +7.39M +mean: [0.019, -0.028, 0.054, 0.058, 0.074] +std: [0.427, 0.142, 0.169, 0.194, 0.153] +o: 50.266475677, 91.698143005 + +对于使用zero_k_bias的版本,输出是 +7.39M +mean: [0.019, -0.028, 0.054, 0.058, 0.074] +std: [0.427, 0.142, 0.169, 0.194, 0.153] +o: 50.266475677, 91.698143005 +""" diff --git a/src/vqvaes/xqgan/discriminator_patchgan.py b/src/vqvaes/xqgan/discriminator_patchgan.py new file mode 100644 index 0000000000000000000000000000000000000000..ae668eec77a09b247194eff06d1628baf0132084 --- /dev/null +++ b/src/vqvaes/xqgan/discriminator_patchgan.py @@ -0,0 +1,174 @@ +# Modified from: +# taming-transformers: https://github.com/CompVis/taming-transformers +import functools +import torch +import torch.nn as nn + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if ( + type(norm_layer) == functools.partial + ): # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True), + ] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + nn.init.normal_(module.weight.data, 0.0, 0.02) + elif isinstance(module, nn.BatchNorm2d): + nn.init.normal_(module.weight.data, 1.0, 0.02) + nn.init.constant_(module.bias.data, 0) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + + +class ActNorm(nn.Module): + def __init__( + self, num_features, logdet=False, affine=True, allow_reverse_init=False + ): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:, :, None, None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height * width * torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:, :, None, None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h diff --git a/src/vqvaes/xqgan/discriminator_stylegan.py b/src/vqvaes/xqgan/discriminator_stylegan.py new file mode 100644 index 0000000000000000000000000000000000000000..18b0e0551614534c6fc6797393adfe5c42fa4250 --- /dev/null +++ b/src/vqvaes/xqgan/discriminator_stylegan.py @@ -0,0 +1,107 @@ +# Modified from: +# stylegan2-pytorch: https://github.com/lucidrains/stylegan2-pytorch/blob/master/stylegan2_pytorch/stylegan2_pytorch.py +# stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py +# maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py +import math +import torch +import torch.nn as nn + +try: + from kornia.filters import filter2d +except: + pass + + +class Discriminator(nn.Module): + def __init__( + self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256 + ): + super().__init__() + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + log_size = int(math.log(image_size, 2)) + in_channel = channels[image_size] + + blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + blocks.append(DiscriminatorBlock(in_channel, out_channel)) + in_channel = out_channel + self.blocks = nn.ModuleList(blocks) + + self.final_conv = nn.Sequential( + nn.Conv2d(in_channel, channels[4], 3, padding=1), + leaky_relu(), + ) + self.final_linear = nn.Sequential( + nn.Linear(channels[4] * 4 * 4, channels[4]), + leaky_relu(), + nn.Linear(channels[4], 1), + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = self.final_conv(x) + x = x.view(x.shape[0], -1) + x = self.final_linear(x) + return x + + +class DiscriminatorBlock(nn.Module): + def __init__(self, input_channels, filters, downsample=True): + super().__init__() + self.conv_res = nn.Conv2d( + input_channels, filters, 1, stride=(2 if downsample else 1) + ) + + self.net = nn.Sequential( + nn.Conv2d(input_channels, filters, 3, padding=1), + leaky_relu(), + nn.Conv2d(filters, filters, 3, padding=1), + leaky_relu(), + ) + + self.downsample = ( + nn.Sequential(Blur(), nn.Conv2d(filters, filters, 3, padding=1, stride=2)) + if downsample + else None + ) + + def forward(self, x): + res = self.conv_res(x) + x = self.net(x) + if exists(self.downsample): + x = self.downsample(x) + x = (x + res) * (1 / math.sqrt(2)) + return x + + +class Blur(nn.Module): + def __init__(self): + super().__init__() + f = torch.Tensor([1, 2, 1]) + self.register_buffer("f", f) + + def forward(self, x): + f = self.f + f = f[None, None, :] * f[None, :, None] + return filter2d(x, f, normalized=True) + + +def leaky_relu(p=0.2): + return nn.LeakyReLU(p, inplace=True) + + +def exists(val): + return val is not None diff --git a/src/vqvaes/xqgan/dist.py b/src/vqvaes/xqgan/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..54a0618d8827851a3ea4a2e65065b3fee0594704 --- /dev/null +++ b/src/vqvaes/xqgan/dist.py @@ -0,0 +1,231 @@ +import datetime +import functools +import os +import sys +from typing import List +from typing import Union + +import torch +import torch.distributed as tdist +import torch.multiprocessing as mp + +__rank, __local_rank, __world_size, __device = ( + 0, + 0, + 1, + "cuda" if torch.cuda.is_available() else "cpu", +) +__initialized = False + + +def initialized(): + return __initialized + + +def initialize(fork=False, backend="nccl", gpu_id_if_not_distibuted=0, timeout=30): + global __device + if not torch.cuda.is_available(): + print( + f"[dist initialize] cuda is not available, use cpu instead", file=sys.stderr + ) + return + elif "RANK" not in os.environ: + torch.cuda.set_device(gpu_id_if_not_distibuted) + __device = torch.empty(1).cuda().device + print( + f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', + file=sys.stderr, + ) + return + # then 'RANK' must exist + global_rank, num_gpus = int(os.environ["RANK"]), torch.cuda.device_count() + local_rank = global_rank % num_gpus + torch.cuda.set_device(local_rank) + + # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 + if mp.get_start_method(allow_none=True) is None: + method = "fork" if fork else "spawn" + print(f"[dist initialize] mp method={method}") + mp.set_start_method(method) + tdist.init_process_group( + backend=backend, timeout=datetime.timedelta(seconds=timeout * 60) + ) + + global __rank, __local_rank, __world_size, __initialized + __local_rank = local_rank + __rank, __world_size = tdist.get_rank(), tdist.get_world_size() + __device = torch.empty(1).cuda().device + __initialized = True + + assert tdist.is_initialized(), "torch.distributed is not initialized!" + print(f"[lrk={get_local_rank()}, rk={get_rank()}]") + + +def get_rank(): + return __rank + + +def get_local_rank(): + return __local_rank + + +def get_world_size(): + return __world_size + + +def get_device(): + return __device + + +def set_gpu_id(gpu_id: int): + if gpu_id is None: + return + global __device + if isinstance(gpu_id, (str, int)): + torch.cuda.set_device(int(gpu_id)) + __device = torch.empty(1).cuda().device + else: + raise NotImplementedError + + +def is_master(): + return __rank == 0 + + +def is_local_master(): + return __local_rank == 0 + + +def new_group(ranks: List[int]): + if __initialized: + return tdist.new_group(ranks=ranks) + return None + + +def barrier(): + if __initialized: + tdist.barrier() + + +def allreduce(t: torch.Tensor, async_op=False): + if __initialized: + if not t.is_cuda: + cu = t.detach().cuda() + ret = tdist.all_reduce(cu, async_op=async_op) + t.copy_(cu.cpu()) + else: + ret = tdist.all_reduce(t, async_op=async_op) + return ret + return None + + +def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: + if __initialized: + if not t.is_cuda: + t = t.cuda() + ls = [torch.empty_like(t) for _ in range(__world_size)] + tdist.all_gather(ls, t) + else: + ls = [t] + if cat: + ls = torch.cat(ls, dim=0) + return ls + + +def allgather_diff_shape( + t: torch.Tensor, cat=True +) -> Union[List[torch.Tensor], torch.Tensor]: + if __initialized: + if not t.is_cuda: + t = t.cuda() + + t_size = torch.tensor(t.size(), device=t.device) + ls_size = [torch.empty_like(t_size) for _ in range(__world_size)] + tdist.all_gather(ls_size, t_size) + + max_B = max(size[0].item() for size in ls_size) + pad = max_B - t_size[0].item() + if pad: + pad_size = (pad, *t.size()[1:]) + t = torch.cat((t, t.new_empty(pad_size)), dim=0) + + ls_padded = [torch.empty_like(t) for _ in range(__world_size)] + tdist.all_gather(ls_padded, t) + ls = [] + for t, size in zip(ls_padded, ls_size): + ls.append(t[: size[0].item()]) + else: + ls = [t] + if cat: + ls = torch.cat(ls, dim=0) + return ls + + +def broadcast(t: torch.Tensor, src_rank) -> None: + if __initialized: + if not t.is_cuda: + cu = t.detach().cuda() + tdist.broadcast(cu, src=src_rank) + t.copy_(cu.cpu()) + else: + tdist.broadcast(t, src=src_rank) + + +def dist_fmt_vals( + val: float, fmt: Union[str, None] = "%.2f" +) -> Union[torch.Tensor, List]: + if not initialized(): + return torch.tensor([val]) if fmt is None else [fmt % val] + + ts = torch.zeros(__world_size) + ts[__rank] = val + allreduce(ts) + if fmt is None: + return ts + return [fmt % v for v in ts.cpu().numpy().tolist()] + + +def master_only(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + force = kwargs.pop("force", False) + if force or is_master(): + ret = func(*args, **kwargs) + else: + ret = None + barrier() + return ret + + return wrapper + + +def local_master_only(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + force = kwargs.pop("force", False) + if force or is_local_master(): + ret = func(*args, **kwargs) + else: + ret = None + barrier() + return ret + + return wrapper + + +def for_visualize(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if is_master(): + # with torch.no_grad(): + ret = func(*args, **kwargs) + else: + ret = None + return ret + + return wrapper + + +def finalize(): + if __initialized: + tdist.destroy_process_group() diff --git a/src/vqvaes/xqgan/latent_perturbation.py b/src/vqvaes/xqgan/latent_perturbation.py new file mode 100644 index 0000000000000000000000000000000000000000..29b4912a833a1555823ac1d32cf64880fa0f6fd0 --- /dev/null +++ b/src/vqvaes/xqgan/latent_perturbation.py @@ -0,0 +1,43 @@ +import torch +import torch.nn.functional as F + + +def add_perturbation(z, z_q, z_channels, codebook_norm, codebook, alpha, beta, delta): + # reshape z -> (batch, height * width, channel) and flatten + z = torch.einsum("b c h w -> b h w c", z).contiguous() + z_flattened = z.view(-1, z_channels) + + if codebook_norm: + z = F.normalize(z, p=2, dim=-1) + z_flattened = F.normalize(z_flattened, p=2, dim=-1) + embedding = F.normalize(codebook.weight, p=2, dim=-1) + else: + embedding = codebook.weight + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(embedding**2, dim=1) + - 2 + * torch.einsum("bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding)) + ) + + _, min_encoding_indices = torch.topk(d, delta, dim=1, largest=False) + random_prob = torch.rand(min_encoding_indices.shape[0], device=d.device) + random_idx = torch.randint(0, delta, random_prob.shape, device=d.device) + random_idx = torch.where(random_prob > alpha, 0, random_idx) + min_encoding_indices = min_encoding_indices[ + torch.arange(min_encoding_indices.size(0)), random_idx + ] + + perturbed_z_q = codebook(min_encoding_indices).view(z.shape) + if codebook_norm: + perturbed_z_q = F.normalize(perturbed_z_q, p=2, dim=-1) + perturbed_z_q = z + (perturbed_z_q - z).detach() + perturbed_z_q = torch.einsum("b h w c -> b c h w", perturbed_z_q) + + mask = torch.arange(z.shape[0], device=perturbed_z_q.device) < int( + z.shape[0] * beta + ) + mask = mask[:, None, None, None] + + return torch.where(mask, perturbed_z_q, z_q) diff --git a/src/vqvaes/xqgan/linear_probing.py b/src/vqvaes/xqgan/linear_probing.py new file mode 100644 index 0000000000000000000000000000000000000000..1059fa26e4f9721f8f0f472692d904a4eb0a2fa6 --- /dev/null +++ b/src/vqvaes/xqgan/linear_probing.py @@ -0,0 +1,811 @@ +import os +import sys + +sys.path.append("/home/xiangl/LlamaGen") +import logging +import json +import numpy as np +import torch.distributed +from tqdm.auto import tqdm +from PIL import Image +from logging import getLogger as get_logger +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as NativeDDP +from torch.utils.data import DataLoader, default_collate +from utils2 import ( + init_distributed_device, + is_global_primary, + is_primary, + seed_everything, + str2bool, +) +from tokenizer.tokenizer_image.msvq_model import VQ_models +from datasets import ( + create_dataset, + fast_collate, + PrefetchLoader, + Normalize, + Denormalize, +) + +from timm.optim import create_optimizer_v2 as create_optimizer +from timm.scheduler import create_scheduler_v2 as create_scheduler + +import argparse + +logger = get_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser() + + # config file + parser.add_argument( + "--config", + type=str, + default="configs/vqgan/imagenet/vqvae_vq_dinov2base_v4096z16n64_pretrained_ae.yaml", + help="config file used to specify parameters", + ) + + # data + parser.add_argument( + "--data_dir", type=str, default="imagenet/train", help="data folder" + ) + parser.add_argument( + "--dataset_name", type=str, default="imagenet", help="dataset name" + ) + parser.add_argument( + "--val_data_dir", type=str, default="imagenet/val", help="data folder" + ) + parser.add_argument("--image_size", type=int, default=256, help="image size") + parser.add_argument("--batch_size", type=int, default=4, help="per gpu batch size") + parser.add_argument("--num_workers", type=int, default=8, help="batch size") + parser.add_argument( + "--num_classes", type=int, default=1000, help="number of classes in dataset" + ) + parser.add_argument( + "--use_prefetcher", type=str2bool, default=True, help="use prefetch" + ) + + # training + parser.add_argument("--run_name", type=str, default=None, help="run_name") + parser.add_argument( + "--output_dir", type=str, default="experiments", help="output folder" + ) + parser.add_argument("--num_epochs", type=int, default=10) + parser.add_argument("--optimizer", type=str, default="adamw", help="optimizer") + parser.add_argument( + "--learning_rate", type=float, default=1e-4, help="learning rate" + ) + parser.add_argument("--min_lr", type=float, default=5e-5, help="end learning rate") + parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay") + parser.add_argument( + "--lr_scheduler", type=str, default="cosine", help="lr scheduler" + ) + parser.add_argument( + "--lr_warmup_epochs", type=float, default=1, help="warmup epochs" + ) + parser.add_argument( + "--log_interval", type=int, default=50, help="log interval for steps" + ) + parser.add_argument( + "--val_interval", type=int, default=1000, help="validation interval for epochs" + ) + parser.add_argument("--save_interval", type=int, default=1, help="save interval") + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="gradient accumulation steps", + ) + parser.add_argument( + "--gradient_clip", type=float, default=1.0, help="gradient clip" + ) + parser.add_argument( + "--torchcompile", type=str2bool, default=False, help="use torch compile" + ) + parser.add_argument( + "--report_to", + type=str, + default="wandb", + help="report to", + choices=["wandb", "tensorboard", "none"], + ) + parser.add_argument( + "--resume", type=str, default=None, help="resume from pre-trained checkpoint" + ) + parser.add_argument( + "--auto_resume", + type=str2bool, + default=False, + help="auto resume from latest checkpoint", + ) + parser.add_argument("--seed", type=int, default=42, help="random seed") + parser.add_argument( + "--mixed_precision", + type=str, + default="bf16", + choices=["fp16", "bf16", "fp32"], + help="mixed precision", + ) + parser.add_argument( + "--ema", type=float, default=0, help="ema updates of the models" + ) + parser.add_argument("--beta1", type=int, default=0.9, help="beta1 for adam") + parser.add_argument("--beta2", type=int, default=0.99, help="beta2 for adam") + parser.add_argument( + "--quantizer_lr_multiplier", + type=float, + default=1.0, + help="lr multiplier for quantization", + ) + parser.add_argument( + "--compile", type=str2bool, default=False, help="use torch compile" + ) + + # loss weight + parser.add_argument( + "--disc_adaptive", + type=str2bool, + default=True, + help="flag of whether to use adaptive discriminator weight", + ) + parser.add_argument( + "--disc_loss_start", + type=float, + default=0, + help="starting threshold of adaptive discriminator weight for discriminator training", + ) + parser.add_argument( + "--disc_loss_weight", type=float, default=0.8, help="discriminator loss weight" + ) + parser.add_argument( + "--gen_disc_loss_weight", + type=float, + default=0.1, + help="discriminator loss weight of generator", + ) + parser.add_argument( + "--gen_disc_loss_type", + type=str, + default="non-saturating", + choices=["hinge", "vanilla", "non-saturating"], + help="generator loss type", + ) + parser.add_argument( + "--disc_loss_type", + type=str, + default="hinge", + choices=["hinge", "vanilla", "non-saturating"], + help="discriminator loss type", + ) + parser.add_argument( + "--disc_model", + type=str, + default="patchgan", + choices=["patchgan", "stylegan"], + help="discriminator loss type", + ) + parser.add_argument( + "--lecam_loss_weight", + type=float, + default=0.0, + help="lecam regularization loss weight of discriminator", + ) + parser.add_argument( + "--codebook_loss_weight", type=float, default=1.0, help="codebook loss weight" + ) + parser.add_argument( + "--perceptual_loss_weight", + type=float, + default=0.1, + help="perceptual loss weight", + ) + parser.add_argument( + "--logit_scale_loss_weight", + type=float, + default=0.1, + help="logit_scale loss weight", + ) + parser.add_argument( + "--rec_loss_weight", type=float, default=1.0, help="rec loss weight" + ) + + parser.add_argument( + "--ent_loss_weight", type=float, default=0.1, help="entropy loss weight" + ) + parser.add_argument( + "--ent_loss_weight_end", type=float, default=0.0, help="entropy loss weight" + ) + parser.add_argument( + "--ent_loss_start", type=float, default=1.0, help="start to add entropy loss" + ) + parser.add_argument( + "--ent_loss_annealing_steps", + type=float, + default=2000, + help="steps to anneal entropy loss weight", + ) + parser.add_argument( + "--sem_loss_weight", type=float, default=0.01, help="semantic loss weight" + ) + parser.add_argument( + "--ent_sample_min_loss_weight", + type=float, + default=1.0, + help="sample entropy minimization loss weight", + ) + parser.add_argument( + "--ent_batch_max_loss_weight", + type=float, + default=1.0, + help="batch entropy maximization loss weight", + ) + + # vqvae + parser.add_argument( + "--recon_loss", + type=str, + default="l1", + choices=["l1", "l2"], + help="reconstruction loss", + ) + parser.add_argument( + "--quantizer", + type=str, + default="vq", + choices=[ + "vq", + "gumbel_vq", + "st_gumbel_vq", + "ema_vq", + "oc_vq", + "diff_vq", + "diff_vq2", + "diff_vq_fix", + ], + help="quantizer type", + ) + parser.add_argument( + "--encoder", type=str, default="dinov2", help="encoder model type" + ) + parser.add_argument( + "--decoder", type=str, default="dinov2", help="deocder model type" + ) + parser.add_argument( + "--encoder_model", + type=str, + default="vit_small_patch14_dinov2.lvd142m", + help="encoder model name", + ) + parser.add_argument( + "--encoder_model_pretrained", + type=str2bool, + default=True, + help="encoder model load pretrained checkpoint", + ) + parser.add_argument( + "--encoder_patch_size", type=int, default=16, help="encoder patch size" + ) + parser.add_argument( + "--encoder_tuning", type=str, default="lora", help="encoder tuning method" + ) + parser.add_argument( + "--encoder_tuning_lora_r", default=8, type=int, help="encoder tuning lora r" + ) + parser.add_argument( + "--encoder_drop_path", type=float, default=0.0, help="encoder droppath rate" + ) + parser.add_argument( + "--decoder_model", + type=str, + default="vit_small_patch14_dinov2.lvd142m", + help="deocder model name", + ) + parser.add_argument( + "--decoder_model_pretrained", + type=str2bool, + default=True, + help="decoder model load pretrained checkpoint", + ) + parser.add_argument( + "--decoder_patch_size", type=int, default=16, help="decoder patch size" + ) + parser.add_argument( + "--decoder_drop_path", type=float, default=0.0, help="decoder droppath rate" + ) + parser.add_argument( + "--decoder_to_pixel", + type=str, + default="linear", + help="decoder to pixel", + choices=["linear", "conv", "ada_conv", "siren"], + ) + parser.add_argument( + "--decoder_use_rope", type=str2bool, default=False, help="decoder use RoPE" + ) + parser.add_argument( + "--decoder_cond_latent", + type=str2bool, + default=False, + help="use dino latent to initialize latent tokens (mask token)", + ) + parser.add_argument( + "--decoder_tuning", type=str, default="lora", help="deocder tuning method" + ) + parser.add_argument( + "--decoder_tuning_lora_r", default=8, type=int, help="decoder tuning lora r" + ) + parser.add_argument( + "--pretrained_path", type=str, default=None, help="pretrained model path" + ) + parser.add_argument( + "--semantic_guide", + type=str, + default="none", + help="semantic guidance on latent tokens", + ) + parser.add_argument( + "--sem_loss_scale", type=float, default=15.0, help="scale for clip loss" + ) + parser.add_argument( + "--renorm_input", type=str2bool, default=False, help="normalize input images" + ) + + parser.add_argument( + "--vocab_size", type=int, default=4096, nargs="+", help="codebook size" + ) + parser.add_argument( + "--z_channels", type=int, default=32, help="latent size of vqvae" + ) + parser.add_argument( + "--num_latent_tokens", type=int, default=32, help="number of latent tokens" + ) + parser.add_argument( + "--codebook_norm", type=str2bool, default=True, help="normalize codebook" + ) + parser.add_argument( + "--use_gumbel", + type=str2bool, + default=False, + help="use gumbel softmax for probs", + ) + parser.add_argument( + "--commit_loss_weight", type=float, default=0.0, help="commit loss weight" + ) + parser.add_argument( + "--kl_loss_weight", type=float, default=5e-4, help="kl loss weight" + ) + parser.add_argument( + "--ema_decay", + type=float, + default=0.999, + help="ema decay for embeddings of ema quantizer", + ) + parser.add_argument( + "--oc_anchor", + type=str, + default="cloest", + help="online cluster anchor", + choices=["closest", "random", "projrandom"], + ) + parser.add_argument( + "--contrastive_loss_weight", + type=float, + default=1.0, + help="contrastive loss weight", + ) + parser.add_argument( + "--freq_loss_weight", type=float, default=0.0, help="freq loss weight" + ) + parser.add_argument( + "--disc_r1_gamma", type=float, default=0.0, help="disc do r1 reg" + ) + parser.add_argument( + "--use_diffaug", type=str2bool, default=False, help="use diff aug" + ) + parser.add_argument( + "--init_logit_scale", + type=float, + default=10, + help="initial logit scale before log", + ) + parser.add_argument( + "--max_logit_scale", + type=float, + default=200, + help="maximum logit scale before log", + ) + + parser.add_argument( + "--v_patch_nums", + type=int, + default=[1, 2, 3, 4, 5, 6, 8, 10, 13, 16], + nargs="+", + help="number of patch numbers of each scale", + ) + parser.add_argument( + "--codebook-size", + type=int, + default=16384, + help="codebook size for vector quantization", + ) + parser.add_argument( + "--codebook-embed-dim", + type=int, + default=8, + help="codebook dimension for vector quantization", + ) + parser.add_argument( + "--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16" + ) + parser.add_argument( + "--vq-ckpt", type=str, default=None, help="ckpt path for resume training" + ) + parser.add_argument( + "--output_path", + type=str, + default="output/linear_probing", + help="output model path", + ) + parser.add_argument("--enc_type", type=str, default="cnn") + parser.add_argument("--dec_type", type=str, default="cnn") + # fFirst parse of command-line args to check for config file + # args = parser.parse_args() + + # # If a config file is specified, load it and set defaults + # if args.config is not None: + # with open(args.config, 'r', encoding='utf-8') as f: + # file_yaml = yaml.YAML() + # config_args = file_yaml.load(f) + # parser.set_defaults(**config_args) + + # re-parse command-line args to overwrite with any command-line inputs + args = parser.parse_args() + return args + + +class LinearClassifier(nn.Module): + def __init__(self, in_dim, out_dim): + super(LinearClassifier, self).__init__() + self.bn = nn.BatchNorm1d(in_dim, affine=False, eps=1e-6) + self.linear = nn.Linear(in_dim, out_dim) + + def forward(self, x): + x = x.mean(dim=1) + x = self.bn(x) + out = self.linear(x) + return out + + +@torch.no_grad() +def extract_feature(vqvae, images, args): + + if args.distributed: + + z_e = vqvae.module.encoder(images) + if args.enc_type == "dinov2": + b, l, c = z_e.shape + z_e = z_e.view(b, 16, 16, c) + z_e = z_e.permute(0, 3, 1, 2) + z_e = vqvae.module.quant_conv(z_e) + # z_q, _, _ = vqvae.module.quantize(z_e) + + else: + z_e = vqvae.encoder(images) + if args.enc_type == "dinov2": + b, l, c = z_e.shape + z_e = z_e.view(b, 16, 16, c) + z_e = z_e.permute(0, 3, 1, 2) + z_e = vqvae.quant_conv(z_e) + # z_q, _, _ = vqvae.quantize(z_e) + + return z_e + + +def train_epoch( + vqvae, linear_classifier, train_dataloader, optimizer, device, scaler, args +): + criterion = torch.nn.CrossEntropyLoss() + linear_classifier.train() + train_dtype = { + "none": torch.float32, + "bf16": torch.bfloat16, + "fp16": torch.float16, + }[args.mixed_precision] + total_loss = 0 + total_correct = 0 + total_samples = 0 + if args.renorm_input: + denormalize = Denormalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], device=device + ) + normalize = Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], device=device + ) + + for idx, batch in tqdm( + enumerate(train_dataloader), + total=len(train_dataloader), + disable=not is_primary(args), + ): + # features, labels = batch + # features, labels = features.to(device), labels.to(device) + + optimizer.zero_grad() + + images, labels = batch + if not args.use_prefetcher: + images = images.to(device) + labels = labels.to(device) + + if args.renorm_input: + input_images = denormalize(images) + input_images = normalize(input_images) + else: + input_images = images + + with torch.cuda.amp.autocast(dtype=train_dtype): + + features = extract_feature(vqvae, input_images, args).detach() + + features = features.flatten(2).permute(0, 2, 1) + logits = linear_classifier(features) + loss = criterion(logits, labels) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + total_loss += loss.item() + total_correct += (logits.argmax(1) == labels).sum().item() + total_samples += labels.size(0) + + if is_primary(args) and idx % 25 == 0: + logger.info(f"Training Loss: {loss.item():.4f}") + logger.info(f"Training Acc: {total_correct / total_samples * 100.0:.4f}") + return total_loss / len(train_dataloader), total_correct / total_samples * 100.0 + + +def evaluate(vqvae, linear_classifier, val_dataloader, device, args): + dtype = {"none": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16}[ + args.mixed_precision + ] + criterion = torch.nn.CrossEntropyLoss() + linear_classifier.eval() + total_loss = 0 + total_correct = 0 + total_samples = 0 + if args.renorm_input: + denormalize = Denormalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], device=device + ) + normalize = Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], device=device + ) + + with torch.no_grad(): + for batch in tqdm( + val_dataloader, total=len(val_dataloader), disable=not is_primary(args) + ): + # features, labels = batch + # features, labels = features.to(device), labels.to(device) + images, labels = batch + images = images.to(device) + labels = labels.to(device) + + if args.renorm_input: + input_images = denormalize(images) + input_images = normalize(input_images) + else: + input_images = images + + with torch.cuda.amp.autocast(dtype=dtype): + features = extract_feature(vqvae, input_images, args) + features = features.flatten(2).permute(0, 2, 1) + # z = extract_feature(vqvae, images, args) + logits = linear_classifier(features) + loss = criterion(logits, labels) + + total_loss += loss.item() + total_correct += (logits.argmax(1) == labels).sum().item() + total_samples += labels.size(0) + return total_loss / len(val_dataloader), total_correct / total_samples * 100.0 + + +def main(): + + args = parse_args() + + # seed + seed_everything(args.seed) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + + device = init_distributed_device(args) + if args.distributed: + logger.info( + "Training in distributed mode with multiple processes, 1 device per process." + f"Process {args.rank}, total {args.world_size}, device {args.device}." + ) + os.environ["HF_HOME"] = f"./hf_cache_{args.rank}/" + os.environ["TRANSFORMERS_CACHE"] = f"./hf_cache_{args.rank}/" + else: + logger.info(f"Training with a single process on 1 device ({args.device}).") + assert args.rank >= 0 + + # create and load model + logger.info("Creating model") + # create and load model + vqvae = VQ_models[args.vq_model]( + codebook_size=args.codebook_size, + codebook_embed_dim=args.codebook_embed_dim, + v_patch_nums=args.v_patch_nums, + enc_type=args.enc_type, + dec_type=args.dec_type, + semantic_guide=args.semantic_guide, + ) + vqvae.to(device) + vqvae.eval() + checkpoint = torch.load(args.vq_ckpt, map_location="cpu") + if "ema" in checkpoint: # ema + model_weight = checkpoint["ema"] + elif "model" in checkpoint: # ddp + model_weight = checkpoint["model"] + elif "state_dict" in checkpoint: + model_weight = checkpoint["state_dict"] + else: + raise Exception("please check model weight") + vqvae.load_state_dict(model_weight) + del checkpoint + + # create linear classifier + linear_classifier = LinearClassifier(vqvae.codebook_embed_dim, args.num_classes) + linear_classifier = linear_classifier.to(device) + + if args.distributed: + if is_primary(args): + logger.info("Using native Torch DistributedDataParallel.") + vqvae = NativeDDP(vqvae, device_ids=[device], find_unused_parameters=True) + linear_classifier = NativeDDP( + linear_classifier, device_ids=[device], find_unused_parameters=True + ) + + logger.info("Creating dataset") + train_dataset = create_dataset( + args.dataset_name, + args.data_dir, + args.image_size, + is_train=True, + use_prefetcher=args.use_prefetcher, + ) + valid_dataset = create_dataset( + args.dataset_name, + args.val_data_dir, + args.image_size, + is_train=False, + use_prefetcher=False, + ) + sampler = None + if args.distributed: + sampler = torch.utils.data.DistributedSampler( + train_dataset, shuffle=True, drop_last=False + ) + shuffle = sampler is None + collate_fn = fast_collate if args.use_prefetcher else default_collate + train_dataloader = DataLoader( + train_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=True, + sampler=sampler, + shuffle=shuffle, + collate_fn=collate_fn, + ) + if args.use_prefetcher: + train_dataloader = PrefetchLoader( + train_dataloader, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], device=device + ) + val_dataloader = DataLoader( + valid_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=True, + shuffle=False, + ) + total_batch_size = args.batch_size * args.world_size + + # create output folder + # output_dir = os.path.join(args.output_dir, args.run_name, 'evaluations') + output_dir = args.output_path + output_dir = os.path.join(output_dir, "evaluations") + os.makedirs(output_dir, exist_ok=True) + lp_model_dir = os.path.join(output_dir, args.dataset_name) + os.makedirs(lp_model_dir, exist_ok=True) + + optimizer = create_optimizer( + linear_classifier, + opt=args.optimizer, + lr=args.learning_rate, + weight_decay=args.weight_decay, + ) + scaler = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision == "fp16")) + scheduler, _ = create_scheduler( + sched="step", + decay_milestones=[ + int(args.num_epochs * 0.3), + int(args.num_epochs * 0.6), + int(args.num_epochs * 0.9), + ], + optimizer=optimizer, + patience_epochs=0, + step_on_epochs=True, + num_epochs=args.num_epochs, + warmup_epochs=args.lr_warmup_epochs, + min_lr=1e-6, + ) + + # train linear classifier + logger.info("Start training linear classifier") + max_accuracy = 0 + for epoch in range(args.num_epochs): + if args.distributed: + sampler.set_epoch(epoch) + train_loss, train_acc = train_epoch( + vqvae, linear_classifier, train_dataloader, optimizer, device, scaler, args + ) + val_loss, val_acc = evaluate( + vqvae, linear_classifier, val_dataloader, device, args + ) + + if is_global_primary(args): + if args.distributed: + torch.save( + linear_classifier.module.state_dict(), + os.path.join(lp_model_dir, f"epoch_{epoch}.pth"), + ) + else: + torch.save( + linear_classifier.state_dict(), + os.path.join(lp_model_dir, f"epoch_{epoch}.pth"), + ) + + if val_acc > max_accuracy: + max_accuracy = val_acc + logger.info(f"Saving best model with accuracy {max_accuracy}") + if args.distributed: + torch.save( + linear_classifier.module.state_dict(), + os.path.join(lp_model_dir, "best.pth"), + ) + else: + torch.save( + linear_classifier.state_dict(), + os.path.join(lp_model_dir, "best.pth"), + ) + + if is_primary(args): + logger.info( + f"Epoch {epoch}: train_loss={train_loss}, train_acc={train_acc}, val_loss={val_loss}, val_acc={val_acc}" + ) + logger.info(f"Best accuracy so far: {max_accuracy}") + + scheduler.step(epoch + 1) + results = {"best_lp_accuracy": max_accuracy} + + # logger.info("Start training k-nn") + + if is_primary(args): + logger.info("Finished training") + logger.info(f"Best accuracy: {max_accuracy}") + + with open(os.path.join(output_dir, "linear_results.json"), "w") as f: + json.dump(results, f) + + +if __name__ == "__main__": + main() diff --git a/src/vqvaes/xqgan/lookup_free_quantize.py b/src/vqvaes/xqgan/lookup_free_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..f0462dfb5730bbbc9144664e1681e8bee0e0c15f --- /dev/null +++ b/src/vqvaes/xqgan/lookup_free_quantize.py @@ -0,0 +1,716 @@ +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from torch import distributed as tdist, nn as nn +from torch.nn import functional as F + +from math import sqrt +import math + +from einops import rearrange, reduce, pack, unpack + +import dist + + +def mult_along_first_dims(x, y): + """ + returns x * y elementwise along the leading dimensions of y + """ + ndim_to_expand = x.ndim - y.ndim + for _ in range(ndim_to_expand): + y = y.unsqueeze(-1) + return x * y + + +def masked_mean(x, m): + """ + takes the mean of the elements of x that are not masked + the mean is taken along the shared leading dims of m + equivalent to: x[m].mean(tuple(range(m.ndim))) + + The benefit of using masked_mean rather than using + tensor indexing is that masked_mean is much faster + for torch-compile on batches. + + The drawback is larger floating point errors + """ + x = mult_along_first_dims(x, m) + x = x / m.sum() + return x.sum(tuple(range(m.ndim))) + + +def entropy_loss( + logits, + mask=None, + temperature=0.01, + sample_minimization_weight=1.0, + batch_maximization_weight=1.0, + eps=1e-5, +): + """ + Entropy loss of unnormalized logits + + logits: Affinities are over the last dimension + + https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279 + LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024) + """ + probs = F.softmax(logits / temperature, -1) + log_probs = F.log_softmax(logits / temperature + eps, -1) + if mask is not None: + # avg_probs = probs[mask].mean(tuple(range(probs.ndim - 1))) + # avg_probs = einx.mean("... D -> D", probs[mask]) + avg_probs = reduce(masked_mean(probs, mask), "... D -> D", "mean") + # avg_probs = einx.mean("... D -> D", avg_probs) + else: + avg_probs = reduce(probs, "... D -> D", "mean") + avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps)) + + sample_entropy = -torch.sum(probs * log_probs, -1) + if mask is not None: + # sample_entropy = sample_entropy[mask].mean() + sample_entropy = masked_mean(sample_entropy, mask).mean() + else: + sample_entropy = torch.mean(sample_entropy) + + loss = (sample_minimization_weight * sample_entropy) - ( + batch_maximization_weight * avg_entropy + ) + + return sample_entropy, avg_entropy, loss + + +class LFQ(nn.Module): + # VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25 + def __init__( + self, + codebook_size, + Cvae, + using_znorm=False, + beta: float = 0.25, + default_qresi_counts=0, + v_patch_nums=None, + quant_resi=0.5, + share_quant_resi=4, + num_latent_tokens=256, + codebook_drop=0.0, + scale=1, + sample_minimization_weight=1.0, + batch_maximization_weight=1.0, + entropy_weight=0.1, + soft_entropy=True, + # share_quant_resi: args.qsr + ): + super().__init__() + self.Cvae: int = Cvae + self.vocab_size: int = 2**self.Cvae + assert self.vocab_size == codebook_size + self.using_znorm: bool = using_znorm + self.v_patch_nums: Tuple[int] = v_patch_nums + self.num_latent_tokens = num_latent_tokens + self.entropy_weight = entropy_weight + self.soft_entropy = soft_entropy + self.persample_entropy_compute = "analytical" + + self.quant_resi_ratio = quant_resi + if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales + self.quant_resi = PhiNonShared( + [ + (Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) + for _ in range(default_qresi_counts or len(self.v_patch_nums)) + ] + ) + elif share_quant_resi == 1: # fully shared: only a single \phi for K scales + self.quant_resi = PhiShared( + Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity() + ) + else: # partially shared: \phi_{1 to share_quant_resi} for K scales + self.quant_resi = PhiPartiallyShared( + nn.ModuleList( + [ + ( + Phi(Cvae, quant_resi) + if abs(quant_resi) > 1e-6 + else nn.Identity() + ) + for _ in range(share_quant_resi) + ] + ) + ) + + self.register_buffer( + "ema_vocab_hit_SV", + torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0), + ) + self.record_hit = 0 + + self.register_buffer("mask", 2 ** torch.arange(self.Cvae), persistent=False) + + self.beta: float = beta + + self.codebook_drop = codebook_drop + + scaler = scale ** torch.arange(len(self.v_patch_nums)) + if using_znorm: + scaler = scaler / sqrt(self.Cvae) + + self.register_buffer("scaler", scaler) + print("scale is", scaler) + + # for entropy loss + self.sample_minimization_weight = sample_minimization_weight + self.batch_maximization_weight = batch_maximization_weight + + # codes + all_codes = torch.arange(codebook_size) + bits = self.indices_to_bits(all_codes) + codebook = bits * 2.0 - 1.0 + + self.register_buffer("codebook", codebook, persistent=False) + + # only used for progressive training of VAR (not supported yet, will be tested and supported in the future) + self.prog_si = -1 # progressive training: not supported yet, prog_si always -1 + + def extra_repr(self) -> str: + return f"{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}" + + # ===================== `forward` is only used in VAE training ===================== + + def forward( + self, f_BChw: torch.Tensor, ret_usages=False, dropout=None + ) -> Tuple[torch.Tensor, List[float], torch.Tensor]: + dtype = f_BChw.dtype + if dtype != torch.float32: + f_BChw = f_BChw.float() + B, C, H, W = f_BChw.shape + if self.using_znorm: + f_BChw = F.normalize(f_BChw, dim=1) + f_no_grad = f_BChw.detach() + + f_rest = f_no_grad.clone() + f_hat = torch.zeros_like(f_rest) + # x = f_BChw + + with torch.cuda.amp.autocast(enabled=False): + mean_vq_loss: torch.Tensor = 0.0 + mean_commit_loss: torch.Tensor = 0.0 + mean_entropy_loss: torch.Tensor = 0.0 + vocab_hit_V = torch.zeros( + self.vocab_size, dtype=torch.float, device=f_BChw.device + ) + SN = len(self.v_patch_nums) + + if self.training: + max_n = len(self.v_patch_nums) + 1 + n_quantizers = torch.ones((B,)) * max_n + n_dropout = int(B * self.codebook_drop) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(f_BChw.device) + else: + n_quantizers = torch.ones((B,)) * (self.v_patch_nums + 1) + + for si, pn in enumerate(self.v_patch_nums): # from small to large + codebook_value = ( + self.scaler[si].to(device=f_BChw.device, dtype=torch.float).detach() + ) + # find the nearest embedding + rest_NC = ( + F.interpolate(f_rest, size=(pn, pn), mode="area") + .permute(0, 2, 3, 1) + .reshape(-1, C) + if (si != SN - 1) or pn != int(sqrt(self.num_latent_tokens)) + else f_rest.permute(0, 2, 3, 1).reshape(-1, C) + ) + # rest_NC = f_rest.permute(0, 2, 3, 1).reshape(-1, C) + d_no_grad = torch.where(rest_NC > 0, codebook_value, -codebook_value) + idx_N = self.bits_to_indices((d_no_grad > 0)) + + hit_V = idx_N.bincount(minlength=self.vocab_size).float() + if self.training: + handler = tdist.all_reduce(hit_V, async_op=True) + # calc loss + idx_Bhw = idx_N.view(B, pn, pn) + + h_BChw = ( + F.interpolate( + self.indices_to_bits(idx_Bhw, si).permute(0, 3, 1, 2), + size=(H, W), + mode="bicubic", + ).contiguous() + if (si != SN - 1) + else self.indices_to_bits(idx_Bhw, si) + .permute(0, 3, 1, 2) + .contiguous() + ) + # h_BChw = self.indices_to_bits(idx_Bhw, si).permute(0, 3, 1, 2).contiguous() + h_BChw = self.quant_resi[si / (SN - 1)](h_BChw) + + # x = f_rest.clone().permute(0, 2, 3, 1) + x = rearrange((f_BChw - f_hat.detach()), "b d h w -> b (h w) 1 d") + + mask = ( + torch.full((B,), fill_value=si, device=h_BChw.device) < n_quantizers + )[:, None, None, None].int() + f_hat = f_hat + h_BChw * mask + + f_rest -= h_BChw + if self.training: + handler.wait() + if self.record_hit == 0: + self.ema_vocab_hit_SV[si].copy_(hit_V) + elif self.record_hit < 100: + self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1)) + else: + self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01)) + self.record_hit += 1 + vocab_hit_V.add_(hit_V) + ratio = mask.sum() / B + + codebook = self.codebook * codebook_value + + if self.soft_entropy: + per_sample_entropy, codebook_entropy, avg_prob = ( + self.soft_entropy_loss(x, si, codebook, mask.squeeze()) + ) + entropy_aux_loss = ( + self.sample_minimization_weight * per_sample_entropy + ) - (self.batch_maximization_weight * codebook_entropy) + else: + logits = 2 * torch.einsum("... i d, j d -> ... i j", x, codebook) + # the same as euclidean distance up to a constant + per_sample_entropy, codebook_entropy, entropy_aux_loss = ( + entropy_loss( + logits=logits, + mask=mask.squeeze(), + sample_minimization_weight=self.sample_minimization_weight, + batch_maximization_weight=self.batch_maximization_weight, + ) + ) + # F.mse_loss(f_hat, f_no_grad, reduction="none").mul_(mask).mean() / ratio + mean_vq_loss += ( + F.mse_loss(f_hat, f_no_grad, reduction="none").mul_(mask).mean() + / ratio + ) + mean_commit_loss += ( + F.mse_loss(f_hat.data, f_BChw, reduction="none") + .mul_(mask) + .mul_(self.beta / ratio) + .mean() + ) + + entropy_weight = self.entropy_weight / ratio + + mean_entropy_loss += entropy_aux_loss.mul_(entropy_weight) + # x -= h_BChw.detach() + + mean_vq_loss *= 1.0 / SN + mean_commit_loss *= 1.0 / SN + mean_entropy_loss *= 1.0 / SN + f_hat = (f_hat.data - f_no_grad).add_(f_BChw) + + margin = ( + tdist.get_world_size() + * (f_BChw.numel() / f_BChw.shape[1]) + / self.vocab_size + * 0.08 + ) + # margin = pn*pn / 100 + if ret_usages: + usages = [ + (self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 + for si, pn in enumerate(self.v_patch_nums) + ] + else: + usages = None + return f_hat, usages, mean_vq_loss, mean_commit_loss, mean_entropy_loss + + # ===================== `forward` is only used in VAE training ===================== + + def bits_to_indices(self, bits): + """ + bits: bool tensor of big endian bits, where the last dimension is the bit dimension + + returns indices, which are long integers from 0 to self.codebook_size + """ + assert bits.shape[-1] == self.Cvae + indices = 2 ** torch.arange( + 0, + self.Cvae, + 1, + dtype=torch.long, + device=bits.device, + ) + return (bits * indices).sum(-1) + + def indices_to_bits(self, x, si=None): + """ + x: long tensor of indices + + returns big endian bits + """ + mask = 2 ** torch.arange(self.Cvae, device=x.device, dtype=torch.long) + # x is now big endian bits, the last dimension being the bits + x = (x.unsqueeze(-1) & mask) != 0 + if si == None: + return x + return torch.where(x, self.scaler[si], -self.scaler[si]) + + def soft_entropy_loss(self, z, si, codebook, mask=None): + if mask != None: + z = z[mask] + distance = -2 * torch.einsum("... g c, d c ->... g d", z, codebook) + prob = (-distance).softmax(dim=-1) + if self.persample_entropy_compute == "analytical": + p = torch.sigmoid(-4 * z * (self.scaler[si])) + prob = torch.stack([p, 1 - p], dim=-1) + per_sample_entropy = ( + self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() + ) + else: + per_sample_entropy = ( + self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() + ) + + # macro average of the probability of each subgroup + avg_prob = reduce(prob, "... g d ->g d", "mean") + codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) + + # the approximation of the entropy is the sum of the entropy of each subgroup + return per_sample_entropy, codebook_entropy.sum(), avg_prob + + def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): + if normalize: + probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True) + else: + probs = count + H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) + return H + + def embed_to_fhat( + self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False + ) -> Union[List[torch.Tensor], torch.Tensor]: + ls_f_hat_BChw = [] + B = ms_h_BChw[0].shape[0] + H = W = self.v_patch_nums[-1] + SN = len(self.v_patch_nums) + if all_to_max_scale: + f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32) + for si, pn in enumerate(self.v_patch_nums): # from small to large + h_BChw = ms_h_BChw[si] + if si < len(self.v_patch_nums) - 1: + h_BChw = F.interpolate(h_BChw, size=(H, W), mode="bicubic") + h_BChw = self.quant_resi[si / (SN - 1)](h_BChw) + f_hat.add_(h_BChw) + if last_one: + ls_f_hat_BChw = f_hat + else: + ls_f_hat_BChw.append(f_hat.clone()) + else: + # WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above) + # WARNING: this should only be used for experimental purpose + f_hat = ms_h_BChw[0].new_zeros( + B, + self.Cvae, + self.v_patch_nums[0], + self.v_patch_nums[0], + dtype=torch.float32, + ) + for si, pn in enumerate(self.v_patch_nums): # from small to large + f_hat = F.interpolate(f_hat, size=(pn, pn), mode="bicubic") + h_BChw = self.quant_resi[si / (SN - 1)](ms_h_BChw[si]) + f_hat.add_(h_BChw) + if last_one: + ls_f_hat_BChw = f_hat + else: + ls_f_hat_BChw.append(f_hat) + + return ls_f_hat_BChw + + def f_to_idxBl_or_fhat( + self, + f_BChw: torch.Tensor, + to_fhat: bool, + v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, + ) -> List[ + Union[torch.Tensor, torch.LongTensor] + ]: # z_BChw is the feature from inp_img_no_grad + B, C, H, W = f_BChw.shape + if self.using_znorm: + f_BChw = F.normalize(f_BChw, dim=1) + f_no_grad = f_BChw.detach() + f_rest = f_no_grad.clone() + f_hat = torch.zeros_like(f_rest) + + f_hat_or_idx_Bl: List[torch.Tensor] = [] + + patch_hws = [ + (pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) + for pn in (v_patch_nums or self.v_patch_nums) + ] # from small to large + # assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})' + + SN = len(patch_hws) + for si, (ph, pw) in enumerate(patch_hws): # from small to large + codebook_value = ( + self.scaler[si].to(device=f_BChw.device, dtype=torch.float).detach() + ) + if 0 <= self.prog_si < si: + break # progressive training: not supported yet, prog_si always -1 + # find the nearest embedding + z_NC = ( + F.interpolate(f_rest, size=(ph, pw), mode="area") + .permute(0, 2, 3, 1) + .reshape(-1, C) + if (si != SN - 1) or ph != 16 + else f_rest.permute(0, 2, 3, 1).reshape(-1, C) + ) + + d_no_grad = torch.where(z_NC > 0, codebook_value, -codebook_value) + idx_N = self.bits_to_indices((d_no_grad > 0)) + + idx_Bhw = idx_N.view(B, ph, pw) + h_BChw = ( + F.interpolate( + self.indices_to_bits(idx_Bhw, si).permute(0, 3, 1, 2), + size=(H, W), + mode="bicubic", + ).contiguous() + if (si != SN - 1) + else self.indices_to_bits(idx_Bhw, si).permute(0, 3, 1, 2).contiguous() + ) + h_BChw = self.quant_resi[si / (SN - 1)](h_BChw) + f_hat.add_(h_BChw) + f_rest.sub_(h_BChw) + f_hat_or_idx_Bl.append( + f_hat.clone() if to_fhat else idx_N.reshape(B, ph * pw) + ) + + return f_hat_or_idx_Bl + + # ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input ===================== + def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor: + next_scales = [] + B = gt_ms_idx_Bl[0].shape[0] + C = self.Cvae + H = W = self.v_patch_nums[-1] + SN = len(self.v_patch_nums) + + f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32) + pn_next: int = self.v_patch_nums[0] + for si in range(SN - 1): + if self.prog_si == 0 or (0 <= self.prog_si - 1 < si): + break # progressive training: not supported yet, prog_si always -1 + h_BChw = F.interpolate( + self.embedding(gt_ms_idx_Bl[si]) + .transpose_(1, 2) + .view(B, C, pn_next, pn_next), + size=(H, W), + mode="bicubic", + ) + f_hat.add_(self.quant_resi[si / (SN - 1)](h_BChw)) + pn_next = self.v_patch_nums[si + 1] + next_scales.append( + F.interpolate(f_hat, size=(pn_next, pn_next), mode="area") + .view(B, C, -1) + .transpose(1, 2) + ) + return ( + torch.cat(next_scales, dim=1) if len(next_scales) else None + ) # cat BlCs to BLC, this should be float32 + + # ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input ===================== + def get_next_autoregressive_input( + self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference + HW = self.v_patch_nums[-1] + if si != SN - 1: + h = self.quant_resi[si / (SN - 1)]( + F.interpolate(h_BChw, size=(HW, HW), mode="bicubic") + ) # conv after upsample + f_hat.add_(h) + return f_hat, F.interpolate( + f_hat, + size=(self.v_patch_nums[si + 1], self.v_patch_nums[si + 1]), + mode="area", + ) + else: + h = self.quant_resi[si / (SN - 1)](h_BChw) + f_hat.add_(h) + return f_hat, f_hat + + +class Phi(nn.Conv2d): + def __init__(self, embed_dim, quant_resi): + ks = 3 + super().__init__( + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=ks, + stride=1, + padding=ks // 2, + ) + self.resi_ratio = abs(quant_resi) + + def forward(self, h_BChw): + return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_( + self.resi_ratio + ) + + +class PhiShared(nn.Module): + def __init__(self, qresi: Phi): + super().__init__() + self.qresi: Phi = qresi + + def __getitem__(self, _) -> Phi: + return self.qresi + + +class PhiPartiallyShared(nn.Module): + def __init__(self, qresi_ls: nn.ModuleList): + super().__init__() + self.qresi_ls = qresi_ls + K = len(qresi_ls) + self.ticks = ( + np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) + if K == 4 + else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K) + ) + + def __getitem__(self, at_from_0_to_1: float) -> Phi: + return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()] + + def extra_repr(self) -> str: + return f"ticks={self.ticks}" + + +class PhiNonShared(nn.ModuleList): + def __init__(self, qresi: List): + super().__init__(qresi) + # self.qresi = qresi + K = len(qresi) + self.ticks = ( + np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) + if K == 4 + else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K) + ) + + def __getitem__(self, at_from_0_to_1: float) -> Phi: + return super().__getitem__( + np.argmin(np.abs(self.ticks - at_from_0_to_1)).item() + ) + + def extra_repr(self) -> str: + return f"ticks={self.ticks}" + + +def schedule(ratio, total_unknown, method="cosine"): + """Generates a mask rate by scheduling mask functions R. + + Given a ratio in [0, 1), we generate a masking ratio from (0, 1]. During + training, the input ratio is uniformly sampled; during inference, the input + ratio is based on the step number divided by the total iteration number: t/T. + Based on experiements, we find that masking more in training helps. + Args: + ratio: The uniformly sampled ratio [0, 1) as input. + total_unknown: The total number of tokens that can be masked out. For + example, in MaskGIT, total_unknown = 256 for 256x256 images and 1024 for + 512x512 images. + method: implemented functions are ["uniform", "cosine", "pow", "log", "exp"] + "pow2.5" represents x^2.5 + + Returns: + The mask rate (float). + """ + if method == "uniform": + mask_ratio = 1.0 - ratio + elif "pow" in method: + exponent = float(method.replace("pow", "")) + mask_ratio = 1.0 - ratio**exponent + elif method == "cosine": + mask_ratio = np.cos(math.pi / 2.0 * ratio) + elif method == "log": + mask_ratio = -np.log2(ratio) / np.log2(total_unknown) + elif method == "exp": + mask_ratio = 1 - np.exp2(-np.log2(total_unknown) * (1 - ratio)) + # Clamps mask into [epsilon, 1) + mask_ratio = np.clip(mask_ratio, 0, 1.0) + return mask_ratio + + +if __name__ == "__main__": + + batch_size = 4 + seq_len = 16 + num_classes = 4096 + # # Generate random logits and integer mask + # logits = torch.randn(batch_size, seq_len,seq_len, num_classes) + mask = torch.ones(batch_size, dtype=torch.int) + + # # Calculate entropy loss + # sample_entropy, avg_entropy, loss = entropy_loss( + # logits, + # mask=mask, + # sample_minimization_weight=1.0, + # batch_maximization_weight=1.0, + # ) + + # # Output results + # print("Sample Entropy for mask:", sample_entropy) + # print("Average Entropy for mask:", avg_entropy) + # print("Entropy Loss for mask:", loss) + + # # Calculate entropy loss + # sample_entropy, avg_entropy, loss = entropy_loss( + # logits, + # sample_minimization_weight=1.0, + # batch_maximization_weight=1.0, + # ) + + # # Output results + # print("Sample Entropy:", sample_entropy) + # print("Average Entropy:", avg_entropy) + # print("Entropy Loss:", loss) + quantizer = LFQ( + 4096, + 12, + using_znorm=False, + v_patch_nums=[1, 2, 3, 4, 5, 6, 8, 10, 12, 16], + ) + + z = torch.randn(batch_size, seq_len * seq_len, 1, 12) + + for i in range(10): + + codebook = quantizer.codebook * quantizer.scaler[i] + logits = 2 * torch.einsum("... i d, j d -> ... i j", z, codebook) + + per_sample_entropy, codebook_entropy, avg_prob = quantizer.soft_entropy_loss( + z, i, codebook, mask + ) + print("Soft Sample Entropy :", per_sample_entropy) + print("Soft codebook Entropy:", codebook_entropy) + print("Soft Entropy Loss", per_sample_entropy - codebook_entropy) + + sample_entropy, avg_entropy, loss = entropy_loss( + logits, + mask=mask, + sample_minimization_weight=1.0, + batch_maximization_weight=1.0, + ) + print("Sample Entropy :", sample_entropy) + print("codebook Entropy:", avg_entropy) + print("Entropy Loss", loss) + + image_feats = torch.randn( + 2, 12, 16, 16 + ) # 16 is dim, must be power of 2 of codebook_size + + dropout_rand = torch.randint(3, len([1, 2, 3, 4, 5, 6, 8, 10, 12, 16]) + 1, (2,)) + + quantized, usgae, loss = quantizer( + image_feats, ret_usages=True, dropout=dropout_rand + ) # you may want to experiment with temperature + + assert image_feats.shape == quantized.shape diff --git a/src/vqvaes/xqgan/lpips.py b/src/vqvaes/xqgan/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..9629dd5d388a2cfd4808ca70e86027f0dcdae505 --- /dev/null +++ b/src/vqvaes/xqgan/lpips.py @@ -0,0 +1,187 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +import os, hashlib +import requests +from tqdm import tqdm + +import torch +import torch.nn as nn +from torchvision import models +from collections import namedtuple + +URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} + +CKPT_MAP = {"vgg_lpips": "vgg.pth"} + +MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path( + name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache") + ) + self.load_state_dict( + torch.load(ckpt, map_location=torch.device("cpu")), strict=False + ) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path( + name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache") + ) + model.load_state_dict( + torch.load(ckpt, map_location=torch.device("cpu")), strict=False + ) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( + outs1[kk] + ) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [ + spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns)) + ] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer( + "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] + ) + self.register_buffer( + "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] + ) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] + ) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/src/vqvaes/xqgan/quant.py b/src/vqvaes/xqgan/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..685135f841617eba800e777118c076a2b332170c --- /dev/null +++ b/src/vqvaes/xqgan/quant.py @@ -0,0 +1,450 @@ +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from torch import distributed as tdist, nn as nn +from torch.nn import functional as F + +from math import sqrt + +import dist + + +class VectorQuantizer2(nn.Module): + # VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25 + def __init__( + self, + vocab_size, + Cvae, + using_znorm=True, + beta: float = 0.25, + default_qresi_counts=0, + v_patch_nums=None, + quant_resi=0.5, + share_quant_resi=4, + num_latent_tokens=256, + codebook_drop=0.0, + # share_quant_resi: args.qsr + ): + super().__init__() + self.vocab_size: int = vocab_size + self.Cvae: int = Cvae + self.using_znorm: bool = using_znorm + self.v_patch_nums: Tuple[int] = v_patch_nums + self.num_latent_tokens = num_latent_tokens + + self.quant_resi_ratio = quant_resi + if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales + self.quant_resi = PhiNonShared( + [ + (Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) + for _ in range(default_qresi_counts or len(self.v_patch_nums)) + ] + ) + elif share_quant_resi == 1: # fully shared: only a single \phi for K scales + self.quant_resi = PhiShared( + Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity() + ) + else: # partially shared: \phi_{1 to share_quant_resi} for K scales + self.quant_resi = PhiPartiallyShared( + nn.ModuleList( + [ + ( + Phi(Cvae, quant_resi) + if abs(quant_resi) > 1e-6 + else nn.Identity() + ) + for _ in range(share_quant_resi) + ] + ) + ) + + self.register_buffer( + "ema_vocab_hit_SV", + torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0), + ) + self.record_hit = 0 + + self.beta: float = beta + self.embedding = nn.Embedding(self.vocab_size, self.Cvae) + self.codebook_drop = codebook_drop + + self.embedding.weight.data.uniform_( + -1.0 / self.vocab_size, 1.0 / self.vocab_size + ) + if self.using_znorm: + self.embedding.weight.data = F.normalize( + self.embedding.weight.data, p=2, dim=-1 + ) + + # only used for progressive training of VAR (not supported yet, will be tested and supported in the future) + self.prog_si = -1 # progressive training: not supported yet, prog_si always -1 + + def eini(self, eini): + if eini > 0: + nn.init.trunc_normal_(self.embedding.weight.data, std=eini) + elif eini < 0: + self.embedding.weight.data.uniform_( + -abs(eini) / self.vocab_size, abs(eini) / self.vocab_size + ) + + def extra_repr(self) -> str: + return f"{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}" + + # ===================== `forward` is only used in VAE training ===================== + def forward( + self, f_BChw: torch.Tensor, ret_usages=False, dropout=None + ) -> Tuple[torch.Tensor, List[float], torch.Tensor]: + dtype = f_BChw.dtype + if dtype != torch.float32: + f_BChw = f_BChw.float() + B, C, H, W = f_BChw.shape + f_no_grad = f_BChw.detach() + + f_rest = f_no_grad.clone() + f_hat = torch.zeros_like(f_rest) + + with torch.cuda.amp.autocast(enabled=False): + mean_vq_loss: torch.Tensor = 0.0 + mean_commit_loss: torch.Tensor = 0.0 + vocab_hit_V = torch.zeros( + self.vocab_size, dtype=torch.float, device=f_BChw.device + ) + SN = len(self.v_patch_nums) + + if self.training and dropout != None: + max_n = len(self.v_patch_nums) + 1 + n_quantizers = torch.ones((B,)) * max_n + n_dropout = int(B * self.codebook_drop) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(f_BChw.device) + else: + n_quantizers = torch.ones((B,), device=f_BChw.device) * ( + len(self.v_patch_nums) + 1 + ) + + for si, pn in enumerate(self.v_patch_nums): # from small to large + # find the nearest embedding + if self.using_znorm: + rest_NC = ( + F.interpolate(f_rest, size=(pn, pn), mode="area") + .permute(0, 2, 3, 1) + .reshape(-1, C) + if (si != SN - 1) or pn != int(sqrt(self.num_latent_tokens)) + else f_rest.permute(0, 2, 3, 1).reshape(-1, C) + ) + rest_NC = F.normalize(rest_NC, dim=-1) + idx_N = torch.argmax( + rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0), + dim=1, + ) + else: + rest_NC = ( + F.interpolate(f_rest, size=(pn, pn), mode="area") + .permute(0, 2, 3, 1) + .reshape(-1, C) + if (si != SN - 1) or pn != int(sqrt(self.num_latent_tokens)) + else f_rest.permute(0, 2, 3, 1).reshape(-1, C) + ) + d_no_grad = torch.sum( + rest_NC.square(), dim=1, keepdim=True + ) + torch.sum( + self.embedding.weight.data.square(), dim=1, keepdim=False + ) + d_no_grad.addmm_( + rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1 + ) # (B*h*w, vocab_size) + idx_N = torch.argmin(d_no_grad, dim=1) + hit_V = idx_N.bincount(minlength=self.vocab_size).float() + if self.training: + handler = tdist.all_reduce(hit_V, async_op=True) + # calc loss + idx_Bhw = idx_N.view(B, pn, pn) + h_BChw = ( + F.interpolate( + self.embedding(idx_Bhw).permute(0, 3, 1, 2), + size=(H, W), + mode="bicubic", + ).contiguous() + if (si != SN - 1) + else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous() + ) + if SN == 1: + h_BChw = self.quant_resi[0](h_BChw) + else: + h_BChw = self.quant_resi[si / (SN - 1)](h_BChw) + + mask = ( + torch.full((B,), fill_value=si, device=h_BChw.device) < n_quantizers + )[:, None, None, None].int() + f_hat = f_hat + h_BChw * mask + + f_rest -= h_BChw + if self.training: + handler.wait() + if self.record_hit == 0: + self.ema_vocab_hit_SV[si].copy_(hit_V) + elif self.record_hit < 100: + self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1)) + else: + self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01)) + self.record_hit += 1 + vocab_hit_V.add_(hit_V) + ratio = mask.sum() / B + + mean_vq_loss += ( + F.mse_loss(f_hat, f_no_grad, reduction="none").mul_(mask).mean() + / ratio + ) + mean_commit_loss += ( + F.mse_loss(f_hat.data, f_BChw, reduction="none") + .mul_(mask) + .mul_(self.beta / ratio) + .mean() + ) + + mean_vq_loss *= 1.0 / SN + f_hat = (f_hat.data - f_no_grad).add_(f_BChw) + + margin = ( + tdist.get_world_size() + * (f_BChw.numel() / f_BChw.shape[1]) + / self.vocab_size + * 0.08 + ) + # margin = pn*pn / 100 + if ret_usages: + usages = [ + (self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 + for si, pn in enumerate(self.v_patch_nums) + ] + else: + usages = None + return f_hat, usages, mean_vq_loss, mean_commit_loss, 0 + + # ===================== `forward` is only used in VAE training ===================== + + def embed_to_fhat( + self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False + ) -> Union[List[torch.Tensor], torch.Tensor]: + ls_f_hat_BChw = [] + B = ms_h_BChw[0].shape[0] + H = W = self.v_patch_nums[-1] + SN = len(self.v_patch_nums) + if all_to_max_scale: + f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32) + for si, pn in enumerate(self.v_patch_nums): # from small to large + h_BChw = ms_h_BChw[si] + if si < len(self.v_patch_nums) - 1: + h_BChw = F.interpolate(h_BChw, size=(H, W), mode="bicubic") + h_BChw = self.quant_resi[si / (SN - 1)](h_BChw) + f_hat.add_(h_BChw) + if last_one: + ls_f_hat_BChw = f_hat + else: + ls_f_hat_BChw.append(f_hat.clone()) + else: + # WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above) + # WARNING: this should only be used for experimental purpose + f_hat = ms_h_BChw[0].new_zeros( + B, + self.Cvae, + self.v_patch_nums[0], + self.v_patch_nums[0], + dtype=torch.float32, + ) + for si, pn in enumerate(self.v_patch_nums): # from small to large + f_hat = F.interpolate(f_hat, size=(pn, pn), mode="bicubic") + h_BChw = self.quant_resi[si / (SN - 1)](ms_h_BChw[si]) + f_hat.add_(h_BChw) + if last_one: + ls_f_hat_BChw = f_hat + else: + ls_f_hat_BChw.append(f_hat) + + return ls_f_hat_BChw + + def f_to_idxBl_or_fhat( + self, + f_BChw: torch.Tensor, + to_fhat: bool, + v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, + ) -> List[ + Union[torch.Tensor, torch.LongTensor] + ]: # z_BChw is the feature from inp_img_no_grad + B, C, H, W = f_BChw.shape + f_no_grad = f_BChw.detach() + f_rest = f_no_grad.clone() + f_hat = torch.zeros_like(f_rest) + + f_hat_or_idx_Bl: List[torch.Tensor] = [] + + patch_hws = [ + (pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) + for pn in (v_patch_nums or self.v_patch_nums) + ] # from small to large + # assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})' + + SN = len(patch_hws) + for si, (ph, pw) in enumerate(patch_hws): # from small to large + if 0 <= self.prog_si < si: + break # progressive training: not supported yet, prog_si always -1 + # find the nearest embedding + z_NC = ( + F.interpolate(f_rest, size=(ph, pw), mode="area") + .permute(0, 2, 3, 1) + .reshape(-1, C) + if (si != SN - 1) or ph != 16 + else f_rest.permute(0, 2, 3, 1).reshape(-1, C) + ) + if self.using_znorm: + z_NC = F.normalize(z_NC, dim=-1) + idx_N = torch.argmax( + z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1 + ) + else: + d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum( + self.embedding.weight.data.square(), dim=1, keepdim=False + ) + d_no_grad.addmm_( + z_NC, self.embedding.weight.data.T, alpha=-2, beta=1 + ) # (B*h*w, vocab_size) + idx_N = torch.argmin(d_no_grad, dim=1) + + idx_Bhw = idx_N.view(B, ph, pw) + h_BChw = ( + F.interpolate( + self.embedding(idx_Bhw).permute(0, 3, 1, 2), + size=(H, W), + mode="bicubic", + ).contiguous() + if (si != SN - 1) + else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous() + ) + if SN == 1: + h_BChw = self.quant_resi[0](h_BChw) + else: + h_BChw = self.quant_resi[si / (SN - 1)](h_BChw) + f_hat.add_(h_BChw) + f_rest.sub_(h_BChw) + f_hat_or_idx_Bl.append( + f_hat.clone() if to_fhat else idx_N.reshape(B, ph * pw) + ) + + return f_hat_or_idx_Bl + + # ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input ===================== + def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor: + next_scales = [] + B = gt_ms_idx_Bl[0].shape[0] + C = self.Cvae + H = W = self.v_patch_nums[-1] + SN = len(self.v_patch_nums) + + f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32) + pn_next: int = self.v_patch_nums[0] + for si in range(SN - 1): + if self.prog_si == 0 or (0 <= self.prog_si - 1 < si): + break # progressive training: not supported yet, prog_si always -1 + h_BChw = F.interpolate( + self.embedding(gt_ms_idx_Bl[si]) + .transpose_(1, 2) + .view(B, C, pn_next, pn_next), + size=(H, W), + mode="bicubic", + ) + f_hat.add_(self.quant_resi[si / (SN - 1)](h_BChw)) + pn_next = self.v_patch_nums[si + 1] + next_scales.append( + F.interpolate(f_hat, size=(pn_next, pn_next), mode="area") + .view(B, C, -1) + .transpose(1, 2) + ) + return ( + torch.cat(next_scales, dim=1) if len(next_scales) else None + ) # cat BlCs to BLC, this should be float32 + + # ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input ===================== + def get_next_autoregressive_input( + self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference + HW = self.v_patch_nums[-1] + if si != SN - 1: + h = self.quant_resi[si / (SN - 1)]( + F.interpolate(h_BChw, size=(HW, HW), mode="bicubic") + ) # conv after upsample + f_hat.add_(h) + return f_hat, F.interpolate( + f_hat, + size=(self.v_patch_nums[si + 1], self.v_patch_nums[si + 1]), + mode="area", + ) + else: + h = self.quant_resi[si / (SN - 1)](h_BChw) + f_hat.add_(h) + return f_hat, f_hat + + +class Phi(nn.Conv2d): + def __init__(self, embed_dim, quant_resi): + ks = 3 + super().__init__( + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=ks, + stride=1, + padding=ks // 2, + ) + self.resi_ratio = abs(quant_resi) + + def forward(self, h_BChw): + return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_( + self.resi_ratio + ) + + +class PhiShared(nn.Module): + def __init__(self, qresi: Phi): + super().__init__() + self.qresi: Phi = qresi + + def __getitem__(self, _) -> Phi: + return self.qresi + + +class PhiPartiallyShared(nn.Module): + def __init__(self, qresi_ls: nn.ModuleList): + super().__init__() + self.qresi_ls = qresi_ls + K = len(qresi_ls) + self.ticks = ( + np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) + if K == 4 + else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K) + ) + + def __getitem__(self, at_from_0_to_1: float) -> Phi: + return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()] + + def extra_repr(self) -> str: + return f"ticks={self.ticks}" + + +class PhiNonShared(nn.ModuleList): + def __init__(self, qresi: List): + super().__init__(qresi) + # self.qresi = qresi + K = len(qresi) + self.ticks = ( + np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) + if K == 4 + else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K) + ) + + def __getitem__(self, at_from_0_to_1: float) -> Phi: + return super().__getitem__( + np.argmin(np.abs(self.ticks - at_from_0_to_1)).item() + ) + + def extra_repr(self) -> str: + return f"ticks={self.ticks}" diff --git a/src/vqvaes/xqgan/vq_loss.py b/src/vqvaes/xqgan/vq_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..318be86e57f8395665d0489f8137f1c7026b4f6c --- /dev/null +++ b/src/vqvaes/xqgan/vq_loss.py @@ -0,0 +1,356 @@ +# Modified from: +# taming-transformers: https://github.com/CompVis/taming-transformers +# muse-maskgit-pytorch: https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/vqgan_vae.py +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + +from tokenizer.tokenizer_image.lpips import LPIPS +from tokenizer.tokenizer_image.discriminator_patchgan import ( + NLayerDiscriminator as PatchGANDiscriminator, +) +from tokenizer.tokenizer_image.discriminator_stylegan import ( + Discriminator as StyleGANDiscriminator, +) +from tokenizer.tokenizer_image.discriminator_dino import DinoDisc as DINODiscriminator +from tokenizer.tokenizer_image.diffaug import DiffAug +import wandb +import torch.distributed as tdist + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.softplus(-logits_real)) + loss_fake = torch.mean(F.softplus(logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def non_saturating_d_loss(logits_real, logits_fake): + loss_real = torch.mean( + F.binary_cross_entropy_with_logits(torch.ones_like(logits_real), logits_real) + ) + loss_fake = torch.mean( + F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake) + ) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def hinge_gen_loss(logit_fake): + return -torch.mean(logit_fake) + + +def non_saturating_gen_loss(logit_fake): + return torch.mean( + F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake), logit_fake) + ) + + +def adopt_weight(weight, global_step, threshold=0, value=0.0): + if global_step < threshold: + weight = value + return weight + + +def anneal_weight( + weight, + global_step, + threshold=0, + initial_value=0.3, + final_value=0.1, + anneal_steps=2000, +): + if global_step < threshold: + return initial_value + elif global_step < threshold + anneal_steps: + # Linearly interpolate between initial and final values within the anneal_steps + decay_ratio = (global_step - threshold) / anneal_steps + weight = initial_value - decay_ratio * (initial_value - final_value) + else: + # After annealing steps, set to final value + weight = final_value + return weight + + +class LeCAM_EMA(object): + def __init__(self, init=0.0, decay=0.999): + self.logits_real_ema = init + self.logits_fake_ema = init + self.decay = decay + + def update(self, logits_real, logits_fake): + self.logits_real_ema = self.logits_real_ema * self.decay + torch.mean( + logits_real + ).item() * (1 - self.decay) + self.logits_fake_ema = self.logits_fake_ema * self.decay + torch.mean( + logits_fake + ).item() * (1 - self.decay) + + +def lecam_reg(real_pred, fake_pred, lecam_ema): + reg = torch.mean(F.relu(real_pred - lecam_ema.logits_fake_ema).pow(2)) + torch.mean( + F.relu(lecam_ema.logits_real_ema - fake_pred).pow(2) + ) + return reg + + +class VQLoss(nn.Module): + def __init__( + self, + disc_start, + disc_loss="hinge", + disc_dim=64, + disc_type="patchgan", + image_size=256, + disc_num_layers=3, + disc_in_channels=3, + disc_weight=1.0, + disc_adaptive_weight=False, + gen_adv_loss="hinge", + reconstruction_loss="l2", + reconstruction_weight=1.0, + codebook_weight=1.0, + perceptual_weight=1.0, + lecam_loss_weight=None, + norm_type="bn", + aug_prob=1, + ): + super().__init__() + # discriminator loss + assert disc_type in ["patchgan", "stylegan", "dinodisc", "samdisc"] + assert disc_loss in ["hinge", "vanilla", "non-saturating"] + self.disc_type = disc_type + if disc_type == "patchgan": + self.discriminator = PatchGANDiscriminator( + input_nc=disc_in_channels, + n_layers=disc_num_layers, + ndf=disc_dim, + ) + elif disc_type == "stylegan": + self.discriminator = StyleGANDiscriminator( + input_nc=disc_in_channels, + image_size=image_size, + ) + elif disc_type == "dinodisc": + self.discriminator = DINODiscriminator( + norm_type=norm_type + ) # default 224 otherwise crop + self.daug = DiffAug(prob=aug_prob, cutout=0.2) + elif disc_type == "samdisc": + self.discriminator = SAMDiscriminator(norm_type=norm_type) + else: + raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.") + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + elif disc_loss == "non-saturating": + self.disc_loss = non_saturating_d_loss + else: + raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.") + self.discriminator_iter_start = disc_start + self.disc_weight = disc_weight + self.disc_adaptive_weight = disc_adaptive_weight + + assert gen_adv_loss in ["hinge", "non-saturating"] + # gen_adv_loss + if gen_adv_loss == "hinge": + self.gen_adv_loss = hinge_gen_loss + elif gen_adv_loss == "non-saturating": + self.gen_adv_loss = non_saturating_gen_loss + else: + raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.") + + # perceptual loss + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + + # reconstruction loss + if reconstruction_loss == "l1": + self.rec_loss = F.l1_loss + elif reconstruction_loss == "l2": + self.rec_loss = F.mse_loss + else: + raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.") + self.rec_weight = reconstruction_weight + + # codebook loss + self.codebook_weight = codebook_weight + + self.lecam_loss_weight = lecam_loss_weight + if self.lecam_loss_weight is not None: + self.lecam_ema = LeCAM_EMA() + + if tdist.get_rank() == 0: + self.wandb_tracker = wandb.init( + project="MSVQ", + ) + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer): + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + return d_weight.detach() + + def forward( + self, + codebook_loss, + sem_loss, + detail_loss, + dependency_loss, + inputs, + reconstructions, + optimizer_idx, + global_step, + last_layer=None, + logger=None, + log_every=100, + fade_blur_schedule=0, + ): + # generator update + if optimizer_idx == 0: + # reconstruction loss + rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous()) + + # perceptual loss + p_loss = self.perceptual_loss( + inputs.contiguous(), reconstructions.contiguous() + ) + p_loss = torch.mean(p_loss) + + # discriminator loss + if self.disc_type == "dinodisc": + if fade_blur_schedule < 1e-6: + fade_blur_schedule = 0 + logits_fake = self.discriminator( + self.daug.aug(reconstructions.contiguous(), fade_blur_schedule) + ) + else: + logits_fake = self.discriminator(reconstructions.contiguous()) + generator_adv_loss = self.gen_adv_loss(logits_fake) + + if self.disc_adaptive_weight: + null_loss = self.rec_weight * rec_loss + self.perceptual_weight * p_loss + disc_adaptive_weight = self.calculate_adaptive_weight( + null_loss, generator_adv_loss, last_layer=last_layer + ) + else: + disc_adaptive_weight = 1 + disc_weight = adopt_weight( + self.disc_weight, global_step, threshold=self.discriminator_iter_start + ) + if sem_loss is None: + sem_loss = 0 + if detail_loss is None: + detail_loss = 0 + if dependency_loss is None: + dependency_loss = 0 + loss = ( + self.rec_weight * rec_loss + + self.perceptual_weight * p_loss + + disc_adaptive_weight * disc_weight * generator_adv_loss + + codebook_loss[0] + + codebook_loss[1] + + codebook_loss[2] + + sem_loss + + detail_loss + + dependency_loss + ) + + if global_step % log_every == 0: + rec_loss = self.rec_weight * rec_loss + p_loss = self.perceptual_weight * p_loss + generator_adv_loss = ( + disc_adaptive_weight * disc_weight * generator_adv_loss + ) + logger.info( + f"(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, sem_loss: {sem_loss:.4f}, detail_loss: {detail_loss} " + f"dependency_loss: {dependency_loss} vq_loss: {codebook_loss[0]:.4f}, commit_loss: {codebook_loss[1]:.4f}, entropy_loss: {codebook_loss[2]:.4f}, " + f"codebook_usage: {codebook_loss[3]}, generator_adv_loss: {generator_adv_loss:.4f}, " + f"disc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}" + ) + if tdist.get_rank() == 0: + self.wandb_tracker.log( + { + "rec_loss": rec_loss, + "perceptual_loss": p_loss, + "sem_loss": sem_loss, + "detail_loss": detail_loss, + "dependency_loss": dependency_loss, + "vq_loss": codebook_loss[0], + "commit_loss": codebook_loss[1], + "entropy_loss": codebook_loss[2], + "codebook_usage": np.mean(codebook_loss[3]), + "generator_adv_loss": generator_adv_loss, + "disc_adaptive_weight": disc_adaptive_weight, + "disc_weight": disc_weight, + }, + step=global_step, + ) + return loss + + # discriminator update + if optimizer_idx == 1: + + if self.disc_type == "dinodisc": + if fade_blur_schedule < 1e-6: + fade_blur_schedule = 0 + # add blur since disc is too strong + logits_fake = self.discriminator( + self.daug.aug( + reconstructions.contiguous().detach(), fade_blur_schedule + ) + ) + logits_real = self.discriminator( + self.daug.aug(inputs.contiguous().detach(), fade_blur_schedule) + ) + else: + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + logits_real = self.discriminator(inputs.contiguous().detach()) + + disc_weight = adopt_weight( + self.disc_weight, global_step, threshold=self.discriminator_iter_start + ) + + if self.lecam_loss_weight is not None: + self.lecam_ema.update(logits_real, logits_fake) + lecam_loss = lecam_reg(logits_real, logits_fake, self.lecam_ema) + non_saturate_d_loss = self.disc_loss(logits_real, logits_fake) + d_adversarial_loss = disc_weight * ( + lecam_loss * self.lecam_loss_weight + non_saturate_d_loss + ) + else: + d_adversarial_loss = disc_weight * self.disc_loss( + logits_real, logits_fake + ) + + if global_step % log_every == 0: + logits_real = logits_real.detach().mean() + logits_fake = logits_fake.detach().mean() + logger.info( + f"(Discriminator) " + f"discriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, " + f"logits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f}" + ) + if tdist.get_rank() == 0: + self.wandb_tracker.log( + { + "discriminator_adv_loss": d_adversarial_loss, + "disc_weight": disc_weight, + "logits_real": logits_real, + "logits_fake": logits_fake, + }, + step=global_step, + ) + return d_adversarial_loss diff --git a/src/vqvaes/xqgan/vq_model.py b/src/vqvaes/xqgan/vq_model.py new file mode 100644 index 0000000000000000000000000000000000000000..789bec5d15a5fb9c064c3a2c13b8090da4c6ba9d --- /dev/null +++ b/src/vqvaes/xqgan/vq_model.py @@ -0,0 +1,564 @@ +# Modified from: +# taming-transformers: https://github.com/CompVis/taming-transformers +# maskgit: https://github.com/google-research/maskgit +from dataclasses import dataclass, field +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from tokenizer.tokenizer_image.cliploss import ClipLoss +from timm.models import create_model + + +@dataclass +class ModelArgs: + codebook_size: int = 16384 + codebook_embed_dim: int = 8 + codebook_l2_norm: bool = True + codebook_show_usage: bool = True + commit_loss_beta: float = 0.25 + entropy_loss_ratio: float = 0.0 + + encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + z_channels: int = 256 + dropout_p: float = 0.0 + + +class VQModel(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.encoder = Encoder( + ch_mult=config.encoder_ch_mult, + z_channels=config.z_channels, + dropout=config.dropout_p, + ) + self.decoder = Decoder( + ch_mult=config.decoder_ch_mult, + z_channels=config.z_channels, + dropout=config.dropout_p, + ) + + self.quantize = VectorQuantizer( + config.codebook_size, + config.codebook_embed_dim, + config.commit_loss_beta, + config.entropy_loss_ratio, + config.codebook_l2_norm, + config.codebook_show_usage, + ) + self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1) + self.post_quant_conv = nn.Conv2d( + config.codebook_embed_dim, config.z_channels, 1 + ) + + # Semantic loss to preserve dino semantics + self.semantic_guide = "dinov2" # 'none' # + if self.semantic_guide == "dinov2": + semantic_model = create_model( + "vit_small_patch14_dinov2.lvd142m", + pretrained=True, + img_size=256, + patch_size=16, + drop_path_rate=0.0, + ) + semantic_model.eval() + for param in semantic_model.parameters(): + param.requires_grad = False + self.semantic_model = semantic_model + + local_loss = False + gather_with_grad = False + rank = 0 + world_size = 8 + use_horovod = False + sem_loss_scale = 1.0 + + self.sem_loss_scale = sem_loss_scale + self.semantic_loss = ClipLoss( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=True, + rank=rank, + world_size=world_size, + use_horovod=use_horovod, + ) + + self.sem_linear = torch.nn.Linear(384, config.codebook_embed_dim) + self.sem_loss_weight = 0.01 + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b, shape=None, channel_first=True): + quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, _ = self.encode(input) + dec = self.decode(quant) + + if self.semantic_guide != "none": + z_s = self.semantic_model(input) + z_q_ = torch.mean(quant, dim=(2, 3)).contiguous() + z_s = self.sem_linear(z_s).contiguous() + sem_loss = self.semantic_loss.forward( + z_s, z_q_, logit_scale=self.sem_loss_scale + ) + sem_loss = sem_loss * self.sem_loss_weight + else: + sem_loss = None + return dec, diff, sem_loss + + +class Encoder(nn.Module): + def __init__( + self, + in_channels=3, + ch=128, + ch_mult=(1, 1, 2, 2, 4), + num_res_blocks=2, + norm_type="group", + dropout=0.0, + resamp_with_conv=True, + z_channels=256, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1) + + # downsampling + in_ch_mult = (1,) + tuple(ch_mult) + self.conv_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + res_block.append( + ResnetBlock( + block_in, block_out, dropout=dropout, norm_type=norm_type + ) + ) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != self.num_resolutions - 1: + conv_block.downsample = Downsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # middle + self.mid = nn.ModuleList() + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d( + block_in, z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + h = self.conv_in(x) + # downsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.downsample(h) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + z_channels=256, + ch=128, + ch_mult=(1, 1, 2, 2, 4), + num_res_blocks=2, + norm_type="group", + dropout=0.0, + resamp_with_conv=True, + out_channels=3, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + + block_in = ch * ch_mult[self.num_resolutions - 1] + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.ModuleList() + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + + # upsampling + self.conv_blocks = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + block_in, block_out, dropout=dropout, norm_type=norm_type + ) + ) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != 0: + conv_block.upsample = Upsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + @property + def last_layer(self): + return self.conv_out.weight + + def forward(self, z): + # z to block_in + h = self.conv_in(z) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # upsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks + 1): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VectorQuantizer(nn.Module): + def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.entropy_loss_ratio = entropy_loss_ratio + self.l2_norm = l2_norm + self.show_usage = show_usage + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + if self.l2_norm: + self.embedding.weight.data = F.normalize( + self.embedding.weight.data, p=2, dim=-1 + ) + if self.show_usage: + self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536))) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = torch.einsum("b c h w -> b h w c", z).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + if self.l2_norm: + z = F.normalize(z, p=2, dim=-1) + z_flattened = F.normalize(z_flattened, p=2, dim=-1) + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(embedding**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding) + ) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = embedding[min_encoding_indices].view(z.shape) + perplexity = None + min_encodings = None + vq_loss = None + commit_loss = None + entropy_loss = None + codebook_usage = 0 + + if self.show_usage and self.training: + cur_len = min_encoding_indices.shape[0] + self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone() + self.codebook_used[-cur_len:] = min_encoding_indices + codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e + + # compute loss for embedding + if self.training: + vq_loss = torch.mean((z_q - z.detach()) ** 2) + commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = torch.einsum("b h w c -> b c h w", z_q) + + return ( + z_q, + (vq_loss, commit_loss, entropy_loss, codebook_usage), + (perplexity, min_encodings, min_encoding_indices), + ) + + def get_codebook_entry(self, indices, shape=None, channel_first=True): + # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel) + if self.l2_norm: + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + z_q = embedding[indices] # (b*h*w, c) + + if shape is not None: + if channel_first: + z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + else: + z_q = z_q.view(shape) + return z_q + + +class ResnetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + norm_type="group", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = Normalize(out_channels, norm_type) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type="group"): + super().__init__() + self.norm = Normalize(in_channels, norm_type) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, norm_type="group"): + assert norm_type in ["group", "batch"] + if norm_type == "group": + return nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + elif norm_type == "batch": + return nn.SyncBatchNorm(in_channels) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): + flat_affinity = affinity.reshape(-1, affinity.shape[-1]) + flat_affinity /= temperature + probs = F.softmax(flat_affinity, dim=-1) + log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) + if loss_type == "softmax": + target_probs = probs + else: + raise ValueError("Entropy loss {} not supported".format(loss_type)) + avg_probs = torch.mean(target_probs, dim=0) + avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5)) + sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1)) + loss = sample_entropy - avg_entropy + return loss + + +################################################################################# +# VQ Model Configs # +################################################################################# +def VQ_8(**kwargs): + return VQModel( + ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs) + ) + + +def VQ_16(**kwargs): + return VQModel( + ModelArgs( + encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs + ) + ) + + +VQ_models = {"VQ-16": VQ_16, "VQ-8": VQ_8} diff --git a/src/vqvaes/xqgan/xqgan.py b/src/vqvaes/xqgan/xqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/vqvaes/xqgan/xqgan_model.py b/src/vqvaes/xqgan/xqgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..24b4c884035534ad9015853d74d82661c41f16ab --- /dev/null +++ b/src/vqvaes/xqgan/xqgan_model.py @@ -0,0 +1,1142 @@ +# Modified from: +# taming-transformers: https://github.com/CompVis/taming-transformers +# maskgit: https://github.com/google-research/maskgit +from dataclasses import dataclass, field +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.models import create_model + +import sys, os +from math import sqrt + +# current_dir = os.path.dirname(os.path.abspath(__file__)) +# project_root = os.path.abspath(os.path.join(current_dir, '../..')) +# +# sys.path.append(project_root) + +from .cliploss import ClipLoss +from .quant import VectorQuantizer2 +from .lookup_free_quantize import LFQ +from .dino_enc.dinov2 import DINOv2Encoder, DINOv2Decoder +from .latent_perturbation import add_perturbation +from datasets import Denormalize +from datasets import Normalize as ImgNormalize + +import torch.distributed as tdist + + +@dataclass +class ModelArgs: + codebook_size: int = 16384 + codebook_embed_dim: int = 8 + codebook_l2_norm: bool = True + codebook_show_usage: bool = True + commit_loss_beta: float = 0.25 + entropy_loss_ratio: float = 0.0 + + encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + z_channels: int = 256 + dropout_p: float = 0.0 + + v_patch_nums: List[int] = field( + default_factory=lambda: [1, 2, 3, 4, 5, 6, 8, 10, 13, 16] + ) + enc_type: str = "cnn" + dec_type: str = "cnn" + semantic_guide: str = "dinov2" + detail_guide: str = "clip" + num_latent_tokens: int = 256 + encoder_model: str = "vit_small_patch14_dinov2.lvd142m" + decoder_model: str = "vit_small_patch14_dinov2.lvd142m" + abs_pos_embed: bool = False + share_quant_resi: int = 4 + product_quant: int = 1 + codebook_drop: float = 0.0 + half_sem: bool = False + start_drop: int = 1 + sem_loss_weight: float = 0.1 + detail_loss_weight: float = 0.1 + clip_norm: bool = False + sem_loss_scale: float = 1.0 + detail_loss_scale: float = 1.0 + guide_type_1: str = "class" + guide_type_2: str = "class" + + lfq: bool = False + scale: float = 1.0 + soft_entropy: bool = True + + dependency_loss_weight: float = 0.0 + + test_model: bool = False + + +class VQModel(nn.Module): + def __init__( + self, + config: ModelArgs, + ): + super().__init__() + self.config = config + self.enc_type = config.enc_type + self.dec_type = config.dec_type + self.product_quant = config.product_quant + self.half_sem = config.half_sem + self.start_drop = config.start_drop + self.clip_norm = config.clip_norm + config.num_latent_tokens = ( + config.num_latent_tokens * config.product_quant + ) # scale num_latent_tokens for PQ + + if config.enc_type == "cnn": + self.encoder = Encoder( + ch_mult=config.encoder_ch_mult, + z_channels=config.z_channels, + dropout=config.dropout_p, + ) + self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1) + elif config.enc_type == "dinov2": + self.encoder = DINOv2Encoder( + in_channels=3, + num_latent_tokens=config.num_latent_tokens, + model_name=config.encoder_model, # 'vit_small_patch14_dinov2.lvd142m', #'vit_base_patch14_dinov2.lvd142m', # + model_kwargs={"img_size": 256, "patch_size": 16, "drop_path_rate": 0.1}, + pretrained=True, + tuning_method="full", + tuning_kwargs={"r": 8}, + abs_pos_embed=config.abs_pos_embed, + product_quant=config.product_quant, + ) + self.quant_conv = nn.Conv2d( + self.encoder.embed_dim, config.codebook_embed_dim, 1 + ) + else: + raise NotImplementedError + + if config.dec_type == "cnn": + self.decoder = Decoder( + ch_mult=config.decoder_ch_mult, + z_channels=config.z_channels, + dropout=config.dropout_p, + ) + self.post_quant_conv = nn.Conv2d( + config.codebook_embed_dim, config.z_channels, 1 + ) + elif config.dec_type == "dinov2": + self.decoder = DINOv2Decoder( + in_channels=3, + num_latent_tokens=config.num_latent_tokens // self.product_quant, + model_name=config.decoder_model, + model_kwargs={"img_size": 256, "patch_size": 16, "drop_path_rate": 0.1}, + pretrained=True, + tuning_method="full", + tuning_kwargs={"r": 8}, + to_pixel="linear", + use_rope=False, + cond_latent=False, + abs_pos_embed=config.abs_pos_embed, + ) + self.post_quant_conv = nn.Conv2d( + config.codebook_embed_dim, self.decoder.embed_dim, 1 + ) + + self.V = self.vocab_size = config.codebook_size * self.product_quant + self.Cvae = config.codebook_embed_dim * self.product_quant + if self.product_quant > 1: + if len(config.v_patch_nums) == 1: + self.quantizes = nn.ModuleList( + [ + VectorQuantizer( + config.codebook_size, + config.codebook_embed_dim, + config.commit_loss_beta, + config.codebook_l2_norm, + ) + for _ in range(self.product_quant) + ] + ) + elif not config.lfq: + self.quantizes = nn.ModuleList( + [ + VectorQuantizer2( + config.codebook_size, + config.codebook_embed_dim, + v_patch_nums=config.v_patch_nums, + num_latent_tokens=config.num_latent_tokens + // self.product_quant, + share_quant_resi=config.share_quant_resi, + codebook_drop=config.codebook_drop, + ) + for _ in range(self.product_quant) + ] + ) + else: + self.quantizes = nn.ModuleList( + [ + LFQ( + config.codebook_size, + config.codebook_embed_dim, + v_patch_nums=config.v_patch_nums, + num_latent_tokens=config.num_latent_tokens + // self.product_quant, + share_quant_resi=config.share_quant_resi, + codebook_drop=config.codebook_drop, + using_znorm=config.codebook_l2_norm, + scale=config.scale, + entropy_weight=config.entropy_loss_ratio, + soft_entropy=config.soft_entropy, + ) + for _ in range(self.product_quant) + ] + ) + self.post_quant_conv = nn.Conv2d( + config.codebook_embed_dim * self.product_quant, + self.decoder.embed_dim, + 1, + ) + else: + if len(config.v_patch_nums) == 1: + self.quantize = VectorQuantizer( + config.codebook_size, + config.codebook_embed_dim, + config.commit_loss_beta, + config.codebook_l2_norm, + ) + elif not config.lfq: + self.quantize = VectorQuantizer2( + config.codebook_size, + config.codebook_embed_dim, + v_patch_nums=config.v_patch_nums, + num_latent_tokens=config.num_latent_tokens, + share_quant_resi=config.share_quant_resi, + ) + else: + self.quantize = LFQ( + config.codebook_size, + config.codebook_embed_dim, + v_patch_nums=config.v_patch_nums, + num_latent_tokens=config.num_latent_tokens, + share_quant_resi=config.share_quant_resi, + codebook_drop=config.codebook_drop, + using_znorm=config.codebook_l2_norm, + scale=config.scale, + entropy_weight=config.entropy_loss_ratio, + soft_entropy=config.soft_entropy, + ) + + self.codebook_embed_dim = config.codebook_embed_dim + self.v_patch_nums = config.v_patch_nums + self.codebook_drop = config.codebook_drop + # Semantic loss to preserve dino semantics + self.semantic_guide = config.semantic_guide + self.denormalize = Denormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + self.normalize = ImgNormalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + if self.semantic_guide == "dinov2": + semantic_model = create_model( + config.encoder_model, + pretrained=True, + img_size=256, + patch_size=16, + drop_path_rate=0.0, + ) + semantic_model.eval() + for param in semantic_model.parameters(): + param.requires_grad = False + self.semantic_model = ( + semantic_model # torch.compile(semantic_model, mode='max-autotune') + ) + + local_loss = False + gather_with_grad = True + rank = tdist.get_rank() + world_size = tdist.get_world_size() + use_horovod = False + sem_loss_scale = config.sem_loss_scale + + self.sem_loss_scale = sem_loss_scale + self.semantic_loss = ClipLoss( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=True, + rank=rank, + world_size=world_size, + use_horovod=use_horovod, + ) + if not self.half_sem and self.product_quant > 1: + self.sem_linear = nn.Conv2d( + self.product_quant * config.codebook_embed_dim, + config.codebook_embed_dim, + 1, + ) + elif self.half_sem and self.product_quant == 1: + self.sem_linear = nn.Conv2d(768, config.codebook_embed_dim // 2, 1) + if self.enc_type == "cnn": + self.sem_linear = torch.nn.Linear(384, config.codebook_embed_dim) + + self.sem_loss_weight = config.sem_loss_weight + + self.detail_guide = config.detail_guide + if self.detail_guide != "none": + detail_model = create_model( + "vit_base_patch16_clip_224.openai", + pretrained=True, + img_size=256, + patch_size=16, + drop_path_rate=0.0, + ) + detail_model.eval() + for param in detail_model.parameters(): + param.requires_grad = False + self.detail_model = detail_model + + self.detail_loss_scale = config.detail_loss_scale + self.detail_loss = ClipLoss( + local_loss=False, + gather_with_grad=True, + cache_labels=True, + rank=tdist.get_rank(), + world_size=tdist.get_world_size(), + use_horovod=False, + ) + self.detail_loss_weight = config.detail_loss_weight + + self.guide_type_1 = config.guide_type_1 + self.guide_type_2 = config.guide_type_2 + self.dependency_loss_weight = config.dependency_loss_weight + + self.test_mode = config.test_model + + if self.test_mode: + self.eval() + [p.requires_grad_(False) for p in self.parameters()] + + def finetune(self, enc_tuning_method, dec_tuning_method): + self.encoder.finetine(enc_tuning_method) + self.decoder.finetine(dec_tuning_method) + + def encode(self, x): + h = self.encoder(x) + if self.enc_type == "dinov2": + b, l, c = h.shape + if self.product_quant > 1: + assert int(sqrt(l // self.product_quant)) ** 2 * self.product_quant == l + h = h.view(b, l, 1, c) + h = h.permute(0, 3, 1, 2) + else: + assert int(sqrt(l)) ** 2 == l + h = h.view(b, int(sqrt(l)), int(sqrt(l)), c) + h = h.permute(0, 3, 1, 2) + h = self.quant_conv(h) + return h + + def decode(self, quant, return_quant=False): + quant = self.post_quant_conv(quant) + if self.dec_type == "dinov2": + quant = quant.flatten(2).permute(0, 2, 1) + dec = self.decoder(quant) + return dec + + def decode_code( + self, + code_b, + ): + quant_b, usages, mean_vq_loss = self.quantize(code_b, ret_usages=True) + dec = self.decode(quant_b) + return dec + + def forward(self, input, epoch, alpha, beta, delta): + h = self.encode(input) + b, c, l, _ = h.shape + if len(self.v_patch_nums) == 1: + dropout_rand = None + else: + dropout_rand = torch.randint( + self.start_drop, len(self.v_patch_nums) + 1, (b,) + ) # to fix dropout across quantizers, skip first start_drop-1 quantizers + + if self.product_quant > 1: + h_list = h.chunk(chunks=self.product_quant, dim=2) + ( + quant_list, + usages_list, + mean_vq_loss_list, + commit_loss_list, + entropy_list, + ) = ([], [], [], [], []) + for i, h in enumerate(h_list): + h = h.view( + b, + -1, + int(sqrt(l // self.product_quant)), + int(sqrt(l // self.product_quant)), + ) + quant, usages, vq_loss, commit_loss, entropy_loss = self.quantizes[ + i + ].forward(h, ret_usages=True, dropout=dropout_rand) + quant_list.append(quant) + usages_list.append(usages) + mean_vq_loss_list.append(vq_loss) + commit_loss_list.append(commit_loss) + entropy_list.append(entropy_loss) + dependency_loss = self.dependency_loss_weight * orthogonal_cosine_loss( + torch.mean(quant_list[0], dim=(2, 3)).contiguous(), + torch.mean(quant_list[-1], dim=(2, 3)).contiguous(), + ) + usages = [sum(us) / self.product_quant for us in zip(*usages_list)] + mean_vq_loss = sum(mean_vq_loss_list) / self.product_quant + mean_commit_loss = sum(commit_loss_list) / self.product_quant + mean_entropy = sum(entropy_list) / self.product_quant + quant = torch.cat(quant_list, dim=1) + else: + dependency_loss = 0.0 + quant, usages, mean_vq_loss, mean_commit_loss, mean_entropy = ( + self.quantize.forward(h, ret_usages=True, dropout=dropout_rand) + ) + print(alpha, beta, delta) + quant = add_perturbation( + h, + quant, + self.quantize.z_channels, + self.quantize.codebook_norm, + self.quantize.embedding, + alpha, + beta, + delta, + ) + quant_list = [quant] + + dec = self.decode(quant) + + # normalize the inputs to dino's transform + input = self.normalize(self.denormalize(input)) + if self.semantic_guide != "none": + if self.guide_type_1 == "class": + z_s = self.semantic_model(input) + z_s = z_s[..., None, None] + else: + z_s = self.semantic_model.forward_features(input)[:, 1:, :] + z_s = z_s.reshape(b, 768, 16, 16) + if self.enc_type == "dinov2": + z_s = self.quant_conv(z_s).contiguous() + semantic_quant = quant_list[-1] + z_s = torch.mean(z_s, dim=(2, 3)).contiguous() + z_q_ = torch.mean(semantic_quant, dim=(2, 3)).contiguous() + elif self.enc_type == "cnn": + z_q_ = torch.mean(h, dim=(2, 3)).contiguous() + z_s = self.sem_linear(z_s).contiguous() + + n_drop = int(b * self.codebook_drop) + with torch.cuda.amp.autocast(enabled=False): + sem_loss_scale = self.sem_loss_scale + feat1 = z_s[n_drop:].float() + feat2 = z_q_[n_drop:].float() + if self.clip_norm: + feat1 = feat1 / feat1.norm(dim=1, keepdim=True) + feat2 = feat2 / feat2.norm(dim=1, keepdim=True) + sem_loss_scale = ( + (epoch % 200) / 200 * (100 - sem_loss_scale) + sem_loss_scale + if epoch < 200 + else 100 + ) + sem_loss = self.semantic_loss.forward( + feat1, feat2, logit_scale=sem_loss_scale + ) + sem_loss = sem_loss * self.sem_loss_weight + else: + sem_loss = None + + if self.detail_guide != "none": + assert ( + self.guide_type_2 == "patch" + ), "current only accept patch for detail guide" + if self.guide_type_2 == "class": + z_d = self.detail_model(input) + z_d = z_d[..., None, None] + else: + z_d = self.detail_model.forward_features(input)[:, 1:, :] + z_d = z_d.reshape(b, 768, 16, 16) + if self.enc_type == "dinov2": + z_d = self.quant_conv(z_d).contiguous() + detail_quant = quant_list[0] + z_d = torch.mean(z_d, dim=(2, 3)).contiguous() + z_q_ = torch.mean(detail_quant, dim=(2, 3)).contiguous() + elif self.enc_type == "cnn": + pass + + n_drop = int(b * self.codebook_drop) + with torch.cuda.amp.autocast(enabled=False): + detail_loss_scale = self.detail_loss_scale + feat1 = z_d[n_drop:].float() + feat2 = z_q_[n_drop:].float() + if self.clip_norm: + feat1 = feat1 / feat1.norm(dim=1, keepdim=True) + feat2 = feat2 / feat2.norm(dim=1, keepdim=True) + detail_loss_scale = ( + (epoch % 200) / 200 * (100 - detail_loss_scale) + + detail_loss_scale + if epoch < 200 + else 100 + ) + detail_loss = self.detail_loss.forward( + feat1, feat2, logit_scale=detail_loss_scale + ) + detail_loss = detail_loss * self.detail_loss_weight + else: + detail_loss = None + + return ( + dec, + (mean_vq_loss, mean_commit_loss, mean_entropy, usages), + sem_loss, + detail_loss, + dependency_loss, + ) + + def img_to_reconstructed_img( + self, + x, + last_one=True, + ) -> List[torch.Tensor]: + h = self.encoder(x) + if self.enc_type == "dinov2": + b, l, c = h.shape + if self.product_quant > 1: + assert int(sqrt(l // self.product_quant)) ** 2 * self.product_quant == l + h = h.view(b, l, 1, c) + h = h.permute(0, 3, 1, 2) + else: + assert int(sqrt(l)) ** 2 == l + h = h.view(b, int(sqrt(l)), int(sqrt(l)), c) + h = h.permute(0, 3, 1, 2) + f = self.quant_conv(h) + + if self.product_quant > 1: + b, c, l, _ = f.shape + f_list = f.chunk(chunks=self.product_quant, dim=2) + f_list = [ + f.view( + b, + -1, + int(sqrt(l // self.product_quant)), + int(sqrt(l // self.product_quant)), + ) + for f in f_list + ] + if len(self.v_patch_nums) == 1: + f_hats_list = [ + self.quantizes[i].f_to_idxBl_or_fhat( + f, to_fhat=True, v_patch_nums=None + ) + for i, f in enumerate(f_list) + ] + else: + f_hats_list = [ + self.quantizes[i].f_to_idxBl_or_fhat( + f, to_fhat=True, v_patch_nums=self.v_patch_nums + ) + for i, f in enumerate(f_list) + ] + f_hats = [ + self.post_quant_conv(torch.cat(f_hats, dim=1)) + for f_hats in zip(*f_hats_list) + ] + else: + if len(self.v_patch_nums) == 1: + ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat( + f, to_fhat=True, v_patch_nums=None + ) + else: + ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat( + f, to_fhat=True, v_patch_nums=self.v_patch_nums + ) + f_hats = [self.post_quant_conv(f_hat) for f_hat in ls_f_hat_BChw] + + if self.dec_type == "dinov2": + f_hats = [f_hat.flatten(2).permute(0, 2, 1) for f_hat in f_hats] + + if last_one: + return self.decoder(f_hats[-1]).clamp_(-1, 1) + else: + return [self.decoder(f_hat).clamp_(-1, 1) for f_hat in f_hats] + + def img_to_sem_feat( + self, + x, + ) -> List[torch.Tensor]: + h = self.encoder(x) + if self.enc_type == "dinov2": + b, l, c = h.shape + if self.product_quant > 1: + assert int(sqrt(l // self.product_quant)) ** 2 * self.product_quant == l + h = h.view(b, l, 1, c) + h = h.permute(0, 3, 1, 2) + else: + assert int(sqrt(l)) ** 2 == l + h = h.view(b, int(sqrt(l)), int(sqrt(l)), c) + h = h.permute(0, 3, 1, 2) + f = self.quant_conv(h) + + b, c, l, _ = f.shape + f_list = f.chunk(chunks=self.product_quant, dim=2) + f_list = [ + f.view( + b, + -1, + int(sqrt(l // self.product_quant)), + int(sqrt(l // self.product_quant)), + ) + for f in f_list + ] + f_hats_list = [ + self.quantizes[i].f_to_idxBl_or_fhat( + f, to_fhat=True, v_patch_nums=self.v_patch_nums + ) + for i, f in enumerate(f_list) + ] + + z_q = f_hats_list[-1][ + -1 + ] # torch.mean(f_hats_list[-1][-1], dim=(2, 3)).contiguous() + return z_q + + def fhat_to_img(self, f_hat: torch.Tensor): + f_hat = self.post_quant_conv(f_hat) + if self.dec_type == "dinov2": + f_hat = f_hat.flatten(2).permute(0, 2, 1) + return self.decoder(f_hat).clamp_(-1, 1) + + def idxBl_to_var_input(self, gt_idx_Bl): + if self.product_quant > 1: + x_BLCv_wo_first_l_list = [ + self.quantizes[i].idxBl_to_var_input(gt_idx_Bl[i]) + for i in range(self.product_quant) + ] + return torch.cat(x_BLCv_wo_first_l_list, dim=-1) + else: + return self.quantize.idxBl_to_var_input(gt_idx_Bl) + + def get_next_autoregressive_input(self, si, SN, f_hat, h_BChw): + f_hat_list = f_hat.chunk(self.product_quant, dim=1) + h_BChw_list = h_BChw.chunk(self.product_quant, dim=1) + out_fhat_list, out_next_token_map_list = [], [] + for i, (f_hat, h_BChw) in enumerate(zip(f_hat_list, h_BChw_list)): + out_fhat, out_next_token_map = self.quantizes[ + i + ].get_next_autoregressive_input(si, SN, f_hat, h_BChw) + out_fhat_list.append(out_fhat) + out_next_token_map_list.append(out_next_token_map) + f_hat = torch.cat(out_fhat_list, dim=1) + next_token_map = torch.cat(out_next_token_map_list, dim=1) + return f_hat, next_token_map + + +class Encoder(nn.Module): + def __init__( + self, + in_channels=3, + ch=128, + ch_mult=(1, 1, 2, 2, 4), + num_res_blocks=2, + norm_type="group", + dropout=0.0, + resamp_with_conv=True, + z_channels=256, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1) + + # downsampling + in_ch_mult = (1,) + tuple(ch_mult) + self.conv_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + res_block.append( + ResnetBlock( + block_in, block_out, dropout=dropout, norm_type=norm_type + ) + ) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != self.num_resolutions - 1: + conv_block.downsample = Downsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # middle + self.mid = nn.ModuleList() + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d( + block_in, z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + h = self.conv_in(x) + # downsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.downsample(h) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + z_channels=256, + ch=128, + ch_mult=(1, 1, 2, 2, 4), + num_res_blocks=2, + norm_type="group", + dropout=0.0, + resamp_with_conv=True, + out_channels=3, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + + block_in = ch * ch_mult[self.num_resolutions - 1] + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.ModuleList() + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + + # upsampling + self.conv_blocks = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + block_in, block_out, dropout=dropout, norm_type=norm_type + ) + ) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != 0: + conv_block.upsample = Upsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + @property + def last_layer(self): + return self.conv_out.weight + + def forward(self, z): + # z to block_in + h = self.conv_in(z) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # upsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks + 1): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class ResnetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + norm_type="group", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = Normalize(out_channels, norm_type) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type="group"): + super().__init__() + self.norm = Normalize(in_channels, norm_type) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, norm_type="group"): + assert norm_type in ["group", "batch"] + if norm_type == "group": + return nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + elif norm_type == "batch": + return nn.SyncBatchNorm(in_channels) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): + flat_affinity = affinity.reshape(-1, affinity.shape[-1]) + flat_affinity /= temperature + probs = F.softmax(flat_affinity, dim=-1) + log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) + if loss_type == "softmax": + target_probs = probs + else: + raise ValueError("Entropy loss {} not supported".format(loss_type)) + avg_probs = torch.mean(target_probs, dim=0) + avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5)) + sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1)) + loss = sample_entropy - avg_entropy + return loss + + +class VectorQuantizer(nn.Module): + + def __init__(self, vocab_size=8192, z_channels=32, beta=0.25, codebook_norm=True): + super().__init__() + # parameters + self.vocab_size = vocab_size + self.z_channels = z_channels + self.beta = beta + self.codebook_norm = codebook_norm + # self.restart_unused_codes = restart_unused_codes + + # embedding layer + self.embedding = nn.Embedding(self.vocab_size, self.z_channels) + self.embedding.weight.data.uniform_( + -1.0 / self.vocab_size, 1.0 / self.vocab_size + ) + if self.codebook_norm: + self.embedding.weight.data = F.normalize( + self.embedding.weight.data, p=2, dim=-1 + ) + + self.register_buffer( + "ema_vocab_hit_SV", torch.full((self.vocab_size,), fill_value=0.0) + ) + self.record_hit = 0 + + def no_weight_decay(self): + return [ + "embedding.weight", + ] + + def forward(self, z, ret_usages=True, dropout=None): + + vocab_hit_V = torch.zeros(self.vocab_size, dtype=torch.float, device=z.device) + + # reshape z -> (batch, height * width, channel) and flatten + z = torch.einsum("b c h w -> b h w c", z).contiguous() + z_flattened = z.view(-1, self.z_channels) + + if self.codebook_norm: + z = F.normalize(z, p=2, dim=-1) + z_flattened = F.normalize(z_flattened, p=2, dim=-1) + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(embedding**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding) + ) + ) + + # argmin find indices and embeddings + min_encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(min_encoding_indices).view(z.shape) + if self.codebook_norm: + z_q = F.normalize(z_q, p=2, dim=-1) + + if ret_usages and self.training: + hit_V = min_encoding_indices.bincount(minlength=self.vocab_size).float() + handler = tdist.all_reduce(hit_V, async_op=True) + handler.wait() + if self.record_hit == 0: + self.ema_vocab_hit_SV.copy_(hit_V) + elif self.record_hit < 100: + self.ema_vocab_hit_SV.mul_(0.9).add_(hit_V.mul(0.1)) + else: + self.ema_vocab_hit_SV.mul_(0.99).add_(hit_V.mul(0.01)) + self.record_hit += 1 + vocab_hit_V.add_(hit_V) + + margin = ( + tdist.get_world_size() + * (z.numel() / self.z_channels) + / self.vocab_size + * 0.08 + ) + + codebook_usage = ( + self.ema_vocab_hit_SV >= margin + ).float().mean().item() * 100 + + # compute loss + commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + vq_loss = torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients - "straight-through" + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = torch.einsum("b h w c -> b c h w", z_q) + + return z_q, [codebook_usage], vq_loss, commit_loss, 0.0 + + def f_to_idxBl_or_fhat( + self, z: torch.Tensor, to_fhat: bool, v_patch_nums + ): # z_BChw is the feature from inp_img_no_grad + # reshape z -> (batch, height, width, channel) and flatten + z = torch.einsum("b c h w -> b h w c", z).contiguous() + z_flattened = z.view(-1, self.z_channels) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + if self.codebook_norm: + z = F.normalize(z, p=2, dim=-1) + z_flattened = F.normalize(z_flattened, p=2, dim=-1) + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(embedding**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding) + ) + ) + + # argmin find indices and embeddings + min_encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(min_encoding_indices).view(z.shape) + if self.codebook_norm: + z_q = F.normalize(z_q, p=2, dim=-1) + + # reshape back to match original input shape + z_q = torch.einsum("b h w c -> b c h w", z_q) + + f_hat_or_idx_Bl: List[torch.Tensor] = [z_q if to_fhat else min_encoding_indices] + + return f_hat_or_idx_Bl + + +def orthogonal_cosine_loss(A, B): + A_norm = A / A.norm(dim=1, keepdim=True) + B_norm = B / B.norm(dim=1, keepdim=True) + loss = (A_norm * B_norm).sum(dim=1).mean() + return loss + + +################################################################################# +# VQ Model Configs # +################################################################################# +def VQ_8(**kwargs): + return VQModel( + ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs) + ) + + +def VQ_16(**kwargs): + return VQModel( + ModelArgs( + encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs + ) + ) + + +VQ_models = {"VQ-16": VQ_16, "VQ-8": VQ_8} + +if __name__ == "__main__": + semantic_model = create_model( + "vit_small_patch14_dinov2.lvd142m", + pretrained=True, + img_size=256, + patch_size=16, + drop_path_rate=0.0, + ) + semantic_model.eval() diff --git a/src/vqvaes/xqgan/xqgan_train.py b/src/vqvaes/xqgan/xqgan_train.py new file mode 100644 index 0000000000000000000000000000000000000000..d76103d1513f99524a400bb5d530f692cec0d0d6 --- /dev/null +++ b/src/vqvaes/xqgan/xqgan_train.py @@ -0,0 +1,898 @@ +# Modified from: +# fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/train.py +# nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py +import torch + +# the first flag below was False when we tested this script but True makes A100 training a lot faster: +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.datasets import ImageFolder +from torchvision import transforms +from torchvision.utils import make_grid +from huggingface_hub import upload_folder + +import warnings + +warnings.filterwarnings("ignore") + +from PIL import Image +from tqdm import tqdm +import ruamel.yaml as yaml + +import os +import time +import argparse +from glob import glob +from copy import deepcopy +import sys +import math +import numpy as np + +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(current_dir, "../..")) +sys.path.append(project_root) +from utils.logger import create_logger +from utils.distributed import init_distributed_mode +from utils.ema import update_ema, requires_grad +from dataset.augmentation import random_crop_arr, center_crop_arr +from dataset.build import build_dataset +from tokenizer.tokenizer_image.xqgan_model import VQ_models +from tokenizer.tokenizer_image.vq_loss import VQLoss + +from timm.scheduler import create_scheduler_v2 as create_scheduler + +from evaluator import Evaluator +import tensorflow.compat.v1 as tf + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +import warnings + +warnings.filterwarnings("ignore") + +import wandb + +################################################################################# +# Training Loop # +################################################################################# + + +def get_random_ratio( + randomness_anneal_start, randomness_anneal_end, end_ratio, cur_step +): + if cur_step < randomness_anneal_start: + return 1.0 + elif cur_step > randomness_anneal_end: + return end_ratio + else: + return ( + 1.0 + - (cur_step - randomness_anneal_start) + / (randomness_anneal_end - randomness_anneal_start) + * end_ratio + ) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-path", type=str, default="/mnt/localssd/ImageNet2012/train" + ) + parser.add_argument( + "--data-face-path", + type=str, + default=None, + help="face datasets to improve vq model", + ) + parser.add_argument( + "--cloud-save-path", + type=str, + default="output/debug", + help="please specify a cloud disk path, if not, local path", + ) + parser.add_argument( + "--no-local-save", + action="store_true", + help="no save checkpoints to local path for limited disk volume", + ) + parser.add_argument( + "--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16" + ) + parser.add_argument( + "--vq-ckpt", type=str, default=None, help="ckpt path for resume training" + ) + parser.add_argument( + "--finetune", action="store_true", help="finetune a pre-trained vq model" + ) + parser.add_argument("--ema", action="store_true", help="whether using ema training") + parser.add_argument( + "--codebook-size", + type=int, + default=16384, + help="codebook size for vector quantization", + ) + parser.add_argument( + "--codebook-embed-dim", + type=int, + default=8, + help="codebook dimension for vector quantization", + ) + parser.add_argument( + "--codebook-l2-norm", action="store_true", default=True, help="l2 norm codebook" + ) + parser.add_argument( + "--codebook-weight", + type=float, + default=1.0, + help="codebook loss weight for vector quantization", + ) + parser.add_argument( + "--entropy-loss-ratio", + type=float, + default=0.0, + help="entropy loss ratio in codebook loss", + ) + parser.add_argument( + "--commit-loss-beta", + type=float, + default=0.25, + help="commit loss beta in codebook loss", + ) + parser.add_argument( + "--reconstruction-weight", + type=float, + default=1.0, + help="reconstruction loss weight of image pixel", + ) + parser.add_argument( + "--reconstruction-loss", + type=str, + default="l2", + help="reconstruction loss type of image pixel", + ) + parser.add_argument( + "--perceptual-weight", + type=float, + default=1.0, + help="perceptual loss weight of LPIPS", + ) + parser.add_argument( + "--disc-weight", + type=float, + default=0.5, + help="discriminator loss weight for gan training", + ) + parser.add_argument( + "--disc-epoch-start", + type=int, + default=0, + help="iteration to start discriminator training and loss", + ) + parser.add_argument( + "--disc-start", + type=int, + default=0, + help="iteration to start discriminator training and loss", + ) # autoset + parser.add_argument( + "--disc-type", + type=str, + choices=["patchgan", "stylegan"], + default="patchgan", + help="discriminator type", + ) + parser.add_argument( + "--disc-loss", + type=str, + choices=["hinge", "vanilla", "non-saturating"], + default="hinge", + help="discriminator loss", + ) + parser.add_argument( + "--gen-loss", + type=str, + choices=["hinge", "non-saturating"], + default="hinge", + help="generator loss for gan training", + ) + parser.add_argument("--compile", action="store_true", default=False) + parser.add_argument("--dropout-p", type=float, default=0.0, help="dropout_p") + parser.add_argument("--results-dir", type=str, default="results_tokenizer_image") + parser.add_argument("--dataset", type=str, default="imagenet") + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--epochs", type=int, default=40) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--disc_lr", type=float, default=1e-4) + parser.add_argument("--max_grad_norm", type=float, default=0.0) + parser.add_argument("--lr_scheduler", type=str, default="none") + parser.add_argument( + "--weight-decay", type=float, default=0.0, help="Weight decay to use." + ) + parser.add_argument( + "--disc-weight-decay", type=float, default=0.0, help="Weight decay to use." + ) + parser.add_argument( + "--beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--beta2", + type=float, + default=0.95, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--max-grad-norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument("--global-batch-size", type=int, default=128) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--num-workers", type=int, default=16) + parser.add_argument("--log-every", type=int, default=100) + parser.add_argument("--vis-every", type=int, default=5000) + parser.add_argument("--ckpt-every", type=int, default=10000) + parser.add_argument("--gradient-accumulation-steps", type=int, default=1) + parser.add_argument( + "--mixed-precision", type=str, default="bf16", choices=["none", "fp16", "bf16"] + ) + parser.add_argument("--save_best", action="store_true", default=False) + parser.add_argument( + "--val_data_path", type=str, default="/mnt/localssd/ImageNet2012/val" + ) + parser.add_argument("--sample_folder_dir", type=str, default="samples") + parser.add_argument( + "--reconstruction_folder_dir", type=str, default="reconstruction" + ) + parser.add_argument( + "--v-patch-nums", + type=int, + default=[1, 2, 3, 4, 5, 6, 8, 10, 13, 16], + nargs="+", + help="number of patch numbers of each scale", + ) + parser.add_argument("--enc_type", type=str, default="cnn") + parser.add_argument("--dec_type", type=str, default="cnn") + parser.add_argument("--semantic_guide", type=str, default="none") + parser.add_argument("--detail_guide", type=str, default="none") + parser.add_argument("--num_latent_tokens", type=int, default=256) + parser.add_argument( + "--encoder_model", + type=str, + default="vit_small_patch14_dinov2.lvd142m", + help="encoder model name", + ) + parser.add_argument( + "--decoder_model", + type=str, + default="vit_small_patch14_dinov2.lvd142m", + help="encoder model name", + ) + parser.add_argument("--disc_adaptive_weight", type=bool, default=False) + parser.add_argument("--abs_pos_embed", type=bool, default=False) + parser.add_argument("--product_quant", type=int, default=1) + parser.add_argument("--share_quant_resi", type=int, default=4) + parser.add_argument("--codebook_drop", type=float, default=0.0) + parser.add_argument("--half_sem", type=bool, default=False) + parser.add_argument("--start_drop", type=int, default=1) + parser.add_argument("--lecam_loss_weight", type=float, default=None) + parser.add_argument("--sem_loss_weight", type=float, default=0.1) + parser.add_argument("--detail_loss_weight", type=float, default=0.1) + parser.add_argument("--enc_tuning_method", type=str, default="full") + parser.add_argument("--dec_tuning_method", type=str, default="full") + parser.add_argument("--clip_norm", type=bool, default=False) + parser.add_argument("--sem_loss_scale", type=float, default=1.0) + parser.add_argument("--detail_loss_scale", type=float, default=1.0) + parser.add_argument("--config", type=str, default=None) + parser.add_argument("--norm_type", type=str, default="bn") + parser.add_argument("--aug_prob", type=float, default=1.0) + parser.add_argument("--aug_fade_steps", type=int, default=0) + parser.add_argument("--disc_reinit", type=int, default=0) + parser.add_argument("--debug_disc", type=bool, default=False) + parser.add_argument( + "--guide_type_1", type=str, default="class", choices=["patch", "class"] + ) + parser.add_argument( + "--guide_type_2", type=str, default="class", choices=["patch", "class"] + ) + parser.add_argument("--lfq", action="store_true", default=False, help="if use LFQ") + + parser.add_argument("--end-ratio", type=float, default=0.5) + parser.add_argument("--anneal-start", type=int, default=200) + parser.add_argument("--anneal-end", type=int, default=200) + + parser.add_argument("--alpha", type=float, default=0.0) + parser.add_argument("--beta", type=float, default=0.0) + parser.add_argument("--delta", type=int, default=100) + + args = parser.parse_args() + if args.config is not None: + with open(args.config, "r", encoding="utf-8") as f: + file_yaml = yaml.YAML() + config_args = file_yaml.load(f) + parser.set_defaults(**config_args) + + # re-parse command-line args to overwrite with any command-line inputs + args = parser.parse_args() + return args + + +def main(args): + """ + Trains a new model. + """ + assert torch.cuda.is_available(), "Training currently requires at least one GPU." + + # Setup DDP: + init_distributed_mode(args) + assert ( + args.global_batch_size % dist.get_world_size() == 0 + ), f"Batch size must be divisible by world size." + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + + # Setup an experiment folder: + if rank == 0: + os.makedirs( + args.results_dir, exist_ok=True + ) # Make results folder (holds all experiment subfolders) + experiment_index = len(glob(f"{args.results_dir}/*")) + model_string_name = args.vq_model.replace("/", "-") + experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder + checkpoint_dir = ( + f"{experiment_dir}/checkpoints" # Stores saved model checkpoints + ) + os.makedirs(checkpoint_dir, exist_ok=True) + logger = create_logger(experiment_dir) + logger.info(f"Experiment directory created at {experiment_dir}") + + cloud_results_dir = f"{args.cloud_save_path}" + cloud_checkpoint_dir = f"{cloud_results_dir}" + os.makedirs(cloud_checkpoint_dir, exist_ok=True) + logger.info(f"Experiment directory created in cloud at {cloud_checkpoint_dir}") + + experiment_config = vars(args) + with open( + os.path.join(cloud_checkpoint_dir, "config.yaml"), "w", encoding="utf-8" + ) as f: + # Use the round_trip_dump method to preserve the order and style + file_yaml = yaml.YAML() + file_yaml.dump(experiment_config, f) + + else: + logger = create_logger(None) + + # training args + logger.info(f"{args}") + + # training env + logger.info( + f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}." + ) + + # Setup data: + transform = transforms.Compose( + [ + transforms.Lambda( + lambda pil_image: random_crop_arr(pil_image, args.image_size) + ), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True + ), + ] + ) + dataset = build_dataset(args, transform=transform) + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=rank, + shuffle=True, + seed=args.global_seed, + ) + loader = DataLoader( + dataset, + batch_size=int(args.global_batch_size // dist.get_world_size()), + shuffle=False, + sampler=sampler, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True, + ) + logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})") + + if args.save_best: + transform = transforms.Compose( + [ + transforms.Lambda( + lambda pil_image: center_crop_arr(pil_image, args.image_size) + ), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True + ), + ] + ) + args.data_path = args.val_data_path + val_dataset = build_dataset(args, transform=transform) + val_sampler = DistributedSampler( + val_dataset, + num_replicas=dist.get_world_size(), + rank=rank, + shuffle=False, + seed=args.global_seed, + ) + val_loader = DataLoader( + val_dataset, + batch_size=int(args.global_batch_size // dist.get_world_size()), + shuffle=False, + sampler=val_sampler, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False, + ) + if rank % torch.cuda.device_count() == 0: + os.makedirs(args.sample_folder_dir, exist_ok=True) + os.makedirs(args.reconstruction_folder_dir, exist_ok=True) + logger.info(f"Saving .png samples at {args.sample_folder_dir}") + logger.info( + f"Saving .png reconstruction at {args.reconstruction_folder_dir}" + ) + + num_update_steps_per_epoch = len(loader) + max_train_steps = args.epochs * num_update_steps_per_epoch + args.disc_start = args.disc_epoch_start * num_update_steps_per_epoch + + # create and load model + vq_model = VQ_models[args.vq_model]( + codebook_size=args.codebook_size, + codebook_embed_dim=args.codebook_embed_dim, + commit_loss_beta=args.commit_loss_beta, + entropy_loss_ratio=args.entropy_loss_ratio, + dropout_p=args.dropout_p, + v_patch_nums=args.v_patch_nums, + enc_type=args.enc_type, + encoder_model=args.encoder_model, + dec_type=args.dec_type, + decoder_model=args.decoder_model, + semantic_guide=args.semantic_guide, + detail_guide=args.detail_guide, + num_latent_tokens=args.num_latent_tokens, + abs_pos_embed=args.abs_pos_embed, + share_quant_resi=args.share_quant_resi, + product_quant=args.product_quant, + codebook_drop=args.codebook_drop, + half_sem=args.half_sem, + start_drop=args.start_drop, + sem_loss_weight=args.sem_loss_weight, + detail_loss_weight=args.detail_loss_weight, + clip_norm=args.clip_norm, + sem_loss_scale=args.sem_loss_scale, + detail_loss_scale=args.detail_loss_scale, + guide_type_1=args.guide_type_1, + guide_type_2=args.guide_type_2, + lfq=args.lfq, + ) + logger.info( + f"VQ Model Parameters: {sum(p.numel() for p in vq_model.parameters()):,}" + ) + if args.ema: + ema = deepcopy(vq_model).to( + device + ) # Create an EMA of the model for use after training + requires_grad(ema, False) + logger.info( + f"VQ Model EMA Parameters: {sum(p.numel() for p in ema.parameters()):,}" + ) + vq_model = vq_model.to(device) + vq_loss = VQLoss( + disc_start=args.disc_start, + disc_weight=args.disc_weight, + disc_type=args.disc_type, + disc_loss=args.disc_loss, + gen_adv_loss=args.gen_loss, + image_size=args.image_size, + perceptual_weight=args.perceptual_weight, + reconstruction_weight=args.reconstruction_weight, + reconstruction_loss=args.reconstruction_loss, + codebook_weight=args.codebook_weight, + lecam_loss_weight=args.lecam_loss_weight, + disc_adaptive_weight=args.disc_adaptive_weight, + norm_type=args.norm_type, + aug_prob=args.aug_prob, + ).to(device) + logger.info( + f"Discriminator Parameters: {sum(p.numel() for p in vq_loss.discriminator.parameters()):,}" + ) + + args.lr = args.lr * args.global_batch_size / 128 + args.disc_lr = args.disc_lr * args.global_batch_size / 128 + # initialize a GradScaler. If enabled=False scaler is a no-op + scaler = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision == "fp16")) + scaler_disc = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision == "fp16")) + # Setup optimizer + optimizer = torch.optim.AdamW( + vq_model.parameters(), + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + ) + optimizer_disc = torch.optim.AdamW( + vq_loss.discriminator.parameters(), + lr=args.disc_lr, + betas=(args.beta1, args.beta2), + weight_decay=args.disc_weight_decay, + ) + + # create lr scheduler + if args.lr_scheduler == "none": + vqvae_lr_scheduler = None + disc_lr_scheduler = None + else: + vqvae_lr_scheduler, _ = create_scheduler( + sched=args.lr_scheduler, + optimizer=optimizer, + patience_epochs=0, + step_on_epochs=True, + updates_per_epoch=num_update_steps_per_epoch, + num_epochs=args.epochs, + warmup_epochs=1, + min_lr=5e-5, + ) + disc_lr_scheduler, _ = create_scheduler( + sched=args.lr_scheduler, + optimizer=optimizer_disc, + patience_epochs=0, + step_on_epochs=True, + updates_per_epoch=num_update_steps_per_epoch, + num_epochs=args.epochs - args.disc_epoch_start, + warmup_epochs=int(0.02 * args.epochs), + min_lr=5e-5, + ) + + logger.info( + f"num_update_steps_per_epoch {num_update_steps_per_epoch:,} max_train_steps ({max_train_steps})" + ) + + # Prepare models for training: + if args.vq_ckpt: + checkpoint = torch.load(args.vq_ckpt, map_location="cpu") + vq_model.load_state_dict(checkpoint["model"]) + if args.ema: + ema.load_state_dict(checkpoint["ema"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + if not args.debug_disc: + vq_loss.discriminator.load_state_dict(checkpoint["discriminator"]) + optimizer_disc.load_state_dict(checkpoint["optimizer_disc"]) + else: + num_step = checkpoint["optimizer_disc"]["state"][ + next(iter(checkpoint["optimizer_disc"]["state"])) + ]["step"] + for param_state in optimizer_disc.state.values(): + param_state["step"] = num_step + if not args.finetune: + train_steps = ( + checkpoint["steps"] + if "steps" in checkpoint + else int(args.vq_ckpt.split("/")[-1].split(".")[0]) + ) + start_epoch = ( + int(train_steps / int(len(dataset) / args.global_batch_size)) + 1 + ) + train_steps = int(start_epoch * int(len(dataset) / args.global_batch_size)) + else: + train_steps = 0 + start_epoch = 0 + del checkpoint + vq_model.finetune(args.enc_tuning_method, args.dec_tuning_method) + logger.info(f"Resume training from checkpoint: {args.vq_ckpt}") + logger.info(f"Initial state: steps={train_steps}, epochs={start_epoch}") + else: + train_steps = 0 + start_epoch = 0 + if args.ema: + update_ema( + ema, vq_model, decay=0 + ) # Ensure EMA is initialized with synced weights + + if args.compile: + logger.info("compiling the model... (may take several minutes)") + vq_model = torch.compile(vq_model, mode="max-autotune") # requires PyTorch 2.0 + + vq_model = DDP(vq_model.to(device), device_ids=[args.gpu]) + vq_model.train() + if args.ema: + ema.eval() # EMA model should always be in eval mode + vq_loss = DDP(vq_loss.to(device), device_ids=[args.gpu]) + vq_loss.train() + + ptdtype = {"none": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16}[ + args.mixed_precision + ] + + # Variables for monitoring/logging purposes: + log_steps = 0 + running_loss = 0 + start_time = time.time() + curr_fid = None + + logger.info(f"Training for {args.epochs} epochs...") + for epoch in range(start_epoch, args.epochs): + ratio = get_random_ratio( + args.anneal_start, args.anneal_end, args.end_ratio, epoch + ) + delta = int(ratio * args.delta) + alpha = ratio * args.alpha + beta = args.beta + + sampler.set_epoch(epoch) + logger.info(f"Beginning epoch {epoch}...") + if args.disc_reinit != 0: + if epoch % args.disc_reinit == 0: + vq_loss.module.discriminator.reinit() + for x, y in loader: + imgs = x.to(device, non_blocking=True) + + if args.aug_fade_steps >= 0: + fade_blur_schedule = ( + 0 + if train_steps < args.disc_start + else min( + 1.0, (train_steps - args.disc_start) / (args.aug_fade_steps + 1) + ) + ) + fade_blur_schedule = 1 - fade_blur_schedule + else: + fade_blur_schedule = 0 + # generator training + optimizer.zero_grad() + with torch.cuda.amp.autocast(dtype=ptdtype): + recons_imgs, codebook_loss, sem_loss, detail_loss, dependency_loss = ( + vq_model(imgs, epoch, alpha, beta, delta) + ) + loss_gen = vq_loss( + codebook_loss, + sem_loss, + detail_loss, + dependency_loss, + imgs, + recons_imgs, + optimizer_idx=0, + global_step=train_steps + 1, + last_layer=vq_model.module.decoder.last_layer, + logger=logger, + log_every=args.log_every, + fade_blur_schedule=fade_blur_schedule, + ) + + scaler.scale(loss_gen).backward() + if args.max_grad_norm != 0.0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_( + vq_model.parameters(), args.max_grad_norm + ) + scaler.step(optimizer) + scaler.update() + if args.ema: + update_ema( + ema, vq_model.module._orig_mod if args.compile else vq_model.module + ) + + # discriminator training + optimizer_disc.zero_grad() + + with torch.cuda.amp.autocast(dtype=ptdtype): + loss_disc = vq_loss( + codebook_loss, + sem_loss, + detail_loss, + dependency_loss, + imgs, + recons_imgs, + optimizer_idx=1, + global_step=train_steps + 1, + logger=logger, + log_every=args.log_every, + fade_blur_schedule=fade_blur_schedule, + ) + scaler_disc.scale(loss_disc).backward() + if args.max_grad_norm != 0.0: + scaler_disc.unscale_(optimizer_disc) + torch.nn.utils.clip_grad_norm_( + vq_loss.module.discriminator.parameters(), args.max_grad_norm + ) + scaler_disc.step(optimizer_disc) + scaler_disc.update() + + # # Log loss values: + running_loss += loss_gen.item() + loss_disc.item() + + log_steps += 1 + train_steps += 1 + if train_steps % args.log_every == 0: + # Measure training speed: + torch.cuda.synchronize() + end_time = time.time() + steps_per_sec = log_steps / (end_time - start_time) + # Reduce loss history over all processes: + avg_loss = torch.tensor(running_loss / log_steps, device=device) + dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) + avg_loss = avg_loss.item() / dist.get_world_size() + logger.info( + f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}" + ) + # Reset monitoring variables: + running_loss = 0 + log_steps = 0 + start_time = time.time() + + if dist.get_rank() == 0: + vq_loss.module.wandb_tracker.log( + {"lr": optimizer.param_groups[0]["lr"], "train_loss": avg_loss}, + step=train_steps, + ) + # show images and recon images + if train_steps % args.vis_every == 0: + with torch.no_grad(): + recons_with_scale = ( + vq_model.module.img_to_reconstructed_img( + imgs[:4], last_one=False + ) + ) + image = torch.cat(recons_with_scale + [imgs[:4]], dim=0) + image = torch.clamp(image, min=-1, max=1) + image = make_grid( + (image + 1) / 2, nrow=4, padding=0, pad_value=1.0 + ) + image = image.permute(1, 2, 0).mul_(255).cpu().numpy() + image = Image.fromarray(image.astype(np.uint8)) + + vq_loss.module.wandb_tracker.log( + {"recon_images": [wandb.Image(image)]}, step=train_steps + ) + + # Save checkpoint: + if train_steps % args.ckpt_every == 0 and train_steps > 0: + if args.save_best: + vq_model.eval() + total = 0 + samples = [] + gt = [] + for x, _ in tqdm( + val_loader, + desc=f"evaluation for step {train_steps:07d}", + disable=not rank == 0, + ): + with torch.no_grad(): + x = x.to(device, non_blocking=True) + sample = vq_model.module.img_to_reconstructed_img(x) + sample = ( + torch.clamp(127.5 * sample + 128.0, 0, 255) + .permute(0, 2, 3, 1) + .to(torch.uint8) + .contiguous() + ) + x = ( + torch.clamp(127.5 * x + 128.0, 0, 255) + .permute(0, 2, 3, 1) + .to(torch.uint8) + .contiguous() + ) + + sample = torch.cat(dist.nn.all_gather(sample), dim=0) + x = torch.cat(dist.nn.all_gather(x), dim=0) + samples.append(sample.to("cpu", dtype=torch.uint8).numpy()) + gt.append(x.to("cpu", dtype=torch.uint8).numpy()) + + total += sample.shape[0] + vq_model.train() + logger.info(f"Ealuate total {total} files.") + dist.barrier() + + if rank == 0: + samples = np.concatenate(samples, axis=0) + gt = np.concatenate(gt, axis=0) + config = tf.ConfigProto( + allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph + ) + config.gpu_options.allow_growth = True + + evaluator = Evaluator(tf.Session(config=config), batch_size=32) + evaluator.warmup() + logger.info("computing reference batch activations...") + ref_acts = evaluator.read_activations(gt) + logger.info("computing/reading reference batch statistics...") + ref_stats, _ = evaluator.read_statistics(gt, ref_acts) + logger.info("computing sample batch activations...") + sample_acts = evaluator.read_activations(samples) + logger.info("computing/reading sample batch statistics...") + sample_stats, _ = evaluator.read_statistics( + samples, sample_acts + ) + FID = sample_stats.frechet_distance(ref_stats) + + logger.info(f"traing step: {train_steps:07d}, FID {FID:07f}") + # eval code, delete prev if not the best + if curr_fid == None: + curr_fid = [FID, train_steps] + elif FID <= curr_fid[0]: + # os.remove(f"{cloud_checkpoint_dir}/{curr_fid[1]:07d}.pt") + curr_fid = [FID, train_steps] + + vq_loss.module.wandb_tracker.log( + {"eval FID": FID}, step=train_steps + ) + + dist.barrier() + + if rank == 0: + if args.compile: + model_weight = vq_model.module._orig_mod.state_dict() + else: + model_weight = vq_model.module.state_dict() + checkpoint = { + "model": model_weight, + "optimizer": optimizer.state_dict(), + "discriminator": vq_loss.module.discriminator.state_dict(), + "optimizer_disc": optimizer_disc.state_dict(), + "steps": train_steps, + "args": args, + } + if args.ema: + checkpoint["ema"] = ema.state_dict() + if not args.no_local_save: + checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" + torch.save(checkpoint, checkpoint_path) + logger.info(f"Saved checkpoint to {checkpoint_path}") + + # cloud_checkpoint_path = f"{cloud_checkpoint_dir}/{train_steps:07d}.pt" + # torch.save(checkpoint, cloud_checkpoint_path) + # logger.info(f"Saved checkpoint in cloud to {cloud_checkpoint_path}") + + if args.save_best: + last_checkpoint_path = f"{args.cloud_save_path}/last_ckpt.pt" + if os.path.exists(last_checkpoint_path): + os.remove(last_checkpoint_path) + else: + os.makedirs(f"{args.cloud_save_path}", exist_ok=True) + torch.save(checkpoint, last_checkpoint_path) + logger.info( + f"Saved checkpoint in cloud to {last_checkpoint_path}" + ) + if curr_fid[1] == train_steps: + best_checkpoint_path = ( + f"{args.cloud_save_path}/best_ckpt.pt" + ) + torch.save(checkpoint, best_checkpoint_path) + logger.info( + f"Saved checkpoint in cloud to {best_checkpoint_path}" + ) + + dist.barrier() + + if vqvae_lr_scheduler is not None: + vqvae_lr_scheduler.step(epoch + 1) + if disc_lr_scheduler is not None and epoch >= args.disc_epoch_start: + disc_lr_scheduler.step(epoch + 1 - args.disc_epoch_start) + + vq_model.eval() # important! This disables randomized embedding dropout + # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... + + logger.info("Done!") + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_args() + main(args)