huaweilin commited on
Commit
14ce5a9
·
1 Parent(s): 6a98d62
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. .gitignore +6 -0
  3. app.py +166 -0
  4. assets/app_examples/0.png +3 -0
  5. assets/app_examples/1.png +3 -0
  6. assets/app_examples/2.png +3 -0
  7. assets/app_examples/3.png +3 -0
  8. assets/app_examples/4.png +3 -0
  9. assets/comparison_of_generation.png +3 -0
  10. assets/overview.png +3 -0
  11. evaluations/character_error_rate.py +27 -0
  12. evaluations/evaluate_images.py +130 -0
  13. evaluations/ocr.py +44 -0
  14. evaluations/word_error_rate.py +30 -0
  15. examples/get_result.py +94 -0
  16. examples/run.sh +85 -0
  17. examples/submit.sh +10 -0
  18. main.py +99 -0
  19. requirements.txt +21 -0
  20. src/__init__.py +0 -0
  21. src/data_loader.py +61 -0
  22. src/data_processing.py +89 -0
  23. src/model_processing.py +409 -0
  24. src/utils.py +47 -0
  25. src/vaes/gpt_image/gpt_image.py +48 -0
  26. src/vaes/stable_diffusion/vae.py +23 -0
  27. src/vqvaes/__init__.py +0 -0
  28. src/vqvaes/anole/anole.py +706 -0
  29. src/vqvaes/bsqvit/attention_mask.py +42 -0
  30. src/vqvaes/bsqvit/bsqvit.py +150 -0
  31. src/vqvaes/bsqvit/quantizer/bsq.py +223 -0
  32. src/vqvaes/bsqvit/quantizer/vq.py +152 -0
  33. src/vqvaes/bsqvit/stylegan_utils/custom_ops.py +126 -0
  34. src/vqvaes/bsqvit/stylegan_utils/misc.py +40 -0
  35. src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.cpp +99 -0
  36. src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.cu +176 -0
  37. src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.h +38 -0
  38. src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.py +226 -0
  39. src/vqvaes/bsqvit/stylegan_utils/ops/conv2d_gradfix.py +170 -0
  40. src/vqvaes/bsqvit/stylegan_utils/ops/conv2d_resample.py +155 -0
  41. src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.cpp +103 -0
  42. src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.cu +353 -0
  43. src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.h +59 -0
  44. src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.py +382 -0
  45. src/vqvaes/bsqvit/transformer.py +416 -0
  46. src/vqvaes/flowmo/flowmo.py +945 -0
  47. src/vqvaes/flowmo/lookup_free_quantize.py +396 -0
  48. src/vqvaes/infinity/conv.py +107 -0
  49. src/vqvaes/infinity/dynamic_resolution.py +147 -0
  50. src/vqvaes/infinity/flux_vqgan.py +771 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/app_examples/0.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/app_examples/1.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/overview.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/app_examples/2.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/app_examples/4.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/comparison_of_generation.png filter=lfs diff=lfs merge=lfs -text
42
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.ckpt
3
+ checkpoints/
4
+ results/
5
+ VTBench_models/
6
+ README.md
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+ import subprocess
4
+ import sys
5
+
6
+ # REQUIREMENTS_FILE = "requirements.txt"
7
+ # if os.path.exists(REQUIREMENTS_FILE):
8
+ # try:
9
+ # print("Installing dependencies from requirements.txt...")
10
+ # subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_FILE])
11
+ # print("Dependencies installed successfully.")
12
+ # except subprocess.CalledProcessError as e:
13
+ # print(f"Failed to install dependencies: {e}")
14
+ # else:
15
+ # print("requirements.txt not found.")
16
+
17
+ import gradio as gr
18
+ from src.data_processing import pil_to_tensor, tensor_to_pil
19
+ from PIL import Image
20
+ from src.model_processing import get_model
21
+ from huggingface_hub import snapshot_download
22
+ import torch
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ print(f"Running on: {device}")
26
+
27
+ MODEL_DIR = "./VTBench_models"
28
+ if not os.path.exists(MODEL_DIR):
29
+ print("Downloading VTBench_models from Hugging Face...")
30
+ snapshot_download(
31
+ repo_id="huaweilin/VTBench_models",
32
+ local_dir=MODEL_DIR,
33
+ local_dir_use_symlinks=False
34
+ )
35
+ print("Download complete.")
36
+
37
+ example_image_paths = [f"assets/app_examples/{i}.png" for i in range(0, 5)]
38
+
39
+ model_name_mapping = {
40
+ "SD3.5L": "SD3.5L",
41
+ "chameleon": "Chameleon",
42
+ # "flowmo_lo": "FlowMo Lo",
43
+ # "flowmo_hi": "FlowMo Hi",
44
+ # "gpt4o": "GPT-4o",
45
+ "janus_pro_1b": "Janus Pro 1B/7B",
46
+ # "llamagen-ds8": "LlamaGen ds8",
47
+ # "llamagen-ds16": "LlamaGen ds16",
48
+ # "llamagen-ds16-t2i": "LlamaGen ds16 T2I",
49
+ # "maskbit_16bit": "MaskBiT 16bit",
50
+ # "maskbit_18bit": "MaskBiT 18bit",
51
+ # "open_magvit2": "OpenMagViT",
52
+ # "titok_b64": "Titok-b64",
53
+ # "titok_bl64": "Titok-bl64",
54
+ # "titok_s128": "Titok-s128",
55
+ # "titok_bl128": "Titok-bl128",
56
+ # "titok_l32": "Titok-l32",
57
+ # "titok_sl256": "Titok-sl256",
58
+ # "var_256": "VAR-256",
59
+ # "var_512": "VAR-512",
60
+ # "FLUX.1-dev": "FLUX.1-dev",
61
+ # "infinity_d32": "Infinity-d32",
62
+ # "infinity_d64": "Infinity-d64",
63
+ # "bsqvit": "BSQ-VIT",
64
+ }
65
+
66
+ def load_model(model_name):
67
+ model, data_params = get_model(MODEL_DIR, model_name)
68
+ model = model.to(device)
69
+ model.eval()
70
+ return model, data_params
71
+
72
+ model_dict = {
73
+ model_name: load_model(model_name)
74
+ for model_name in model_name_mapping
75
+ }
76
+
77
+ placeholder_image = Image.new("RGBA", (512, 512), (0, 0, 0, 0))
78
+
79
+ @spaces.GPU
80
+ def process_selected_models(uploaded_image, selected_models):
81
+ results = []
82
+ for model_name in model_name_mapping:
83
+ if uploaded_image is None:
84
+ results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (No input)"))
85
+ elif model_name in selected_models:
86
+ try:
87
+ model, data_params = model_dict[model_name]
88
+ pixel_values = pil_to_tensor(uploaded_image, **data_params).unsqueeze(0).to(device)
89
+ output = model(pixel_values)[0]
90
+ reconstructed_image = tensor_to_pil(output[0].cpu(), **data_params)
91
+ results.append(gr.update(value=reconstructed_image, label=model_name_mapping[model_name]))
92
+ except Exception as e:
93
+ print(f"Error in model {model_name}: {e}")
94
+ results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (Error)"))
95
+ else:
96
+ results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (Not selected)"))
97
+ return results
98
+
99
+ with gr.Blocks() as demo:
100
+ gr.Markdown("## VTBench")
101
+
102
+ gr.Markdown("---")
103
+
104
+ image_input = gr.Image(
105
+ type="pil",
106
+ label="Upload an image",
107
+ width=512,
108
+ height=512,
109
+ )
110
+
111
+ gr.Markdown("### Click on an example image to use it as input:")
112
+ example_rows = [example_image_paths[i:i+5] for i in range(0, len(example_image_paths), 5)]
113
+ for row in example_rows:
114
+ with gr.Row():
115
+ for path in row:
116
+ ex_img = gr.Image(
117
+ value=path,
118
+ show_label=False,
119
+ interactive=True,
120
+ width=256,
121
+ height=256,
122
+ )
123
+
124
+ def make_loader(p=path):
125
+ def load_img():
126
+ return Image.open(p)
127
+ return load_img
128
+
129
+ ex_img.select(fn=make_loader(), outputs=image_input)
130
+
131
+ gr.Markdown("---")
132
+
133
+ gr.Markdown("⚠️ **The more models you select, the longer the processing time will be.**")
134
+ model_selector = gr.CheckboxGroup(
135
+ choices=list(model_name_mapping.keys()),
136
+ label="Select models to run",
137
+ value=["SD3.5L", "chameleon", "janus_pro_1b"],
138
+ interactive=True,
139
+ )
140
+ run_button = gr.Button("Start Processing")
141
+
142
+ image_outputs = []
143
+ model_items = list(model_name_mapping.items())
144
+
145
+ n_columns = 5
146
+ output_rows = [model_items[i:i+n_columns] for i in range(0, len(model_items), n_columns)]
147
+
148
+ with gr.Column():
149
+ for row in output_rows:
150
+ with gr.Row():
151
+ for model_name, display_name in row:
152
+ out_img = gr.Image(
153
+ label=f"{display_name} (Not run)",
154
+ value=placeholder_image,
155
+ width=512,
156
+ height=512,
157
+ )
158
+ image_outputs.append(out_img)
159
+
160
+ run_button.click(
161
+ fn=process_selected_models,
162
+ inputs=[image_input, model_selector],
163
+ outputs=image_outputs
164
+ )
165
+
166
+ demo.launch()
assets/app_examples/0.png ADDED

Git LFS Details

  • SHA256: 4ea7967c763311298a587d3c8a7486913d63df0640e79e9e5cfd8cfdc9a4a558
  • Pointer size: 132 Bytes
  • Size of remote file: 2.85 MB
assets/app_examples/1.png ADDED

Git LFS Details

  • SHA256: 5d54dacbad49976f3105b8e69fede50f7a1cf7abe96ec5244c46eaaadfd688a6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.29 MB
assets/app_examples/2.png ADDED

Git LFS Details

  • SHA256: 656a438b675122e0384097c91554fe4d810e8a0025b770443998e473574e5056
  • Pointer size: 132 Bytes
  • Size of remote file: 2.09 MB
assets/app_examples/3.png ADDED

Git LFS Details

  • SHA256: fdee23b36fae13bb5806738e96607f8690941b647f7d6db0865f5efd745d8360
  • Pointer size: 130 Bytes
  • Size of remote file: 89.4 kB
assets/app_examples/4.png ADDED

Git LFS Details

  • SHA256: 9040083842809702de7efc1944aed09b099919f1ed2267a7280fc2bab82df0b1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.66 MB
assets/comparison_of_generation.png ADDED

Git LFS Details

  • SHA256: e8ce5f1b645dceb72c01cce066b2d91bb19935877ac03a4ea69c74ed612e8212
  • Pointer size: 132 Bytes
  • Size of remote file: 3.48 MB
assets/overview.png ADDED

Git LFS Details

  • SHA256: d80cb837fa4594b76be6f8c912e1dcb13c5dc6a4a57b09e7861f2d20ce5c92e2
  • Pointer size: 132 Bytes
  • Size of remote file: 2.15 MB
evaluations/character_error_rate.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchmetrics import Metric
3
+ from ocr import OCR
4
+ import Levenshtein
5
+
6
+
7
+ class CharacterErrorRate(Metric):
8
+ def __init__(self, ocr, dist_sync_on_step=False):
9
+ # super().__init__(dist_sync_on_step=dist_sync_on_step)
10
+ super().__init__()
11
+ self.add_state("total_errors", default=torch.tensor(0.0), dist_reduce_fx="sum")
12
+ self.add_state("total_chars", default=torch.tensor(0.0), dist_reduce_fx="sum")
13
+ self.ocr = ocr
14
+
15
+ def update(self, pred_images, target_images):
16
+ for pred_img, target_img in zip(pred_images, target_images):
17
+ pred_text = self.ocr.predict(pred_img)
18
+ target_text = self.ocr.predict(target_img)
19
+
20
+ dist = Levenshtein.distance(pred_text, target_text)
21
+ self.total_errors += dist
22
+ self.total_chars += len(target_text)
23
+
24
+ def compute(self):
25
+ if self.total_chars == 0:
26
+ return torch.tensor(0.0)
27
+ return self.total_errors / self.total_chars
evaluations/evaluate_images.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+ from torchvision import transforms
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from ocr import OCR
10
+ from character_error_rate import CharacterErrorRate
11
+ from word_error_rate import WordErrorRate
12
+ from torchmetrics.image import (
13
+ PeakSignalNoiseRatio,
14
+ StructuralSimilarityIndexMeasure,
15
+ LearnedPerceptualImagePatchSimilarity,
16
+ FrechetInceptionDistance,
17
+ )
18
+
19
+
20
+ class ImageFolderPairDataset(Dataset):
21
+ def __init__(self, dir1, dir2, transform=None):
22
+ self.dir1 = dir1
23
+ self.dir2 = dir2
24
+ self.filenames = sorted(os.listdir(dir1))
25
+ self.transform = transform
26
+
27
+ def __len__(self):
28
+ return len(self.filenames)
29
+
30
+ def __getitem__(self, idx):
31
+ name = self.filenames[idx]
32
+ img1 = Image.open(os.path.join(self.dir1, name)).convert("RGB")
33
+ img2 = Image.open(os.path.join(self.dir2, name)).convert("RGB")
34
+ if self.transform:
35
+ img1 = self.transform(img1)
36
+ img2 = self.transform(img2)
37
+ return img1, img2
38
+
39
+
40
+ def evaluate(args):
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ print(f"Using device: {device}")
43
+
44
+ transform = transforms.Compose(
45
+ [transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor()]
46
+ )
47
+
48
+ dataset = ImageFolderPairDataset(
49
+ args.original_dir, args.reconstructed_dir, transform
50
+ )
51
+ loader = DataLoader(
52
+ dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
53
+ )
54
+
55
+ if "cer" in args.metrics or "wer" in args.metrics:
56
+ ocr = OCR(device)
57
+
58
+ # Metrics init
59
+ metrics = {}
60
+
61
+ if "psnr" in args.metrics:
62
+ metrics["psnr"] = PeakSignalNoiseRatio().to(device)
63
+ if "ssim" in args.metrics:
64
+ metrics["ssim"] = StructuralSimilarityIndexMeasure().to(device)
65
+ if "lpips" in args.metrics:
66
+ metrics["lpips"] = LearnedPerceptualImagePatchSimilarity().to(device)
67
+ if "fid" in args.metrics:
68
+ metrics["fid"] = FrechetInceptionDistance().to(device)
69
+ if "cer" in args.metrics:
70
+ metrics["cer"] = CharacterErrorRate(ocr)
71
+ if "wer" in args.metrics:
72
+ metrics["wer"] = WordErrorRate(ocr)
73
+
74
+ for batch in tqdm(loader, desc="Evaluating"):
75
+ # img1, img1_path, img2, img2_path = [b.to(device) for b in batch]
76
+ img1, img2 = [b.to(device) for b in batch]
77
+
78
+ if "psnr" in metrics:
79
+ metrics["psnr"].update(img2, img1)
80
+ if "ssim" in metrics:
81
+ metrics["ssim"].update(img2, img1)
82
+ if "lpips" in metrics:
83
+ metrics["lpips"].update(img2, img1)
84
+ if "cer" in metrics:
85
+ metrics["cer"].update(img2, img1)
86
+ if "wer" in metrics:
87
+ metrics["wer"].update(img2, img1)
88
+ if "fid" in metrics:
89
+ img1_uint8 = (img1 * 255).clamp(0, 255).to(torch.uint8)
90
+ img2_uint8 = (img2 * 255).clamp(0, 255).to(torch.uint8)
91
+ metrics["fid"].update(img1_uint8, real=True)
92
+ metrics["fid"].update(img2_uint8, real=False)
93
+
94
+ print("\nResults:")
95
+ for name, metric in metrics.items():
96
+ print(f"{name.upper()}", end="\t")
97
+ print()
98
+ for name, metric in metrics.items():
99
+ result = metric.compute().item()
100
+ print(f"{result:.4f}", end="\t")
101
+ print()
102
+
103
+
104
+ if __name__ == "__main__":
105
+ parser = argparse.ArgumentParser()
106
+ parser.add_argument(
107
+ "--original_dir", type=str, required=True, help="Path to original images"
108
+ )
109
+ parser.add_argument(
110
+ "--reconstructed_dir",
111
+ type=str,
112
+ required=True,
113
+ help="Path to reconstructed images",
114
+ )
115
+ parser.add_argument(
116
+ "--metrics",
117
+ nargs="+",
118
+ default=["psnr", "ssim", "lpips", "fid"],
119
+ help="Metrics to compute: psnr, ssim, lpips, fid",
120
+ )
121
+ parser.add_argument(
122
+ "--batch_size", type=int, default=8, help="Batch size for processing"
123
+ )
124
+ parser.add_argument("--image_size", type=int, default=256, help="Image resize size")
125
+ parser.add_argument(
126
+ "--num_workers", type=int, default=4, help="Number of workers for DataLoader"
127
+ )
128
+ args = parser.parse_args()
129
+
130
+ evaluate(args)
evaluations/ocr.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from transformers import AutoProcessor, AutoModelForImageTextToText
3
+ import torch
4
+
5
+
6
+ class OCR:
7
+ def __init__(self, device="cpu"):
8
+ self.device = torch.device(device)
9
+ self.model = AutoModelForImageTextToText.from_pretrained(
10
+ "google/gemma-3-12b-it",
11
+ torch_dtype=torch.bfloat16,
12
+ ).to(self.device)
13
+ self.processor = AutoProcessor.from_pretrained("google/gemma-3-12b-it")
14
+
15
+ self.messages = [
16
+ {
17
+ "role": "user",
18
+ "content": [
19
+ {"type": "image"},
20
+ {
21
+ "type": "text",
22
+ "text": "Extract and output only the text from the image in its original language. If there is no text, return nothing.",
23
+ },
24
+ ],
25
+ },
26
+ ]
27
+
28
+ def predict(self, image):
29
+ image = (
30
+ (image * 255).clamp(0, 255).to(torch.uint8).permute((1, 2, 0)).cpu().numpy()
31
+ )
32
+ image = Image.fromarray(image).convert("RGB").resize((1024, 1024))
33
+ prompt = self.processor.apply_chat_template(
34
+ self.messages, add_generation_prompt=True
35
+ )
36
+ inputs = self.processor(text=prompt, images=[image], return_tensors="pt").to(
37
+ self.device
38
+ )
39
+ with torch.no_grad():
40
+ generated_ids = self.model.generate(**inputs, max_new_tokens=1024)
41
+ generated_text = self.processor.batch_decode(
42
+ generated_ids[:, inputs.input_ids.shape[-1] :], skip_special_tokens=True
43
+ )[0]
44
+ return generated_text
evaluations/word_error_rate.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchmetrics import Metric
3
+ import Levenshtein
4
+
5
+
6
+ class WordErrorRate(Metric):
7
+ def __init__(self, ocr, dist_sync_on_step=False):
8
+ # super().__init__(dist_sync_on_step=dist_sync_on_step)
9
+ super().__init__()
10
+ self.ocr = ocr
11
+ self.add_state("total_errors", default=torch.tensor(0.0), dist_reduce_fx="sum")
12
+ self.add_state("total_words", default=torch.tensor(0.0), dist_reduce_fx="sum")
13
+
14
+ def update(self, pred_images, target_images):
15
+ for pred_img, target_img in zip(pred_images, target_images):
16
+ pred_text = self.ocr.predict(pred_img)
17
+ target_text = self.ocr.predict(target_img)
18
+
19
+ pred_words = pred_text.strip().split()
20
+ target_words = target_text.strip().split()
21
+
22
+ dist = Levenshtein.distance(" ".join(pred_words), " ".join(target_words))
23
+
24
+ self.total_errors += dist
25
+ self.total_words += len(target_words)
26
+
27
+ def compute(self):
28
+ if self.total_words == 0:
29
+ return torch.tensor(0.0)
30
+ return self.total_errors / self.total_words
examples/get_result.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+
4
+ root_dir = "./"
5
+
6
+ model_name_mapping = {
7
+ "flowmo_lo": "FlowMo Lo",
8
+ "flowmo_hi": "FlowMo Hi",
9
+ "gpt4o": "GPT-4o",
10
+ "janus_pro_1b": "Janus Pro 1B/7B",
11
+ "llamagen-ds8": "LlamaGen ds8",
12
+ "llamagen-ds16": "LlamaGen ds16",
13
+ "llamagen-ds16-t2i": "LlamaGen ds16 T2I",
14
+ "maskbit_16bit": "MaskBiT 16bit",
15
+ "maskbit_18bit": "MaskBiT 18bit",
16
+ "open_magvit2": "OpenMagViT",
17
+ "titok_b64": "Titok-b64",
18
+ "titok_bl64": "Titok-bl64",
19
+ "titok_s128": "Titok-s128",
20
+ "titok_bl128": "Titok-bl128",
21
+ "titok_l32": "Titok-l32",
22
+ "titok_sl256": "Titok-sl256",
23
+ "var_256": "VAR-256",
24
+ "var_512": "VAR-512",
25
+ "SD3.5L": "SD3.5L",
26
+ "FLUX.1-dev": "FLUX.1-dev",
27
+ "infinity_d32": "Infinity-d32",
28
+ "infinity_d64": "Infinity-d64",
29
+ "chameleon": "Chameleon",
30
+ "bsqvit": "BSQ-VIT",
31
+ }
32
+
33
+ output_order = [
34
+ "FlowMo Lo",
35
+ "FlowMo Hi",
36
+ "MaskBiT 16bit",
37
+ "MaskBiT 18bit",
38
+ "Titok-l32",
39
+ "Titok-b64",
40
+ "Titok-s128",
41
+ "Titok-bl64",
42
+ "Titok-bl128",
43
+ "Titok-sl256",
44
+ "OpenMagViT",
45
+ "LlamaGen ds8",
46
+ "BSQ-VIT",
47
+ "VAR-256",
48
+ "Janus Pro 1B/7B",
49
+ "Chameleon",
50
+ "LlamaGen ds16",
51
+ "LlamaGen ds16 T2I",
52
+ "VAR-512",
53
+ "Infinity-d32",
54
+ "Infinity-d64",
55
+ "SD3.5L",
56
+ "FLUX.1-dev",
57
+ "GPT-4o",
58
+ ]
59
+
60
+ for dataset_name in os.listdir(root_dir):
61
+ dataset_path = os.path.join(root_dir, dataset_name)
62
+ if not os.path.isdir(dataset_path):
63
+ continue
64
+
65
+ results = {}
66
+
67
+ for model_dir in os.listdir(dataset_path):
68
+ model_path = os.path.join(dataset_path, model_dir)
69
+ result_file = os.path.join(model_path, "result.txt")
70
+
71
+ if os.path.isfile(result_file):
72
+ with open(result_file, "r", encoding="utf-8") as f:
73
+ lines = f.readlines()
74
+
75
+ if len(lines) >= 2:
76
+ metrics_line = lines[-2].strip()
77
+ values_line = lines[-1].strip()
78
+
79
+ metrics = metrics_line.split()
80
+ values = values_line.split()
81
+
82
+ mapped_name = model_name_mapping.get(model_dir, model_dir)
83
+ results[mapped_name] = values
84
+
85
+ if results:
86
+ header = "\t".join(metrics)
87
+ print(f"{dataset_name}\t{header}")
88
+ for model_name in output_order:
89
+ if model_name in results:
90
+ values = results[model_name]
91
+ print(f"{model_name}\t" + "\t".join(values))
92
+ else:
93
+ print(f"{model_name}\t" + "no result")
94
+ print()
examples/run.sh ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ dataset_name_list=("task1-imagenet" "task1-high-resolution" "task1-varying-resolution" "task2-detail-preservation" "task3-movie-posters" "task3-arxiv-abstracts" "task3-multilingual_Chinese" "task3-multilingual_Hindi" "task3-multilingual_Japanese" "task3-multilingual_Korean")
4
+ model_name_list=("chameleon" "llamagen-ds16" "llamagen-ds8" "flowmo_lo" "flowmo_hi" "open_magvit2" "titok_l32" "titok_b64" "titok_s128" "titok_bl64" "titok_bl128" "titok_sl256" "janus_pro_1b" "maskbit_18bit" "maskbit_16bit" "var_256" "var_512" "SD3.5L" "gpt4o" "llamagen-ds16-t2i" "infinity_d32" "infinity_d64" "bsqvit" "FLUX.1-dev")
5
+
6
+ batch_size=1
7
+
8
+ if command -v sbatch >/dev/null 2>&1; then
9
+ has_slurm=true
10
+ else
11
+ has_slurm=false
12
+ fi
13
+
14
+ shell_dir=$(cd "$(dirname "$0")";pwd)
15
+ echo "shell_dir: ${shell_dir}"
16
+ base_path="${shell_dir}/../"
17
+
18
+ for dataset_name in "${dataset_name_list[@]}"
19
+ do
20
+ cd ${shell_dir}
21
+ folder_dir="${dataset_name}"
22
+ mkdir ${folder_dir}
23
+
24
+ metrics="fid ssim psnr lpips"
25
+ split_name="test"
26
+ n_take=-1
27
+
28
+ if [[ $dataset_name == task3-multilingual_* ]]; then
29
+ split_name="${dataset_name##*_}"
30
+ dataset_name="${dataset_name%_*}"
31
+ fi
32
+ if [ "$dataset_name" = "task1-imagenet" ]; then
33
+ split_name="val"
34
+ fi
35
+
36
+ if [ "$dataset_name" = "task1-varying-resolution" ]; then
37
+ batch_size=1
38
+ fi
39
+ if [ "$dataset_name" = "task3-movie-posters" ]; then
40
+ metrics="fid ssim psnr lpips cer wer"
41
+ fi
42
+ if [ "$dataset_name" = "task3-arxiv-abstracts" ]; then
43
+ metrics="fid ssim psnr lpips cer wer"
44
+ fi
45
+ if [ "$dataset_name" = "task3-multilingual" ]; then
46
+ metrics="fid ssim psnr lpips cer"
47
+ fi
48
+
49
+ for model_name in "${model_name_list[@]}"
50
+ do
51
+ if [ "$dataset_name" = "task1-imagenet" ] && [ "$model_name" = "gpt4o" ]; then
52
+ n_take=100
53
+ fi
54
+ cd ${shell_dir}
55
+
56
+ work_dir="${folder_dir}/${model_name}"
57
+ echo "model_name: ${model_name}, work_dir: ${work_dir}"
58
+ mkdir ${work_dir}
59
+
60
+ cp submit.sh ${work_dir}
61
+
62
+ cd ${work_dir}
63
+ sed -i "s|{model_name}|${model_name}|g" submit.sh
64
+ sed -i "s|{split_name}|${split_name}|g" submit.sh
65
+ sed -i "s|{dataset_name}|${dataset_name}|g" submit.sh
66
+ sed -i "s|{batch_size}|${batch_size}|g" submit.sh
67
+ sed -i "s|{base_path}|${base_path}|g" submit.sh
68
+ sed -i "s|{metrics}|${metrics}|g" submit.sh
69
+ sed -i "s|{n_take}|${n_take}|g" submit.sh
70
+
71
+ # if [ "$has_slurm" = true ]; then
72
+ # res=$(sbatch ./submit.sh)
73
+ # res=($res)
74
+ # task_id=${res[-1]}
75
+ # echo "task_id: ${task_id}"
76
+ # touch "task_id_${task_id}"
77
+ # else
78
+ # echo "Slurm not detected, running with bash..."
79
+ # bash ./submit.sh
80
+ # fi
81
+
82
+ bash ./submit.sh
83
+
84
+ done
85
+ done
examples/submit.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Put your slurm commands here
3
+
4
+ accelerate launch --num_processes=1 {base_path}/main.py --batch_size {batch_size} --model_name {model_name} --split_name {split_name} --dataset_name {dataset_name} --output_dir {model_name}_results --n_take {n_take}
5
+ python {base_path}/evaluations/evaluate_images.py \
6
+ --original_dir {model_name}_results/original_images \
7
+ --reconstructed_dir {model_name}_results/reconstructed_images/ \
8
+ --metrics {metrics} \
9
+ --batch_size 16 \
10
+ --num_workers 8 | tee result.txt
main.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import PIL
4
+ import pickle
5
+ import torch
6
+ import argparse
7
+ import json
8
+ from PIL import Image
9
+ import torch.nn as nn
10
+ import torch
11
+ from transformers import AutoProcessor, AutoModelForImageTextToText
12
+ from src.data_loader import DataCollatorForSupervisedDataset, get_dataset
13
+ from src.data_processing import tensor_to_pil
14
+ from src.model_processing import get_model
15
+ from PIL import Image
16
+ from accelerate import Accelerator
17
+ from torch.utils.data import DataLoader
18
+ from tqdm import tqdm
19
+ from concurrent.futures import ThreadPoolExecutor
20
+
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("--model_name", type=str, default="chameleon")
23
+ parser.add_argument("--model_path", type=str, default=None)
24
+ parser.add_argument("--dataset_name", type=str, default="task3-movie-posters")
25
+ parser.add_argument("--split_name", type=str, default="test")
26
+ parser.add_argument("--batch_size", default=8, type=int)
27
+ parser.add_argument("--output_dir", type=str, default=None)
28
+ parser.add_argument("--begin_id", default=0, type=int)
29
+ parser.add_argument("--n_take", default=-1, type=int)
30
+ args = parser.parse_args()
31
+
32
+ batch_size = args.batch_size
33
+ output_dir = args.output_dir
34
+
35
+ accelerator = Accelerator()
36
+
37
+ if accelerator.is_main_process and output_dir is not None:
38
+ os.makedirs(output_dir, exist_ok=True)
39
+ os.makedirs(f"{output_dir}/original_images", exist_ok=True)
40
+ os.makedirs(f"{output_dir}/reconstructed_images", exist_ok=True)
41
+ os.makedirs(f"{output_dir}/results", exist_ok=True)
42
+
43
+ model, data_params = get_model(args.model_path, args.model_name)
44
+ dataset = get_dataset(args.dataset_name, args.split_name, None if args.n_take <= 0 else args.n_take)
45
+ data_collator = DataCollatorForSupervisedDataset(args.dataset_name, **data_params)
46
+ dataloader = DataLoader(
47
+ dataset, batch_size=batch_size, num_workers=0, collate_fn=data_collator
48
+ )
49
+
50
+ model, dataloader = accelerator.prepare(model, dataloader)
51
+ print("Model prepared...")
52
+
53
+
54
+ def save_results(
55
+ pixel_values, reconstructed_image, idx, output_dir, data_params
56
+ ):
57
+ if reconstructed_image is None:
58
+ return
59
+
60
+ ori_img = tensor_to_pil(pixel_values, **data_params)
61
+ rec_img = tensor_to_pil(reconstructed_image, **data_params)
62
+
63
+ ori_img.save(f"{output_dir}/original_images/{idx:08d}.png")
64
+ rec_img.save(f"{output_dir}/reconstructed_images/{idx:08d}.png")
65
+
66
+ result = {
67
+ "ori_img": ori_img,
68
+ "rec_img": rec_img,
69
+ }
70
+
71
+ with open(f"{output_dir}/results/{idx:08d}.pickle", "wb") as fw:
72
+ pickle.dump(result, fw)
73
+
74
+
75
+ executor = ThreadPoolExecutor(max_workers=16)
76
+ with torch.no_grad():
77
+ print("Begin data loading...")
78
+ for batch in tqdm(dataloader):
79
+ pixel_values = batch["image"]
80
+ reconstructed_images = model(pixel_values)
81
+ if isinstance(reconstructed_images, tuple):
82
+ reconstructed_images = reconstructed_images[0]
83
+
84
+ if output_dir is not None:
85
+ idx_list = batch["idx"]
86
+ original_images = pixel_values.detach().cpu()
87
+ if not isinstance(reconstructed_images, list):
88
+ reconstructed_images = reconstructed_images.detach().cpu()
89
+ for i in range(pixel_values.shape[0]):
90
+ executor.submit(
91
+ save_results,
92
+ original_images[i],
93
+ reconstructed_images[i],
94
+ idx_list[i],
95
+ output_dir,
96
+ data_params,
97
+ )
98
+
99
+ executor.shutdown(wait=True)
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ mup==1.0.0
3
+ einops
4
+ omegaconf
5
+ lightning==2.3.3
6
+ piq
7
+ python-Levenshtein
8
+ verovio
9
+ pytorch_fid
10
+ transformers
11
+ torch-fidelity
12
+ accelerate
13
+ datasets
14
+ git+https://github.com/deepseek-ai/Janus.git
15
+ diffusers
16
+ openai
17
+ imageio
18
+ huggingface_hub
19
+ gradio
20
+ torch
21
+ torchvision
src/__init__.py ADDED
File without changes
src/data_loader.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ from PIL import Image
3
+ from dataclasses import dataclass, field
4
+ from datasets import load_dataset
5
+ import torch
6
+ from .data_processing import pil_to_tensor
7
+
8
+
9
+ @dataclass
10
+ class DataCollatorForSupervisedDataset(object):
11
+ """Collate examples for supervised fine-tuning."""
12
+
13
+ def __init__(self, dataset_name, **kwargs):
14
+ override_params = {}
15
+ if dataset_name == "DIV2K":
16
+ override_params = {
17
+ "target_image_size": -1,
18
+ "lock_ratio": True,
19
+ "center_crop": False,
20
+ "padding": False,
21
+ }
22
+ if dataset_name == "imagenet":
23
+ override_params = {"center_crop": True, "padding": False}
24
+ if dataset_name == "movie_posters":
25
+ override_params = {"center_crop": True, "padding": False}
26
+ if dataset_name == "high_quality_1024":
27
+ override_params = {"target_image_size": (1024, 1024)}
28
+
29
+ self.data_params = {**kwargs, **override_params}
30
+
31
+ def __call__(self, instances):
32
+ images = torch.stack(
33
+ [
34
+ pil_to_tensor(instance["image"], **self.data_params)
35
+ for instance in instances
36
+ ],
37
+ dim=0,
38
+ )
39
+ idx = [instance["idx"] for instance in instances]
40
+ return dict(image=images, idx=idx)
41
+
42
+
43
+ class ImagenetDataset(torch.utils.data.Dataset):
44
+ def __init__(self, dataset_name, split_name="test", n_take=None):
45
+ print(dataset_name, split_name)
46
+ ds = load_dataset("huaweilin/VTBench", name=dataset_name, split=split_name if n_take is None else f"{split_name}[:{n_take}]")
47
+ self.image_list = ds["image"]
48
+
49
+ def __len__(self):
50
+ return len(self.image_list)
51
+
52
+ def __getitem__(self, idx):
53
+ return dict(
54
+ image=self.image_list[idx],
55
+ idx=idx,
56
+ )
57
+
58
+
59
+ def get_dataset(dataset_name, split_name, n_take):
60
+ dataset = ImagenetDataset(dataset_name, split_name, n_take)
61
+ return dataset
src/data_processing.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import PIL
3
+ from PIL import Image
4
+ import torch
5
+
6
+
7
+ def pil_to_tensor(
8
+ img: Image.Image,
9
+ target_image_size=512,
10
+ lock_ratio=True,
11
+ center_crop=True,
12
+ padding=False,
13
+ standardize=True,
14
+ **kwarg
15
+ ) -> torch.Tensor:
16
+ if img.mode != "RGB":
17
+ img = img.convert("RGB")
18
+
19
+ if isinstance(target_image_size, int):
20
+ target_size = (target_image_size, target_image_size)
21
+ if target_image_size < 0:
22
+ target_size = img.size
23
+ else:
24
+ target_size = target_image_size # (width, height)
25
+
26
+ if lock_ratio:
27
+ original_width, original_height = img.size
28
+ target_width, target_height = target_size
29
+
30
+ scale_w = target_width / original_width
31
+ scale_h = target_height / original_height
32
+
33
+ if center_crop:
34
+ scale = max(scale_w, scale_h)
35
+ elif padding:
36
+ scale = min(scale_w, scale_h)
37
+ else:
38
+ scale = 1.0 # fallback
39
+
40
+ new_size = (round(original_width * scale), round(original_height * scale))
41
+ img = img.resize(new_size, Image.LANCZOS)
42
+
43
+ if center_crop:
44
+ left = (img.width - target_width) // 2
45
+ top = (img.height - target_height) // 2
46
+ img = img.crop((left, top, left + target_width, top + target_height))
47
+ elif padding:
48
+ new_img = Image.new("RGB", target_size, (0, 0, 0))
49
+ left = (target_width - img.width) // 2
50
+ top = (target_height - img.height) // 2
51
+ new_img.paste(img, (left, top))
52
+ img = new_img
53
+ else:
54
+ img = img.resize(target_size, Image.LANCZOS)
55
+
56
+ np_img = np.array(img) / 255.0 # Normalize to [0, 1]
57
+ if standardize:
58
+ np_img = np_img * 2 - 1 # Scale to [-1, 1]
59
+ tensor_img = torch.from_numpy(np_img).permute(2, 0, 1).float() # (C, H, W)
60
+
61
+ return tensor_img
62
+
63
+
64
+ def tensor_to_pil(chw_tensor: torch.Tensor, standardize=True, **kwarg) -> PIL.Image:
65
+ # Ensure detachment and move tensor to CPU.
66
+ detached_chw_tensor = chw_tensor.detach().cpu()
67
+
68
+ # Normalize tensor to [0, 1] range from [-1, 1] range.
69
+ if standardize:
70
+ normalized_chw_tensor = (
71
+ torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0
72
+ ) / 2.0
73
+ else:
74
+ normalized_chw_tensor = torch.clamp(detached_chw_tensor, 0.0, 1.0)
75
+
76
+ # Permute CHW tensor to HWC format and convert to NumPy array.
77
+ hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy()
78
+
79
+ # Convert to an 8-bit unsigned integer format.
80
+ image_array_uint8 = (hwc_array * 255).astype(np.uint8)
81
+
82
+ # Convert NumPy array to PIL Image.
83
+ pil_image = Image.fromarray(image_array_uint8)
84
+
85
+ # Convert image to RGB if it is not already.
86
+ if pil_image.mode != "RGB":
87
+ pil_image = pil_image.convert("RGB")
88
+
89
+ return pil_image
src/model_processing.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+ import yaml
4
+ from .utils import get_ckpt, get_yaml_config
5
+
6
+
7
+ def download_ckpt_yaml(model_path, model_name, ckpt_path, yaml_url=None):
8
+ def download_file(url, save_path):
9
+ response = requests.get(url)
10
+ response.raise_for_status()
11
+ with open(save_path, 'wb') as f:
12
+ f.write(response.content)
13
+
14
+ # os.makedirs(model_path, exist_ok=True)
15
+ local_dir = os.path.join(model_path, model_name)
16
+ os.makedirs(local_dir, exist_ok=True)
17
+
18
+ ckpt_name = ckpt_path.split("/")[-1]
19
+ local_ckpt_path = os.path.join(local_dir, ckpt_name)
20
+ if not os.path.exists(local_ckpt_path):
21
+ print(f"Downloading CKPT to {local_ckpt_path}")
22
+ download_file(ckpt_path, local_ckpt_path)
23
+
24
+ if yaml_url:
25
+ yaml_name = yaml_url.split("/")[-1]
26
+ local_yaml_path = os.path.join(local_dir, yaml_name)
27
+ if not os.path.exists(local_yaml_path):
28
+ print(f"Downloading YAML to {local_yaml_path}")
29
+ download_file(yaml_url, local_yaml_path)
30
+ return local_ckpt_path, local_yaml_path
31
+
32
+ return local_ckpt_path, None
33
+
34
+
35
+ def get_model(model_path, model_name):
36
+ model = None
37
+ data_params = {
38
+ "target_image_size": (512, 512),
39
+ "lock_ratio": True,
40
+ "center_crop": True,
41
+ "padding": False,
42
+ }
43
+
44
+ if model_name.lower() == "anole":
45
+ from src.vqvaes.anole.anole import VQModel
46
+ yaml_url = "https://huggingface.co/GAIR/Anole-7b-v0.1/resolve/main/tokenizer/vqgan.yaml"
47
+ ckpt_path = "https://huggingface.co/GAIR/Anole-7b-v0.1/resolve/main/tokenizer/vqgan.ckpt"
48
+
49
+ if model_path is not None:
50
+ ckpt_path, yaml_url = download_ckpt_yaml(model_path, "anole", ckpt_path, yaml_url)
51
+ config = get_yaml_config(yaml_url)
52
+
53
+ params = config["model"]["params"]
54
+ if "lossconfig" in params:
55
+ del params["lossconfig"]
56
+ params["ckpt_path"] = ckpt_path
57
+ model = VQModel(**params)
58
+ data_params = {
59
+ "target_image_size": (512, 512),
60
+ "lock_ratio": True,
61
+ "center_crop": True,
62
+ "padding": False,
63
+ }
64
+
65
+ elif model_name.lower() == "chameleon":
66
+ from src.vqvaes.anole.anole import VQModel
67
+
68
+ yaml_url = "https://huggingface.co/huaweilin/chameleon_vqvae/resolve/main/vqgan.yaml"
69
+ ckpt_path = "https://huggingface.co/huaweilin/chameleon_vqvae/resolve/main/vqgan.ckpt"
70
+ if model_path is not None:
71
+ ckpt_path, yaml_url = download_ckpt_yaml(model_path, "chameleon", ckpt_path, yaml_url)
72
+ config = get_yaml_config(yaml_url)
73
+
74
+ params = config["model"]["params"]
75
+ if "lossconfig" in params:
76
+ del params["lossconfig"]
77
+ params["ckpt_path"] = ckpt_path
78
+ model = VQModel(**params)
79
+ data_params = {
80
+ "target_image_size": (512, 512),
81
+ "lock_ratio": True,
82
+ "center_crop": True,
83
+ "padding": False,
84
+ }
85
+
86
+ elif model_name.lower() == "llamagen-ds16":
87
+ from src.vqvaes.llamagen.llamagen import VQ_models
88
+ ckpt_path = "https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds16_c2i.pt"
89
+ if model_path is not None:
90
+ ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds16", ckpt_path, None)
91
+
92
+ model = VQ_models["VQ-16"](codebook_size=16384, codebook_embed_dim=8)
93
+ model.load_state_dict(get_ckpt(ckpt_path, key="model"))
94
+ data_params = {
95
+ "target_image_size": (512, 512),
96
+ "lock_ratio": True,
97
+ "center_crop": True,
98
+ "padding": False,
99
+ }
100
+
101
+ elif model_name.lower() == "llamagen-ds16-t2i":
102
+ from src.vqvaes.llamagen.llamagen import VQ_models
103
+ ckpt_path = "https://huggingface.co/peizesun/llamagen_t2i/resolve/main/vq_ds16_t2i.pt"
104
+ if model_path is not None:
105
+ ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds16-t2i", ckpt_path, None)
106
+
107
+ model = VQ_models["VQ-16"](codebook_size=16384, codebook_embed_dim=8)
108
+ model.load_state_dict(get_ckpt(ckpt_path, key="model"))
109
+ data_params = {
110
+ "target_image_size": (512, 512),
111
+ "lock_ratio": True,
112
+ "center_crop": True,
113
+ "padding": False,
114
+ }
115
+
116
+ elif model_name.lower() == "llamagen-ds8":
117
+ from src.vqvaes.llamagen.llamagen import VQ_models
118
+ ckpt_path = "https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds8_c2i.pt"
119
+ if model_path is not None:
120
+ ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds8", ckpt_path, None)
121
+
122
+ model = VQ_models["VQ-8"](codebook_size=16384, codebook_embed_dim=8)
123
+ model.load_state_dict(get_ckpt(ckpt_path, key="model"))
124
+ data_params = {
125
+ "target_image_size": (256, 256),
126
+ "lock_ratio": True,
127
+ "center_crop": True,
128
+ "padding": False,
129
+ }
130
+
131
+ elif model_name.lower() == "flowmo_lo":
132
+ from src.vqvaes.flowmo.flowmo import build_model
133
+ yaml_url = "https://raw.githubusercontent.com/kylesargent/FlowMo/refs/heads/main/flowmo/configs/base.yaml"
134
+ ckpt_path = "https://huggingface.co/ksarge/FlowMo/resolve/main/flowmo_lo.pth"
135
+ if model_path is not None:
136
+ ckpt_path, yaml_url = download_ckpt_yaml(model_path, "flowmo_lo", ckpt_path, yaml_url)
137
+ config = get_yaml_config(yaml_url)
138
+
139
+ config.model.context_dim = 18
140
+ model = build_model(config)
141
+ model.load_state_dict(
142
+ get_ckpt(ckpt_path, key="model_ema_state_dict")
143
+ )
144
+ data_params = {
145
+ "target_image_size": (256, 256),
146
+ "lock_ratio": True,
147
+ "center_crop": True,
148
+ "padding": False,
149
+ }
150
+
151
+ elif model_name.lower() == "flowmo_hi":
152
+ from src.vqvaes.flowmo.flowmo import build_model
153
+
154
+ yaml_url = "https://raw.githubusercontent.com/kylesargent/FlowMo/refs/heads/main/flowmo/configs/base.yaml"
155
+ ckpt_path = "https://huggingface.co/ksarge/FlowMo/resolve/main/flowmo_hi.pth"
156
+ if model_path is not None:
157
+ ckpt_path, yaml_url = download_ckpt_yaml(model_path, "flowmo_hi", ckpt_path, yaml_url)
158
+ config = get_yaml_config(yaml_url)
159
+
160
+ config.model.context_dim = 56
161
+ config.model.codebook_size_for_entropy = 14
162
+ model = build_model(config)
163
+ model.load_state_dict(
164
+ get_ckpt(ckpt_path, key="model_ema_state_dict")
165
+ )
166
+ data_params = {
167
+ "target_image_size": (256, 256),
168
+ "lock_ratio": True,
169
+ "center_crop": True,
170
+ "padding": False,
171
+ }
172
+
173
+ elif model_name.lower() == "open_magvit2":
174
+ from src.vqvaes.open_magvit2.open_magvit2 import VQModel
175
+
176
+ yaml_url = "https://raw.githubusercontent.com/TencentARC/SEED-Voken/refs/heads/main/configs/Open-MAGVIT2/gpu/imagenet_lfqgan_256_L.yaml"
177
+ ckpt_path = "https://huggingface.co/TencentARC/Open-MAGVIT2-Tokenizer-256-resolution/resolve/main/imagenet_256_L.ckpt"
178
+ if model_path is not None:
179
+ ckpt_path, yaml_url = download_ckpt_yaml(model_path, "open_magvit2", ckpt_path, yaml_url)
180
+ config = get_yaml_config(yaml_url)
181
+
182
+ model = VQModel(**config.model.init_args)
183
+ model.load_state_dict(get_ckpt(ckpt_path, key="state_dict"))
184
+ data_params = {
185
+ "target_image_size": (256, 256),
186
+ "lock_ratio": True,
187
+ "center_crop": True,
188
+ "padding": False,
189
+ }
190
+
191
+ elif "maskbit" in model_name.lower():
192
+ from src.vqvaes.maskbit.maskbit import ConvVQModel
193
+
194
+ if "16bit" in model_name.lower():
195
+ yaml_url = "https://raw.githubusercontent.com/markweberdev/maskbit/refs/heads/main/configs/tokenizer/maskbit_tokenizer_16bit.yaml"
196
+ ckpt_path = "https://huggingface.co/markweber/maskbit_tokenizer_16bit/resolve/main/maskbit_tokenizer_16bit.bin"
197
+ if model_path is not None:
198
+ ckpt_path, yaml_url = download_ckpt_yaml(model_path, "maskbit-16bit", ckpt_path, yaml_url)
199
+ elif "18bit" in model_name.lower():
200
+ yaml_url = "https://raw.githubusercontent.com/markweberdev/maskbit/refs/heads/main/configs/tokenizer/maskbit_tokenizer_18bit.yaml"
201
+ ckpt_path = "https://huggingface.co/markweber/maskbit_tokenizer_18bit/resolve/main/maskbit_tokenizer_18bit.bin"
202
+ if model_path is not None:
203
+ ckpt_path, yaml_url = download_ckpt_yaml(model_path, "maskbit-18bit", ckpt_path, yaml_url)
204
+ else:
205
+ raise Exception(f"Unsupported model: {model_name}")
206
+
207
+ config = get_yaml_config(yaml_url)
208
+ model = ConvVQModel(config.model.vq_model, legacy=False)
209
+ model.load_pretrained(get_ckpt(ckpt_path, key=None))
210
+ data_params = {
211
+ "target_image_size": (256, 256),
212
+ "lock_ratio": True,
213
+ "center_crop": True,
214
+ "padding": False,
215
+ "standardize": False,
216
+ }
217
+
218
+ elif "bsqvit" in model_name.lower():
219
+ from src.vqvaes.bsqvit.bsqvit import VITBSQModel
220
+
221
+ yaml_url = "https://huggingface.co/huaweilin/bsqvit_256x256/resolve/main/config.yaml"
222
+ ckpt_path = "https://huggingface.co/huaweilin/bsqvit_256x256/resolve/main/checkpoint.pt"
223
+ if model_path is not None:
224
+ ckpt_path, yaml_url = download_ckpt_yaml(model_path, "bsqvit", ckpt_path, yaml_url)
225
+
226
+ config = get_yaml_config(yaml_url)
227
+ model = VITBSQModel(**config["model"]["params"])
228
+ model.init_from_ckpt(get_ckpt(ckpt_path, key="state_dict"))
229
+ data_params = {
230
+ "target_image_size": (256, 256),
231
+ "lock_ratio": True,
232
+ "center_crop": True,
233
+ "padding": False,
234
+ "standardize": False,
235
+ }
236
+
237
+ elif "titok" in model_name.lower():
238
+ from src.vqvaes.titok.titok import TiTok
239
+
240
+ ckpt_path = None
241
+ if "bl64" in model_name.lower():
242
+ ckpt_path = "yucornetto/tokenizer_titok_bl64_vq8k_imagenet"
243
+ elif "bl128" in model_name.lower():
244
+ ckpt_path = "yucornetto/tokenizer_titok_bl128_vq8k_imagenet"
245
+ elif "sl256" in model_name.lower():
246
+ ckpt_path = "yucornetto/tokenizer_titok_sl256_vq8k_imagenet"
247
+ elif "l32" in model_name.lower():
248
+ ckpt_path = "yucornetto/tokenizer_titok_l32_imagenet"
249
+ elif "b64" in model_name.lower():
250
+ ckpt_path = "yucornetto/tokenizer_titok_b64_imagenet"
251
+ elif "s128" in model_name.lower():
252
+ ckpt_path = "yucornetto/tokenizer_titok_s128_imagenet"
253
+ else:
254
+ raise Exception(f"Unsupported model: {model_name}")
255
+
256
+ model = TiTok.from_pretrained(ckpt_path)
257
+ data_params = {
258
+ "target_image_size": (256, 256),
259
+ "lock_ratio": True,
260
+ "center_crop": True,
261
+ "padding": False,
262
+ "standardize": False,
263
+ }
264
+
265
+ elif "janus_pro" in model_name.lower():
266
+ from janus.models import MultiModalityCausalLM
267
+ from src.vqvaes.janus_pro.janus_pro import forward
268
+ import types
269
+
270
+ model = MultiModalityCausalLM.from_pretrained(
271
+ "deepseek-ai/Janus-Pro-7B", trust_remote_code=True
272
+ ).gen_vision_model
273
+ model.forward = types.MethodType(forward, model)
274
+ data_params = {
275
+ "target_image_size": (384, 384),
276
+ "lock_ratio": True,
277
+ "center_crop": False,
278
+ "padding": True,
279
+ }
280
+
281
+ elif "var" in model_name.lower():
282
+ from src.vqvaes.var.var_vq import VQVAE
283
+
284
+ ckpt_path = "https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth"
285
+ if model_path is not None:
286
+ ckpt_path, _ = download_ckpt_yaml(model_path, "var", ckpt_path, None)
287
+
288
+ v_patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
289
+ if "512" in model_name.lower():
290
+ v_patch_nums = (1, 2, 3, 4, 6, 9, 13, 18, 24, 32)
291
+ model = VQVAE(
292
+ vocab_size=4096,
293
+ z_channels=32,
294
+ ch=160,
295
+ test_mode=True,
296
+ share_quant_resi=4,
297
+ v_patch_nums=v_patch_nums,
298
+ )
299
+ model.load_state_dict(get_ckpt(ckpt_path, key=None))
300
+ data_params = {
301
+ "target_image_size": (
302
+ (512, 512) if "512" in model_name.lower() else (256, 256)
303
+ ),
304
+ "lock_ratio": True,
305
+ "center_crop": False,
306
+ "padding": True,
307
+ "standardize": False,
308
+ }
309
+
310
+ elif (
311
+ "infinity" in model_name.lower()
312
+ ): # "infinity_d32", "infinity_d64", "infinity_d56_f8_14_patchify"
313
+ from src.vqvaes.infinity.vae import vae_model
314
+
315
+ if "d32" in model_name:
316
+ ckpt_path = "https://huggingface.co/FoundationVision/Infinity/resolve/main/infinity_vae_d32.pth"
317
+ codebook_dim = 32
318
+ if model_path is not None:
319
+ ckpt_path, _ = download_ckpt_yaml(model_path, "infinity-d32", ckpt_path, None)
320
+ elif "d64" in model_name:
321
+ ckpt_path = "https://huggingface.co/FoundationVision/Infinity/resolve/main/infinity_vae_d64.pth"
322
+ codebook_dim = 64
323
+ if model_path is not None:
324
+ ckpt_path, _ = download_ckpt_yaml(model_path, "infinity-d64", ckpt_path, None)
325
+
326
+ schedule_mode = "dynamic"
327
+ codebook_size = 2**codebook_dim
328
+ patch_size = 16
329
+ encoder_ch_mult = [1, 2, 4, 4, 4]
330
+ decoder_ch_mult = [1, 2, 4, 4, 4]
331
+
332
+ ckpt = get_ckpt(ckpt_path, key=None)
333
+ model = vae_model(
334
+ ckpt,
335
+ schedule_mode,
336
+ codebook_dim,
337
+ codebook_size,
338
+ patch_size=patch_size,
339
+ encoder_ch_mult=encoder_ch_mult,
340
+ decoder_ch_mult=decoder_ch_mult,
341
+ test_mode=True,
342
+ )
343
+
344
+ data_params = {
345
+ "target_image_size": (1024, 1024),
346
+ "lock_ratio": True,
347
+ "center_crop": False,
348
+ "padding": True,
349
+ "standardize": False,
350
+ }
351
+
352
+ elif "sd3.5l" in model_name.lower(): # SD3.5L
353
+ from src.vaes.stable_diffusion.vae import forward
354
+ from diffusers import AutoencoderKL
355
+ import types
356
+
357
+ model = AutoencoderKL.from_pretrained(
358
+ "huaweilin/stable-diffusion-3.5-large-vae", subfolder="vae"
359
+ )
360
+ model.forward = types.MethodType(forward, model)
361
+ data_params = {
362
+ "target_image_size": (1024, 1024),
363
+ "lock_ratio": True,
364
+ "center_crop": False,
365
+ "padding": True,
366
+ "standardize": True,
367
+ }
368
+
369
+ elif "FLUX.1-dev".lower() in model_name.lower(): # SD3.5L
370
+ from src.vaes.stable_diffusion.vae import forward
371
+ from diffusers import AutoencoderKL
372
+ import types
373
+
374
+ model = AutoencoderKL.from_pretrained(
375
+ "black-forest-labs/FLUX.1-dev", subfolder="vae"
376
+ )
377
+ model.forward = types.MethodType(forward, model)
378
+ data_params = {
379
+ "target_image_size": (1024, 1024),
380
+ "lock_ratio": True,
381
+ "center_crop": False,
382
+ "padding": True,
383
+ "standardize": True,
384
+ }
385
+
386
+ elif "gpt4o" in model_name.lower():
387
+ from src.vaes.gpt_image.gpt_image import GPTImage
388
+
389
+ data_params = {
390
+ "target_image_size": (1024, 1024),
391
+ "lock_ratio": True,
392
+ "center_crop": False,
393
+ "padding": True,
394
+ "standardize": False,
395
+ }
396
+ model = GPTImage(data_params)
397
+
398
+ else:
399
+ raise Exception(f"Unsupported model: \"{model_name}\"")
400
+
401
+ try:
402
+ trainable_params = sum(p.numel() for p in model.parameters())
403
+ print("trainable_params:", trainable_params)
404
+ except Exception as e:
405
+ print(e)
406
+ pass
407
+
408
+ model.eval()
409
+ return model, data_params
src/utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from omegaconf import OmegaConf
3
+ import torch
4
+ import tempfile
5
+ from safetensors.torch import load_file
6
+ import requests
7
+ import yaml
8
+
9
+ def get_ckpt(path, key="state_dict"):
10
+ is_url = path.startswith("http://") or path.startswith("https://")
11
+ suffix = os.path.splitext(path)[-1]
12
+
13
+ if is_url:
14
+ print(f"Loading checkpoint from URL: {path}")
15
+ with tempfile.NamedTemporaryFile(suffix=suffix) as tmp_file:
16
+ response = requests.get(path)
17
+ response.raise_for_status()
18
+ tmp_file.write(response.content)
19
+ tmp_file.flush()
20
+ ckpt_path = tmp_file.name
21
+
22
+ if suffix == ".safetensors":
23
+ checkpoint = load_file(ckpt_path)
24
+ else:
25
+ checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
26
+ else:
27
+ print(f"Loading checkpoint from local path: {path}")
28
+ if suffix == ".safetensors":
29
+ checkpoint = load_file(path)
30
+ else:
31
+ checkpoint = torch.load(path, map_location="cpu", weights_only=False)
32
+
33
+ if key is not None and key in checkpoint:
34
+ checkpoint = checkpoint[key]
35
+
36
+ return checkpoint
37
+
38
+
39
+ def get_yaml_config(path):
40
+ if path.startswith("http://") or path.startswith("https://"):
41
+ response = requests.get(path)
42
+ response.raise_for_status()
43
+ config = OmegaConf.create(response.text)
44
+ else:
45
+ with open(path, 'r') as f:
46
+ config = OmegaConf.load(f)
47
+ return config
src/vaes/gpt_image/gpt_image.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from torchvision.transforms.functional import to_pil_image
3
+ from openai import OpenAI
4
+ import io
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ from ...data_processing import tensor_to_pil, pil_to_tensor
9
+
10
+
11
+ class GPTImage:
12
+ def __init__(self, data_params):
13
+ self.client = OpenAI(organization="org-xZTnLOf1k9s04LEoKKjl4jOB")
14
+ self.prompt = "Please recreate the exact same image without any alterations. Please preserve the original resolution (1024*1024)."
15
+ self.data_params = data_params
16
+
17
+ def eval(self):
18
+ pass
19
+
20
+ def __call__(self, *args, **kwargs):
21
+ return self.forward(*args, **kwargs)
22
+
23
+ def forward(self, input):
24
+ results = []
25
+ for image in input:
26
+ image = tensor_to_pil(image, **self.data_params)
27
+ buffer = io.BytesIO()
28
+ image.save(buffer, format="PNG")
29
+ buffer.seek(0)
30
+ image_file = ("image.png", buffer, "image/png")
31
+
32
+ try:
33
+ result = self.client.images.edit(
34
+ model="gpt-image-1",
35
+ image=image_file,
36
+ prompt=self.prompt,
37
+ n=1,
38
+ size="1024x1024",
39
+ )
40
+ image_base64 = result.data[0].b64_json
41
+ image_bytes = base64.b64decode(image_base64)
42
+ image = Image.open(io.BytesIO(image_bytes))
43
+ results.append(pil_to_tensor(image, **self.data_params))
44
+ except Exception as e:
45
+ print("💥 Unexpected error occurred:", e)
46
+ results.append(None)
47
+
48
+ return results, None, None
src/vaes/stable_diffusion/vae.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def forward(
2
+ self,
3
+ sample,
4
+ sample_posterior=False,
5
+ return_dict=True,
6
+ generator=None,
7
+ ):
8
+ r"""
9
+ Args:
10
+ sample (`torch.Tensor`): Input sample.
11
+ sample_posterior (`bool`, *optional*, defaults to `False`):
12
+ Whether to sample from the posterior.
13
+ return_dict (`bool`, *optional*, defaults to `True`):
14
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
15
+ """
16
+ x = sample
17
+ posterior = self.encode(x).latent_dist
18
+ if sample_posterior:
19
+ z = posterior.sample(generator=generator)
20
+ else:
21
+ z = posterior.mode()
22
+ dec = self.decode(z).sample
23
+ return dec, None, None
src/vqvaes/__init__.py ADDED
File without changes
src/vqvaes/anole/anole.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ Contents of this file are taken from https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/models/vqgan.py
8
+ [with minimal dependencies]
9
+
10
+ This implementation is inference-only -- training steps and optimizer components
11
+ introduce significant additional dependencies
12
+ """
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from ...utils import get_ckpt
19
+
20
+
21
+ class VectorQuantizer2(nn.Module):
22
+ """
23
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
24
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
25
+ """
26
+
27
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
28
+ # backwards compatibility we use the buggy version by default, but you can
29
+ # specify legacy=False to fix it.
30
+ def __init__(
31
+ self,
32
+ n_e,
33
+ e_dim,
34
+ beta,
35
+ remap=None,
36
+ unknown_index="random",
37
+ sane_index_shape=False,
38
+ legacy=True,
39
+ ):
40
+ super().__init__()
41
+ self.n_e = n_e
42
+ self.e_dim = e_dim
43
+ self.beta = beta
44
+ self.legacy = legacy
45
+
46
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
47
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
48
+
49
+ self.remap = remap
50
+ if self.remap is not None:
51
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
52
+ self.re_embed = self.used.shape[0]
53
+ self.unknown_index = unknown_index # "random" or "extra" or integer
54
+ if self.unknown_index == "extra":
55
+ self.unknown_index = self.re_embed
56
+ self.re_embed = self.re_embed + 1
57
+ print(
58
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
59
+ f"Using {self.unknown_index} for unknown indices."
60
+ )
61
+ else:
62
+ self.re_embed = n_e
63
+
64
+ self.sane_index_shape = sane_index_shape
65
+
66
+ def remap_to_used(self, inds):
67
+ ishape = inds.shape
68
+ assert len(ishape) > 1
69
+ inds = inds.reshape(ishape[0], -1)
70
+ used = self.used.to(inds)
71
+ match = (inds[:, :, None] == used[None, None, ...]).long()
72
+ new = match.argmax(-1)
73
+ unknown = match.sum(2) < 1
74
+ if self.unknown_index == "random":
75
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
76
+ device=new.device
77
+ )
78
+ else:
79
+ new[unknown] = self.unknown_index
80
+ return new.reshape(ishape)
81
+
82
+ def unmap_to_all(self, inds):
83
+ ishape = inds.shape
84
+ assert len(ishape) > 1
85
+ inds = inds.reshape(ishape[0], -1)
86
+ used = self.used.to(inds)
87
+ if self.re_embed > self.used.shape[0]: # extra token
88
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
89
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
90
+ return back.reshape(ishape)
91
+
92
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
93
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
94
+ assert rescale_logits is False, "Only for interface compatible with Gumbel"
95
+ assert return_logits is False, "Only for interface compatible with Gumbel"
96
+ # reshape z -> (batch, height, width, channel) and flatten
97
+ z = z.permute(0, 2, 3, 1).contiguous()
98
+ z_flattened = z.view(-1, self.e_dim)
99
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
100
+
101
+ d = (
102
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
103
+ + torch.sum(self.embedding.weight**2, dim=1)
104
+ - 2
105
+ * torch.einsum(
106
+ "bd,dn->bn", z_flattened, self.embedding.weight.transpose(0, 1)
107
+ )
108
+ )
109
+
110
+ min_encoding_indices = torch.argmin(d, dim=1)
111
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
112
+ perplexity = None
113
+ min_encodings = None
114
+
115
+ # compute loss for embedding
116
+ if not self.legacy:
117
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
118
+ (z_q - z.detach()) ** 2
119
+ )
120
+ else:
121
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
122
+ (z_q - z.detach()) ** 2
123
+ )
124
+
125
+ # preserve gradients
126
+ z_q = z + (z_q - z).detach()
127
+
128
+ # reshape back to match original input shape
129
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
130
+
131
+ if self.remap is not None:
132
+ min_encoding_indices = min_encoding_indices.reshape(
133
+ z.shape[0], -1
134
+ ) # add batch axis
135
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
136
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
137
+
138
+ if self.sane_index_shape:
139
+ min_encoding_indices = min_encoding_indices.reshape(
140
+ z_q.shape[0], z_q.shape[2], z_q.shape[3]
141
+ )
142
+
143
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
144
+
145
+ def get_codebook_entry(self, indices, shape):
146
+ # shape specifying (batch, height, width, channel)
147
+ if self.remap is not None:
148
+ indices = indices.reshape(shape[0], -1) # add batch axis
149
+ indices = self.unmap_to_all(indices)
150
+ indices = indices.reshape(-1) # flatten again
151
+
152
+ # get quantized latent vectors
153
+ z_q = self.embedding(indices)
154
+
155
+ if shape is not None:
156
+ z_q = z_q.view(shape)
157
+ # reshape back to match original input shape
158
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
159
+
160
+ return z_q
161
+
162
+
163
+ # Alias
164
+ VectorQuantizer = VectorQuantizer2
165
+
166
+
167
+ def nonlinearity(x):
168
+ # swish
169
+ return x * torch.sigmoid(x)
170
+
171
+
172
+ def Normalize(in_channels, num_groups=32):
173
+ return torch.nn.GroupNorm(
174
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
175
+ )
176
+
177
+
178
+ class Upsample(nn.Module):
179
+ def __init__(self, in_channels, with_conv):
180
+ super().__init__()
181
+ self.with_conv = with_conv
182
+ if self.with_conv:
183
+ self.conv = torch.nn.Conv2d(
184
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
185
+ )
186
+
187
+ def forward(self, x):
188
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
189
+ if self.with_conv:
190
+ x = self.conv(x)
191
+ return x
192
+
193
+
194
+ class Downsample(nn.Module):
195
+ def __init__(self, in_channels, with_conv):
196
+ super().__init__()
197
+ self.with_conv = with_conv
198
+ if self.with_conv:
199
+ # no asymmetric padding in torch conv, must do it ourselves
200
+ self.conv = torch.nn.Conv2d(
201
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
202
+ )
203
+
204
+ def forward(self, x):
205
+ if self.with_conv:
206
+ pad = (0, 1, 0, 1)
207
+ x = F.pad(x, pad, mode="constant", value=0)
208
+ x = self.conv(x)
209
+ else:
210
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
211
+ return x
212
+
213
+
214
+ class ResnetBlock(nn.Module):
215
+ def __init__(
216
+ self,
217
+ *,
218
+ in_channels,
219
+ out_channels=None,
220
+ conv_shortcut=False,
221
+ dropout,
222
+ temb_channels=512,
223
+ ):
224
+ super().__init__()
225
+ self.in_channels = in_channels
226
+ out_channels = in_channels if out_channels is None else out_channels
227
+ self.out_channels = out_channels
228
+ self.use_conv_shortcut = conv_shortcut
229
+
230
+ self.norm1 = Normalize(in_channels)
231
+ self.conv1 = torch.nn.Conv2d(
232
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
233
+ )
234
+ if temb_channels > 0:
235
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
236
+ self.norm2 = Normalize(out_channels)
237
+ self.dropout = torch.nn.Dropout(dropout)
238
+ self.conv2 = torch.nn.Conv2d(
239
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
240
+ )
241
+ if self.in_channels != self.out_channels:
242
+ if self.use_conv_shortcut:
243
+ self.conv_shortcut = torch.nn.Conv2d(
244
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
245
+ )
246
+ else:
247
+ self.nin_shortcut = torch.nn.Conv2d(
248
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
249
+ )
250
+
251
+ def forward(self, x, temb):
252
+ h = x
253
+ h = self.norm1(h)
254
+ h = nonlinearity(h)
255
+ h = self.conv1(h)
256
+
257
+ if temb is not None:
258
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
259
+
260
+ h = self.norm2(h)
261
+ h = nonlinearity(h)
262
+ h = self.dropout(h)
263
+ h = self.conv2(h)
264
+
265
+ if self.in_channels != self.out_channels:
266
+ if self.use_conv_shortcut:
267
+ x = self.conv_shortcut(x)
268
+ else:
269
+ x = self.nin_shortcut(x)
270
+
271
+ return x + h
272
+
273
+
274
+ class AttnBlock(nn.Module):
275
+ def __init__(self, in_channels):
276
+ super().__init__()
277
+ self.in_channels = in_channels
278
+
279
+ self.norm = Normalize(in_channels)
280
+ self.q = torch.nn.Conv2d(
281
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
282
+ )
283
+ self.k = torch.nn.Conv2d(
284
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
285
+ )
286
+ self.v = torch.nn.Conv2d(
287
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
288
+ )
289
+ self.proj_out = torch.nn.Conv2d(
290
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
291
+ )
292
+
293
+ def forward(self, x):
294
+ h_ = x
295
+ h_ = self.norm(h_)
296
+ q = self.q(h_)
297
+ k = self.k(h_)
298
+ v = self.v(h_)
299
+
300
+ # compute attention
301
+ b, c, h, w = q.shape
302
+ q = q.reshape(b, c, h * w)
303
+ q = q.permute(0, 2, 1) # b,hw,c
304
+ k = k.reshape(b, c, h * w) # b,c,hw
305
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
306
+ w_ = w_ * (int(c) ** (-0.5))
307
+ w_ = F.softmax(w_, dim=2)
308
+
309
+ # attend to values
310
+ v = v.reshape(b, c, h * w)
311
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
312
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
313
+ h_ = h_.reshape(b, c, h, w)
314
+
315
+ h_ = self.proj_out(h_)
316
+
317
+ return x + h_
318
+
319
+
320
+ def make_attn(in_channels, attn_type="vanilla"):
321
+ assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
322
+ # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
323
+ if attn_type == "vanilla":
324
+ return AttnBlock(in_channels)
325
+ elif attn_type == "none":
326
+ return nn.Identity(in_channels)
327
+ else:
328
+ raise ValueError("Unexpected attention type")
329
+
330
+
331
+ class Encoder(nn.Module):
332
+ def __init__(
333
+ self,
334
+ *,
335
+ ch,
336
+ out_ch,
337
+ ch_mult=(1, 2, 4, 8),
338
+ num_res_blocks,
339
+ attn_resolutions,
340
+ dropout=0.0,
341
+ resamp_with_conv=True,
342
+ in_channels,
343
+ resolution,
344
+ z_channels,
345
+ double_z=True,
346
+ use_linear_attn=False,
347
+ attn_type="vanilla",
348
+ **ignore_kwargs,
349
+ ):
350
+ super().__init__()
351
+ if use_linear_attn:
352
+ attn_type = "linear"
353
+ self.ch = ch
354
+ self.temb_ch = 0
355
+ self.num_resolutions = len(ch_mult)
356
+ self.num_res_blocks = num_res_blocks
357
+ self.resolution = resolution
358
+ self.in_channels = in_channels
359
+
360
+ # downsampling
361
+ self.conv_in = torch.nn.Conv2d(
362
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
363
+ )
364
+
365
+ curr_res = resolution
366
+ in_ch_mult = (1,) + tuple(ch_mult)
367
+ self.in_ch_mult = in_ch_mult
368
+ self.down = nn.ModuleList()
369
+ for i_level in range(self.num_resolutions):
370
+ block = nn.ModuleList()
371
+ attn = nn.ModuleList()
372
+ block_in = ch * in_ch_mult[i_level]
373
+ block_out = ch * ch_mult[i_level]
374
+ for i_block in range(self.num_res_blocks):
375
+ block.append(
376
+ ResnetBlock(
377
+ in_channels=block_in,
378
+ out_channels=block_out,
379
+ temb_channels=self.temb_ch,
380
+ dropout=dropout,
381
+ )
382
+ )
383
+ block_in = block_out
384
+ if curr_res in attn_resolutions:
385
+ attn.append(make_attn(block_in, attn_type=attn_type))
386
+ down = nn.Module()
387
+ down.block = block
388
+ down.attn = attn
389
+ if i_level != self.num_resolutions - 1:
390
+ down.downsample = Downsample(block_in, resamp_with_conv)
391
+ curr_res = curr_res // 2
392
+ self.down.append(down)
393
+
394
+ # middle
395
+ self.mid = nn.Module()
396
+ self.mid.block_1 = ResnetBlock(
397
+ in_channels=block_in,
398
+ out_channels=block_in,
399
+ temb_channels=self.temb_ch,
400
+ dropout=dropout,
401
+ )
402
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
403
+ self.mid.block_2 = ResnetBlock(
404
+ in_channels=block_in,
405
+ out_channels=block_in,
406
+ temb_channels=self.temb_ch,
407
+ dropout=dropout,
408
+ )
409
+
410
+ # end
411
+ self.norm_out = Normalize(block_in)
412
+ self.conv_out = torch.nn.Conv2d(
413
+ block_in,
414
+ 2 * z_channels if double_z else z_channels,
415
+ kernel_size=3,
416
+ stride=1,
417
+ padding=1,
418
+ )
419
+
420
+ def forward(self, x):
421
+ # timestep embedding
422
+ temb = None
423
+
424
+ # downsampling
425
+ hs = [self.conv_in(x)]
426
+ for i_level in range(self.num_resolutions):
427
+ for i_block in range(self.num_res_blocks):
428
+ h = self.down[i_level].block[i_block](hs[-1], temb)
429
+ if len(self.down[i_level].attn) > 0:
430
+ h = self.down[i_level].attn[i_block](h)
431
+ hs.append(h)
432
+ if i_level != self.num_resolutions - 1:
433
+ hs.append(self.down[i_level].downsample(hs[-1]))
434
+
435
+ # middle
436
+ h = hs[-1]
437
+ h = self.mid.block_1(h, temb)
438
+ h = self.mid.attn_1(h)
439
+ h = self.mid.block_2(h, temb)
440
+
441
+ # end
442
+ h = self.norm_out(h)
443
+ h = nonlinearity(h)
444
+ h = self.conv_out(h)
445
+ return h
446
+
447
+
448
+ class Decoder(nn.Module):
449
+ def __init__(
450
+ self,
451
+ *,
452
+ ch,
453
+ out_ch,
454
+ ch_mult=(1, 2, 4, 8),
455
+ num_res_blocks,
456
+ attn_resolutions,
457
+ dropout=0.0,
458
+ resamp_with_conv=True,
459
+ in_channels,
460
+ resolution,
461
+ z_channels,
462
+ give_pre_end=False,
463
+ tanh_out=False,
464
+ use_linear_attn=False,
465
+ attn_type="vanilla",
466
+ **ignorekwargs,
467
+ ):
468
+ super().__init__()
469
+ if use_linear_attn:
470
+ attn_type = "linear"
471
+ self.ch = ch
472
+ self.temb_ch = 0
473
+ self.num_resolutions = len(ch_mult)
474
+ self.num_res_blocks = num_res_blocks
475
+ self.resolution = resolution
476
+ self.in_channels = in_channels
477
+ self.give_pre_end = give_pre_end
478
+ self.tanh_out = tanh_out
479
+
480
+ # compute in_ch_mult, block_in and curr_res at lowest res
481
+ block_in = ch * ch_mult[self.num_resolutions - 1]
482
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
483
+ self.z_shape = (1, z_channels, curr_res, curr_res)
484
+
485
+ # z to block_in
486
+ self.conv_in = torch.nn.Conv2d(
487
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
488
+ )
489
+
490
+ # middle
491
+ self.mid = nn.Module()
492
+ self.mid.block_1 = ResnetBlock(
493
+ in_channels=block_in,
494
+ out_channels=block_in,
495
+ temb_channels=self.temb_ch,
496
+ dropout=dropout,
497
+ )
498
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
499
+ self.mid.block_2 = ResnetBlock(
500
+ in_channels=block_in,
501
+ out_channels=block_in,
502
+ temb_channels=self.temb_ch,
503
+ dropout=dropout,
504
+ )
505
+
506
+ # upsampling
507
+ self.up = nn.ModuleList()
508
+ for i_level in reversed(range(self.num_resolutions)):
509
+ block = nn.ModuleList()
510
+ attn = nn.ModuleList()
511
+ block_out = ch * ch_mult[i_level]
512
+ for i_block in range(self.num_res_blocks + 1):
513
+ block.append(
514
+ ResnetBlock(
515
+ in_channels=block_in,
516
+ out_channels=block_out,
517
+ temb_channels=self.temb_ch,
518
+ dropout=dropout,
519
+ )
520
+ )
521
+ block_in = block_out
522
+ if curr_res in attn_resolutions:
523
+ attn.append(make_attn(block_in, attn_type=attn_type))
524
+ up = nn.Module()
525
+ up.block = block
526
+ up.attn = attn
527
+ if i_level != 0:
528
+ up.upsample = Upsample(block_in, resamp_with_conv)
529
+ curr_res = curr_res * 2
530
+ self.up.insert(0, up) # prepend to get consistent order
531
+
532
+ # end
533
+ self.norm_out = Normalize(block_in)
534
+ self.conv_out = torch.nn.Conv2d(
535
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
536
+ )
537
+
538
+ def forward(self, z):
539
+ # assert z.shape[1:] == self.z_shape[1:]
540
+ self.last_z_shape = z.shape
541
+
542
+ # timestep embedding
543
+ temb = None
544
+
545
+ # z to block_in
546
+ h = self.conv_in(z)
547
+
548
+ # middle
549
+ h = self.mid.block_1(h, temb)
550
+ h = self.mid.attn_1(h)
551
+ h = self.mid.block_2(h, temb)
552
+
553
+ # upsampling
554
+ for i_level in reversed(range(self.num_resolutions)):
555
+ for i_block in range(self.num_res_blocks + 1):
556
+ h = self.up[i_level].block[i_block](h, temb)
557
+ if len(self.up[i_level].attn) > 0:
558
+ h = self.up[i_level].attn[i_block](h)
559
+ if i_level != 0:
560
+ h = self.up[i_level].upsample(h)
561
+
562
+ # end
563
+ if self.give_pre_end:
564
+ return h
565
+
566
+ h = self.norm_out(h)
567
+ h = nonlinearity(h)
568
+ h = self.conv_out(h)
569
+ if self.tanh_out:
570
+ h = torch.tanh(h)
571
+ return h
572
+
573
+
574
+ class VQModel(nn.Module):
575
+ def __init__(
576
+ self,
577
+ ddconfig,
578
+ n_embed,
579
+ embed_dim,
580
+ ckpt_path=None,
581
+ ignore_keys=[],
582
+ image_key="image",
583
+ colorize_nlabels=None,
584
+ monitor=None,
585
+ scheduler_config=None,
586
+ lr_g_factor=1.0,
587
+ remap=None,
588
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
589
+ ):
590
+ super().__init__()
591
+ self.image_key = image_key
592
+ self.encoder = Encoder(**ddconfig)
593
+ self.decoder = Decoder(**ddconfig)
594
+ self.quantize = VectorQuantizer(
595
+ n_embed,
596
+ embed_dim,
597
+ beta=0.25,
598
+ remap=remap,
599
+ sane_index_shape=sane_index_shape,
600
+ )
601
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
602
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
603
+ if ckpt_path is not None:
604
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
605
+ self.image_key = image_key
606
+ if colorize_nlabels is not None:
607
+ assert isinstance(colorize_nlabels, int)
608
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
609
+ if monitor is not None:
610
+ self.monitor = monitor
611
+ self.scheduler_config = scheduler_config
612
+ self.lr_g_factor = lr_g_factor
613
+
614
+ def init_from_ckpt(self, path, ignore_keys=list()):
615
+ if path.startswith("http://") or path.startswith("https://"):
616
+ sd = get_ckpt(path)
617
+ else:
618
+ print(f"Loading checkpoint from local path: {path}")
619
+ sd = torch.load(path, map_location="cpu")["state_dict"]
620
+
621
+ keys = list(sd.keys())
622
+ for k in keys:
623
+ for ik in ignore_keys:
624
+ if k.startswith(ik):
625
+ print(f"Deleting key {k} from state_dict.")
626
+ del sd[k]
627
+
628
+ self.load_state_dict(sd, strict=False)
629
+ print(f"VQModel loaded from {path}")
630
+
631
+ def encode(self, x):
632
+ h = self.encoder(x)
633
+ h = self.quant_conv(h)
634
+ quant, emb_loss, info = self.quantize(h)
635
+ return quant, emb_loss, info
636
+
637
+ def decode(self, quant):
638
+ quant = self.post_quant_conv(quant)
639
+ dec = self.decoder(quant)
640
+ return dec
641
+
642
+ def decode_code(self, code_b):
643
+ quant_b = self.quantize.embed_code(code_b)
644
+ dec = self.decode(quant_b)
645
+ return dec
646
+
647
+ # def forward(self, input):
648
+ # quant, diff, _ = self.encode(input)
649
+ # dec = self.decode(quant)
650
+ # return dec, diff
651
+
652
+ def forward(self, input):
653
+ quant, diff, [_, _, img_toks] = self.encode(input)
654
+
655
+ batch_size, n_channel, height, width = (
656
+ input.shape[0],
657
+ quant.shape[-1],
658
+ quant.shape[-2],
659
+ quant.shape[-3],
660
+ )
661
+ codebook_entry = self.quantize.get_codebook_entry(
662
+ img_toks, (batch_size, n_channel, height, width)
663
+ )
664
+ pixels = self.decode(codebook_entry)
665
+
666
+ return pixels, img_toks, quant
667
+
668
+ def get_input(self, batch, k):
669
+ x = batch[k]
670
+ if len(x.shape) == 3:
671
+ x = x[..., None]
672
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
673
+ return x.float()
674
+
675
+ def get_last_layer(self):
676
+ return self.decoder.conv_out.weight
677
+
678
+ def log_images(self, batch, **kwargs):
679
+ log = dict()
680
+ x = self.get_input(batch, self.image_key)
681
+ x = x.to(self.device)
682
+ xrec, _ = self(x)
683
+ if x.shape[1] > 3:
684
+ # colorize with random projection
685
+ assert xrec.shape[1] > 3
686
+ x = self.to_rgb(x)
687
+ xrec = self.to_rgb(xrec)
688
+ log["inputs"] = x
689
+ log["reconstructions"] = xrec
690
+ return log
691
+
692
+ @property
693
+ def device(self):
694
+ return next(self.parameters()).device
695
+
696
+ @property
697
+ def dtype(self):
698
+ return next(self.parameters()).dtype
699
+
700
+ def to_rgb(self, x):
701
+ assert self.image_key == "segmentation"
702
+ if not hasattr(self, "colorize"):
703
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
704
+ x = F.conv2d(x, weight=self.colorize)
705
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
706
+ return x
src/vqvaes/bsqvit/attention_mask.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_attention_mask(sequence_length, device, mask_type="block-causal", **kwargs):
5
+ if mask_type.lower() == 'none' or mask_type is None:
6
+ return None
7
+ elif mask_type.lower() == 'block-causal':
8
+ return _block_caulsal_mask_impl(sequence_length, device, **kwargs)
9
+ elif mask_type.lower() == 'causal':
10
+ return _caulsal_mask_impl(sequence_length, device, **kwargs)
11
+ else:
12
+ raise NotImplementedError(f"Mask type {mask_type} not implemented")
13
+
14
+
15
+ def _block_caulsal_mask_impl(sequence_length, device, block_size=16, **kwargs):
16
+ """
17
+ Create a block-causal mask
18
+ """
19
+ assert sequence_length % block_size == 0, "for block causal masks sequence length must be divisible by block size"
20
+ blocks = torch.ones(sequence_length // block_size, block_size, block_size, device=device)
21
+ block_diag_enable_mask = torch.block_diag(*blocks)
22
+ causal_enable_mask = torch.ones(sequence_length, sequence_length, device=device).tril_(0)
23
+ disable_mask = ((block_diag_enable_mask + causal_enable_mask) < 0.5)
24
+ return disable_mask
25
+
26
+
27
+ def _caulsal_mask_impl(sequence_length, device, **kwargs):
28
+ """
29
+ Create a causal mask
30
+ """
31
+ causal_disable_mask = torch.triu(
32
+ torch.full((sequence_length, sequence_length), float('-inf'), dtype=torch.float32, device=device),
33
+ diagonal=1,
34
+ )
35
+ return causal_disable_mask
36
+
37
+
38
+ if __name__ == '__main__':
39
+ mask = get_attention_mask(9, "cuda", mask_type="block-causal", block_size=3)
40
+ print(mask)
41
+ mask = get_attention_mask(9, "cuda", mask_type="causal")
42
+ print(mask)
src/vqvaes/bsqvit/bsqvit.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .quantizer.bsq import BinarySphericalQuantizer
6
+ from .quantizer.vq import VectorQuantizer
7
+ from .transformer import TransformerDecoder, TransformerEncoder
8
+
9
+
10
+ class VITVQModel(nn.Module):
11
+ def __init__(self, vitconfig, n_embed, embed_dim,
12
+ l2_norm=False, logit_laplace=False, ckpt_path=None, ignore_keys=[],
13
+ grad_checkpointing=False, selective_checkpointing=False,
14
+ clamp_range=(0, 1),
15
+ dvitconfig=None,
16
+ ):
17
+ super().__init__()
18
+ self.encoder = TransformerEncoder(**vitconfig)
19
+ dvitconfig = vitconfig if dvitconfig is None else dvitconfig
20
+ self.decoder = TransformerDecoder(**dvitconfig, logit_laplace=logit_laplace)
21
+ if self.training and grad_checkpointing:
22
+ self.encoder.set_grad_checkpointing(True, selective=selective_checkpointing)
23
+ self.decoder.set_grad_checkpointing(True, selective=selective_checkpointing)
24
+
25
+ self.n_embed = n_embed
26
+ self.embed_dim = embed_dim
27
+ self.l2_norm = l2_norm
28
+ self.setup_quantizer()
29
+
30
+ self.quant_embed = nn.Linear(in_features=vitconfig['width'], out_features=embed_dim)
31
+ self.post_quant_embed = nn.Linear(in_features=embed_dim, out_features=dvitconfig['width'])
32
+ self.l2_norm = l2_norm
33
+ self.logit_laplace = logit_laplace
34
+ self.clamp_range = clamp_range
35
+
36
+ if ckpt_path is not None:
37
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
38
+
39
+ def setup_quantizer(self):
40
+ self.quantize = VectorQuantizer(self.n_embed, self.embed_dim, l2_norm=self.l2_norm, beta=0.25, input_format='blc')
41
+
42
+ # def init_from_ckpt(self, ckpt_path, ignore_keys=[]):
43
+ def init_from_ckpt(self, state_dict, ignore_keys=[]):
44
+ state_dict = {k[7:]: v for k, v in state_dict.items() if k.startswith('module.')}
45
+ filtered_state_dict = {k: v for k, v in state_dict.items() if all([not k.startswith(ig) for ig in ignore_keys])}
46
+ missing_keys, unexpected_keys = self.load_state_dict(filtered_state_dict, strict=False)
47
+ print(f"missing_keys: {missing_keys}")
48
+ print(f"unexpected_keys: {unexpected_keys}")
49
+
50
+ def encode(self, x, skip_quantize=False):
51
+ h = self.encoder(x)
52
+ h = self.quant_embed(h)
53
+ if skip_quantize:
54
+ assert not self.training, 'skip_quantize should be used in eval mode only.'
55
+ if self.l2_norm:
56
+ h = F.normalize(h, dim=-1)
57
+ return h, {}, {}
58
+ quant, loss, info = self.quantize(h)
59
+ return quant, loss, info
60
+
61
+ def decode(self, quant):
62
+ h = self.post_quant_embed(quant)
63
+ x = self.decoder(h)
64
+ return x
65
+
66
+ def clamp(self, x):
67
+ if self.logit_laplace:
68
+ dec, _ = x.chunk(2, dim=1)
69
+ x = self.logit_laplace_loss.unmap(F.sigmoid(dec))
70
+ else:
71
+ x = x.clamp_(self.clamp_range[0], self.clamp_range[1])
72
+ return x
73
+
74
+ def forward(self, input, skip_quantize=False):
75
+ if self.logit_laplace:
76
+ input = self.logit_laplace_loss.inmap(input)
77
+ quant, loss, info = self.encode(input, skip_quantize=skip_quantize)
78
+ dec = self.decode(quant)
79
+ if self.logit_laplace:
80
+ dec, lnb = dec.chunk(2, dim=1)
81
+ logit_laplace_loss = self.logit_laplace_loss(dec, lnb, input)
82
+ info.update({'logit_laplace_loss': logit_laplace_loss})
83
+ dec = self.logit_laplace_loss.unmap(F.sigmoid(dec))
84
+ else:
85
+ dec = dec.clamp_(self.clamp_range[0], self.clamp_range[1])
86
+ return dec, loss, info
87
+
88
+ def get_last_layer(self):
89
+ return self.decoder.conv_out.weight
90
+
91
+
92
+ class VITBSQModel(VITVQModel):
93
+ def __init__(self, vitconfig, embed_dim, embed_group_size=9,
94
+ l2_norm=False, logit_laplace=False, ckpt_path=None, ignore_keys=[],
95
+ grad_checkpointing=False, selective_checkpointing=False,
96
+ clamp_range=(0, 1),
97
+ dvitconfig=None, beta=0., gamma0=1.0, gamma=1.0, zeta=1.0,
98
+ persample_entropy_compute='group',
99
+ cb_entropy_compute='group',
100
+ post_q_l2_norm=False,
101
+ inv_temperature=1.,
102
+ ):
103
+ # set quantizer params
104
+ self.beta = beta # commit loss
105
+ self.gamma0 = gamma0 # entropy
106
+ self.gamma = gamma # entropy penalty
107
+ self.zeta = zeta # lpips
108
+ self.embed_group_size = embed_group_size
109
+ self.persample_entropy_compute = persample_entropy_compute
110
+ self.cb_entropy_compute = cb_entropy_compute
111
+ self.post_q_l2_norm = post_q_l2_norm
112
+ self.inv_temperature = inv_temperature
113
+
114
+ # call init
115
+ super().__init__(
116
+ vitconfig,
117
+ 2 ** embed_dim,
118
+ embed_dim,
119
+ l2_norm=l2_norm,
120
+ logit_laplace=logit_laplace,
121
+ ckpt_path=ckpt_path,
122
+ ignore_keys=ignore_keys,
123
+ grad_checkpointing=grad_checkpointing,
124
+ selective_checkpointing=selective_checkpointing,
125
+ clamp_range=clamp_range,
126
+ dvitconfig=dvitconfig,
127
+ )
128
+
129
+
130
+ def setup_quantizer(self):
131
+ self.quantize = BinarySphericalQuantizer(
132
+ self.embed_dim, self.beta, self.gamma0, self.gamma, self.zeta,
133
+ group_size=self.embed_group_size,
134
+ persample_entropy_compute=self.persample_entropy_compute,
135
+ cb_entropy_compute=self.cb_entropy_compute,
136
+ input_format='blc',
137
+ l2_norm=self.post_q_l2_norm,
138
+ inv_temperature=self.inv_temperature,
139
+ )
140
+
141
+ def encode(self, x, skip_quantize=False):
142
+ h = self.encoder(x)
143
+ h = self.quant_embed(h)
144
+ if self.l2_norm:
145
+ h = F.normalize(h, dim=-1)
146
+ if skip_quantize:
147
+ assert not self.training, 'skip_quantize should be used in eval mode only.'
148
+ return h, {}, {}
149
+ quant, loss, info = self.quantize(h)
150
+ return quant, loss, info
src/vqvaes/bsqvit/quantizer/bsq.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange, reduce
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.autograd import Function
5
+
6
+
7
+ class DifferentiableEntropyFunction(Function):
8
+ @staticmethod
9
+ def forward(ctx, zq, basis, K, eps):
10
+ zb = (zq + 1) / 2
11
+ zi = ((zb * basis).sum(-1)).to(torch.int64)
12
+ cnt = torch.scatter_reduce(torch.zeros(2**K, device=zq.device, dtype=zq.dtype),
13
+ 0,
14
+ zi.flatten(),
15
+ torch.ones_like(zi.flatten()).to(zq.dtype),
16
+ 'sum')
17
+ prob = (cnt + eps) / (cnt + eps).sum()
18
+ H = -(prob * torch.log(prob)).sum()
19
+ ctx.save_for_backward(zq, zi, prob)
20
+ ctx.K = K
21
+ return H
22
+
23
+ @staticmethod
24
+ def backward(ctx, grad_output):
25
+ zq, zi, prob= ctx.saved_tensors
26
+ grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K
27
+ reord_grad = grad_array[zi.flatten()].reshape(zi.shape)
28
+ grad_input = reord_grad.unsqueeze(-1) * zq
29
+ return grad_input, None, None, None, None
30
+
31
+
32
+ def codebook_entropy(zq, basis, K, eps=1e-4):
33
+ return DifferentiableEntropyFunction.apply(zq, basis, K, eps)
34
+
35
+
36
+ class BinarySphericalQuantizer(nn.Module):
37
+ def __init__(self, embed_dim, beta, gamma0, gamma, zeta,
38
+ input_format='bchw',
39
+ soft_entropy=True, group_size=9,
40
+ persample_entropy_compute='group',
41
+ cb_entropy_compute='group',
42
+ l2_norm=False,
43
+ inv_temperature=1):
44
+ super().__init__()
45
+ self.embed_dim = embed_dim
46
+ self.beta = beta # loss weight for commit loss
47
+ self.gamma0 = gamma0 # loss weight for entropy penalty
48
+ self.gamma = gamma # loss weight for entropy penalty
49
+ self.zeta = zeta # loss weight for entire entropy penalty
50
+ self.input_format = input_format
51
+ assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size"
52
+ self.num_groups = self.embed_dim // group_size
53
+ self.group_size = group_size
54
+ assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'"
55
+ assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'"
56
+ self.persample_entropy_compute = persample_entropy_compute
57
+ self.cb_entropy_compute = cb_entropy_compute
58
+ self.l2_norm = l2_norm
59
+ self.inv_temperature = inv_temperature
60
+
61
+ self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1))
62
+ self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1))
63
+
64
+ self.num_dimensions = 2 ** embed_dim
65
+ self.bits_per_index = embed_dim
66
+
67
+ # we only need to keep the codebook portion up to the group size
68
+ # because we approximate the H loss with this subcode
69
+ group_codes = torch.arange(2 ** self.group_size)
70
+ group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:]
71
+ self.register_buffer('group_codebook', group_codebook, persistent=False)
72
+
73
+ self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf
74
+
75
+ def quantize(self, z):
76
+ assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}"
77
+
78
+ zhat = torch.where(z > 0,
79
+ torch.tensor(1, dtype=z.dtype, device=z.device),
80
+ torch.tensor(-1, dtype=z.dtype, device=z.device))
81
+ return z + (zhat - z).detach()
82
+
83
+ def forward(self, z):
84
+ if self.input_format == 'bchw':
85
+ z = rearrange(z, 'b c h w -> b h w c')
86
+ zq = self.quantize(z)
87
+
88
+ indices = self.codes_to_indexes(zq.detach())
89
+ group_indices = self.codes_to_group_indexes(zq.detach())
90
+ if not self.training:
91
+ used_codes = torch.unique(indices, return_counts=False)
92
+ else:
93
+ used_codes = None
94
+
95
+ q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
96
+
97
+ if self.soft_entropy:
98
+ persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z)
99
+ entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
100
+ else:
101
+ zb_by_sample= ((zq + 1)/2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32)
102
+ persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample)
103
+ cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim)
104
+ entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
105
+
106
+ zq = zq * q_scale
107
+
108
+ # commit loss
109
+ commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))
110
+
111
+ if self.input_format == 'bchw':
112
+ zq = rearrange(zq, 'b h w c -> b c h w')
113
+
114
+ return (
115
+ zq,
116
+ commit_loss + self.zeta * entropy_penalty / self.inv_temperature,
117
+ {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices,
118
+ "avg_prob": avg_prob}
119
+ )
120
+
121
+ def soft_entropy_loss(self, z):
122
+ # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size
123
+ # the sub-code is the last group_size bits of the full code
124
+ group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1)
125
+ divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size)
126
+
127
+ # we calculate the distance between the divided_z and the codebook for each subgroup
128
+ distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book)
129
+ prob = (-distance * self.inv_temperature).softmax(dim = -1)
130
+ if self.persample_entropy_compute == 'analytical':
131
+ if self.l2_norm:
132
+ p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature)
133
+ else:
134
+ p = torch.sigmoid(-4 * z * self.inv_temperature)
135
+ prob = torch.stack([p, 1-p], dim=-1)
136
+ per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
137
+ else:
138
+ per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
139
+
140
+ # macro average of the probability of each subgroup
141
+ avg_prob = reduce(prob, '... g d ->g d', 'mean')
142
+ codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
143
+
144
+ # the approximation of the entropy is the sum of the entropy of each subgroup
145
+ return per_sample_entropy, codebook_entropy.sum(), avg_prob
146
+
147
+ def get_hard_per_sample_entropy(self, zb_by_sample):
148
+ probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1]
149
+ persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8)
150
+ persample_entropy = persample_entropy.sum(-1)
151
+ return persample_entropy.mean()
152
+
153
+ def codes_to_indexes(self, zhat):
154
+ """Converts a `code` to an index in the codebook.
155
+ Args:
156
+ zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
157
+ """
158
+ assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}"
159
+ return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64)
160
+
161
+ def codes_to_group_indexes(self, zhat):
162
+ """Converts a `code` to a list of indexes (in groups) in the codebook.
163
+ Args:
164
+ zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
165
+ """
166
+ zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size)
167
+ return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64)
168
+
169
+ def indexes_to_codes(self, indices):
170
+ """Inverse of `indexes_to_codes`."""
171
+ indices = indices.unsqueeze(-1)
172
+ codes_non_centered = torch.remainder(
173
+ torch.floor_divide(indices, self.basis), 2
174
+ )
175
+ return codes_non_centered * 2 - 1
176
+
177
+ def group_indexes_to_codes(self, group_indices):
178
+ """Inverse of `group_indexes_to_codes`."""
179
+ group_indices = group_indices.unsqueeze(-1)
180
+ codes_non_centered = torch.remainder(
181
+ torch.floor_divide(group_indices, self.group_basis), 2
182
+ )
183
+ codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)')
184
+ return codes_non_centered * 2 - 1
185
+
186
+ def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
187
+ if normalize:
188
+ probs = (count + eps) / (count + eps).sum(dim=dim, keepdim =True)
189
+ else:
190
+ probs = count
191
+ H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
192
+ return H
193
+
194
+ def get_group_codebook_entry(self, group_indices):
195
+ z_q = self.group_indexes_to_codes(group_indices)
196
+ q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
197
+ z_q = z_q * q_scale
198
+ if self.input_format == 'bchw':
199
+ h, w = int(z_q.shape[1] ** 0.5)
200
+ assert h * w == z_q.shape[1], 'Invalid sequence length'
201
+ z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
202
+ return z_q
203
+
204
+ def get_codebook_entry(self, indices):
205
+ z_q = self.indexes_to_codes(indices)
206
+ q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
207
+ z_q = z_q * q_scale
208
+ if self.input_format == 'bchw':
209
+ h, w = int(z_q.shape[1] ** 0.5)
210
+ assert h * w == z_q.shape[1], 'Invalid sequence length'
211
+ z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
212
+ return z_q
213
+
214
+
215
+ if __name__ == "__main__":
216
+ K = 8
217
+ # zq = torch.randint(0, 2, (4, 32, K), dtype=torch.bfloat16, device='cuda') * 2 - 1
218
+ zq = torch.zeros((4, 32, K), dtype=torch.bfloat16, device='cuda') * 2 - 1
219
+ basis = (2 ** torch.arange(K - 1, -1, -1)).to(torch.bfloat16).cuda()
220
+ zq.requires_grad = True
221
+ h = codebook_entropy(zq, basis, K)
222
+ h.backward()
223
+ print(zq.grad, zq)
src/vqvaes/bsqvit/quantizer/vq.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+ import numpy as np
3
+ import torch
4
+ import torch.distributed as dist
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class VectorQuantizer(nn.Module):
10
+ def __init__(self, n_embed, embed_dim, l2_norm, beta, input_format='bchw'):
11
+ super().__init__()
12
+
13
+ self.n_embed = n_embed
14
+ self.embed_dim = embed_dim
15
+ self.l2_norm = l2_norm
16
+ self.beta = beta
17
+ assert input_format in ['bchw', 'blc']
18
+ self.input_format = input_format
19
+
20
+ self.embedding = nn.Embedding(n_embed, embed_dim)
21
+ self.embedding.weight.data.uniform_(-1 / n_embed, 1 / n_embed)
22
+ self.bits_per_index = int(np.ceil(np.log2(n_embed)))
23
+
24
+ def forward(self, z):
25
+ batch = z.shape[0]
26
+ if self.input_format == 'bchw':
27
+ z = rearrange(z, 'b c h w -> b h w c')
28
+
29
+ if self.l2_norm:
30
+ z = F.normalize(z, dim=-1)
31
+ z_flatten = z.reshape(-1, self.embed_dim)
32
+ embedding_weight = F.normalize(self.embedding.weight, dim=-1)
33
+ d = -z_flatten @ embedding_weight.t()
34
+ else:
35
+ z_flatten = z.reshape(-1, self.embed_dim)
36
+ d = torch.sum(z_flatten ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * z_flatten @ self.embedding.weight.t()
37
+
38
+ min_encoding_indices = torch.argmin(d.detach(), dim=1)
39
+ if not self.training:
40
+ used_codes = torch.unique(min_encoding_indices, return_counts=False)
41
+ else:
42
+ used_codes = None
43
+ cb_usage = F.one_hot(min_encoding_indices, self.n_embed).sum(0)
44
+ cb_entropy = self.get_entropy(cb_usage)
45
+
46
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
47
+ if self.l2_norm:
48
+ z_q = F.normalize(z_q, dim=-1)
49
+
50
+ # fix the issue with loss scaling
51
+ # loss weight should not associate with the dimensionality of words
52
+ # loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
53
+ loss = self.beta * torch.mean(((z_q.detach() - z) ** 2).sum(dim=-1)) + torch.mean(((z_q - z.detach()) ** 2).sum(dim=-1))
54
+
55
+ z_q = z + (z_q - z).detach()
56
+ if self.input_format == 'bchw':
57
+ z_q = rearrange(z_q, 'b h w c -> b c h w')
58
+ return z_q, loss, {"H":cb_entropy, "used_codes": used_codes, 'indices': min_encoding_indices.view(batch, -1)}
59
+
60
+ def get_entropy(self, count, eps=1e-4):
61
+ probs = (count + eps) / (count + eps).sum()
62
+ H = -(probs * torch.log(probs)).sum()
63
+ return H
64
+
65
+
66
+ def get_codebook_entry(self, indices):
67
+ z_q = self.embedding(indices)
68
+ if self.l2_norm:
69
+ z_q = F.normalize(z_q, dim=-1)
70
+
71
+ if self.input_format == 'bchw':
72
+ h = w = int(z_q.shape[1] ** 0.5)
73
+ assert h * w == z_q.shape[1], 'Invalid sequence length'
74
+ z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
75
+ return z_q
76
+
77
+
78
+ class EMAVectorQuantizer(nn.Module):
79
+ def __init__(self, n_embed, embed_dim, l2_norm, beta, decay=0.99, eps=1e-5, random_restart=True, restart_threshold=1.0, input_format='bchw'):
80
+ super().__init__()
81
+
82
+ self.n_embed = n_embed
83
+ self.embed_dim = embed_dim
84
+ self.l2_norm = l2_norm
85
+ self.beta = beta
86
+ self.decay = decay
87
+ self.eps = eps
88
+ self.random_restart = random_restart
89
+ self.restart_threshold = restart_threshold
90
+ self.input_format = input_format
91
+
92
+ self.embedding = nn.Embedding(n_embed, embed_dim)
93
+ self.embedding.weight.data.uniform_(-1 / n_embed, 1 / n_embed) # TODO (yzhao): test other initialization methods
94
+ self.register_buffer("ema_cluster_size", torch.zeros(self.n_embed))
95
+ self.embedding_avg = nn.Parameter(torch.Tensor(self.n_embed, self.embed_dim))
96
+ self.embedding_avg.data.copy_(self.embedding.weight.data)
97
+
98
+ def _tile(self, z):
99
+ n_z, embedding_dim = z.shape
100
+ if n_z < self.n_embed:
101
+ n_repeats = (self.n_embed + n_z - 1) // n_z
102
+ std = 0.01 / np.sqrt(embedding_dim)
103
+ z = z.repeat(n_repeats, 1)
104
+ z = z + torch.randn_like(z) * std
105
+ return z
106
+
107
+ def forward(self, z):
108
+ if self.input_format == 'bchw':
109
+ z = rearrange(z, 'b c h w -> b h w c')
110
+ z_flatten = z.reshape(-1, self.embed_dim)
111
+
112
+ d = torch.sum(z_flatten ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * z_flatten @ self.embedding.weight.t()
113
+
114
+ encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
115
+ encodings = torch.zeros(encoding_indices.size(0), self.n_embed, device=z.device)
116
+ encodings.scatter_(1, encoding_indices, 1)
117
+
118
+ z_q = self.embedding(encoding_indices).view(z.shape)
119
+ if self.l2_norm:
120
+ z = F.normalize(z, dim=-1)
121
+ z_q = F.normalize(z_q, dim=-1)
122
+
123
+ if self.training:
124
+ # EMA update cluster size
125
+ encodings_sum = encodings.sum(0)
126
+ if dist.is_initialized(): dist.all_reduce(encodings_sum)
127
+ self.ema_cluster_size.data.mul_(self.decay).add_(encodings_sum, alpha=1-self.decay)
128
+
129
+ # EMA update of the embedding vectors
130
+ dw = encodings.t() @ z_flatten
131
+ if dist.is_initialized(): dist.all_reduce(dw)
132
+ self.embedding_avg.data.mul_(self.decay).add_(dw, alpha=1-self.decay)
133
+
134
+ # Laplace smoothing of the cluster size
135
+ n = torch.sum(self.ema_cluster_size)
136
+ weights = (self.ema_cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
137
+ self.embedding.weight.data = self.embedding_avg.data / weights.unsqueeze(1)
138
+
139
+ if self.random_restart:
140
+ zz = self._tile(z_flatten)
141
+ _k_rand = zz[torch.randperm(zz.size(0))][:self.n_embed]
142
+ if dist.is_initialized(): dist.broadcast(_k_rand, 0)
143
+ usage = (self.ema_cluster_size.view(-1, 1) > self.restart_threshold).float()
144
+ self.embedding.weight.data.mul_(usage).add_(_k_rand * (1 - usage))
145
+
146
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
147
+
148
+ z_q = z + (z_q - z).detach()
149
+ if self.input_format == 'bchw':
150
+ z_q = rearrange(z_q, 'b h w c -> b c h w')
151
+ # TODO (yzhao): monitor utility of the dictionary
152
+ return z_q, loss, {}
src/vqvaes/bsqvit/stylegan_utils/custom_ops.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import glob
11
+ import torch
12
+ import torch.utils.cpp_extension
13
+ import importlib
14
+ import hashlib
15
+ import shutil
16
+ from pathlib import Path
17
+
18
+ from torch.utils.file_baton import FileBaton
19
+
20
+ #----------------------------------------------------------------------------
21
+ # Global options.
22
+
23
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
24
+
25
+ #----------------------------------------------------------------------------
26
+ # Internal helper funcs.
27
+
28
+ def _find_compiler_bindir():
29
+ patterns = [
30
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
31
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
32
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
33
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
34
+ ]
35
+ for pattern in patterns:
36
+ matches = sorted(glob.glob(pattern))
37
+ if len(matches):
38
+ return matches[-1]
39
+ return None
40
+
41
+ #----------------------------------------------------------------------------
42
+ # Main entry point for compiling and loading C++/CUDA plugins.
43
+
44
+ _cached_plugins = dict()
45
+
46
+ def get_plugin(module_name, sources, **build_kwargs):
47
+ assert verbosity in ['none', 'brief', 'full']
48
+
49
+ # Already cached?
50
+ if module_name in _cached_plugins:
51
+ return _cached_plugins[module_name]
52
+
53
+ # Print status.
54
+ if verbosity == 'full':
55
+ print(f'Setting up PyTorch plugin "{module_name}"...')
56
+ elif verbosity == 'brief':
57
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
58
+
59
+ try: # pylint: disable=too-many-nested-blocks
60
+ # Make sure we can find the necessary compiler binaries.
61
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
62
+ compiler_bindir = _find_compiler_bindir()
63
+ if compiler_bindir is None:
64
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
65
+ os.environ['PATH'] += ';' + compiler_bindir
66
+
67
+ # Compile and load.
68
+ verbose_build = (verbosity == 'full')
69
+
70
+ # Incremental build md5sum trickery. Copies all the input source files
71
+ # into a cached build directory under a combined md5 digest of the input
72
+ # source files. Copying is done only if the combined digest has changed.
73
+ # This keeps input file timestamps and filenames the same as in previous
74
+ # extension builds, allowing for fast incremental rebuilds.
75
+ #
76
+ # This optimization is done only in case all the source files reside in
77
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
78
+ # environment variable is set (we take this as a signal that the user
79
+ # actually cares about this.)
80
+ source_dirs_set = set(os.path.dirname(source) for source in sources)
81
+ if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
82
+ all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
83
+
84
+ # Compute a combined hash digest for all source files in the same
85
+ # custom op directory (usually .cu, .cpp, .py and .h files).
86
+ hash_md5 = hashlib.md5()
87
+ for src in all_source_files:
88
+ with open(src, 'rb') as f:
89
+ hash_md5.update(f.read())
90
+ build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
91
+ digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
92
+
93
+ if not os.path.isdir(digest_build_dir):
94
+ os.makedirs(digest_build_dir, exist_ok=True)
95
+ baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
96
+ if baton.try_acquire():
97
+ try:
98
+ for src in all_source_files:
99
+ shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
100
+ finally:
101
+ baton.release()
102
+ else:
103
+ # Someone else is copying source files under the digest dir,
104
+ # wait until done and continue.
105
+ baton.wait()
106
+ digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
107
+ torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
108
+ verbose=verbose_build, sources=digest_sources, **build_kwargs)
109
+ else:
110
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
111
+ module = importlib.import_module(module_name)
112
+
113
+ except:
114
+ if verbosity == 'brief':
115
+ print('Failed!')
116
+ raise
117
+
118
+ # Print status and add to cache.
119
+ if verbosity == 'full':
120
+ print(f'Done setting up PyTorch plugin "{module_name}".')
121
+ elif verbosity == 'brief':
122
+ print('Done.')
123
+ _cached_plugins[module_name] = module
124
+ return module
125
+
126
+ #----------------------------------------------------------------------------
src/vqvaes/bsqvit/stylegan_utils/misc.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+
4
+
5
+ #----------------------------------------------------------------------------
6
+ # Symbolic assert.
7
+
8
+ try:
9
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
10
+ except AttributeError:
11
+ symbolic_assert = torch.Assert # 1.7.0
12
+
13
+ #----------------------------------------------------------------------------
14
+ # Context manager to suppress known warnings in torch.jit.trace().
15
+
16
+ class suppress_tracer_warnings(warnings.catch_warnings):
17
+ def __enter__(self):
18
+ super().__enter__()
19
+ warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
20
+ return self
21
+
22
+ #----------------------------------------------------------------------------
23
+ # Assert that the shape of a tensor matches the given list of integers.
24
+ # None indicates that the size of a dimension is allowed to vary.
25
+ # Performs symbolic assertion when used in torch.jit.trace().
26
+
27
+ def assert_shape(tensor, ref_shape):
28
+ if tensor.ndim != len(ref_shape):
29
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
30
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
31
+ if ref_size is None:
32
+ pass
33
+ elif isinstance(ref_size, torch.Tensor):
34
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
35
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
36
+ elif isinstance(size, torch.Tensor):
37
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
38
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
39
+ elif size != ref_size:
40
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.cpp ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "bias_act.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17
+ {
18
+ if (x.dim() != y.dim())
19
+ return false;
20
+ for (int64_t i = 0; i < x.dim(); i++)
21
+ {
22
+ if (x.size(i) != y.size(i))
23
+ return false;
24
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25
+ return false;
26
+ }
27
+ return true;
28
+ }
29
+
30
+ //------------------------------------------------------------------------
31
+
32
+ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33
+ {
34
+ // Validate arguments.
35
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
45
+
46
+ // Validate layout.
47
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52
+
53
+ // Create output tensor.
54
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55
+ torch::Tensor y = torch::empty_like(x);
56
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57
+
58
+ // Initialize CUDA kernel parameters.
59
+ bias_act_kernel_params p;
60
+ p.x = x.data_ptr();
61
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
62
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65
+ p.y = y.data_ptr();
66
+ p.grad = grad;
67
+ p.act = act;
68
+ p.alpha = alpha;
69
+ p.gain = gain;
70
+ p.clamp = clamp;
71
+ p.sizeX = (int)x.numel();
72
+ p.sizeB = (int)b.numel();
73
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74
+
75
+ // Choose CUDA kernel.
76
+ void* kernel;
77
+ AT_DISPATCH_REDUCED_FLOATING_TYPES(x.scalar_type(), "upfirdn2d_cuda", [&]
78
+ {
79
+ kernel = choose_bias_act_kernel<scalar_t>(p);
80
+ });
81
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82
+
83
+ // Launch CUDA kernel.
84
+ p.loopX = 4;
85
+ int blockSize = 4 * 32;
86
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87
+ void* args[] = {&p};
88
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89
+ return y;
90
+ }
91
+
92
+ //------------------------------------------------------------------------
93
+
94
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95
+ {
96
+ m.def("bias_act", &bias_act);
97
+ }
98
+
99
+ //------------------------------------------------------------------------
src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.cu ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <c10/util/Half.h>
11
+ #include "bias_act.h"
12
+
13
+ //------------------------------------------------------------------------
14
+ // Helpers.
15
+
16
+ template <class T> struct InternalType;
17
+ template <> struct InternalType<double> { typedef double scalar_t; };
18
+ template <> struct InternalType<float> { typedef float scalar_t; };
19
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
20
+ template <> struct InternalType<at::BFloat16> { typedef float scalar_t; };
21
+
22
+ //------------------------------------------------------------------------
23
+ // CUDA kernel.
24
+
25
+ template <class T, int A>
26
+ __global__ void bias_act_kernel(bias_act_kernel_params p)
27
+ {
28
+ typedef typename InternalType<T>::scalar_t scalar_t;
29
+ int G = p.grad;
30
+ scalar_t alpha = (scalar_t)p.alpha;
31
+ scalar_t gain = (scalar_t)p.gain;
32
+ scalar_t clamp = (scalar_t)p.clamp;
33
+ scalar_t one = (scalar_t)1;
34
+ scalar_t two = (scalar_t)2;
35
+ scalar_t expRange = (scalar_t)80;
36
+ scalar_t halfExpRange = (scalar_t)40;
37
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
38
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
39
+
40
+ // Loop over elements.
41
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
42
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
43
+ {
44
+ // Load.
45
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
46
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
47
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
48
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
49
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
50
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
51
+ scalar_t y = 0;
52
+
53
+ // Apply bias.
54
+ ((G == 0) ? x : xref) += b;
55
+
56
+ // linear
57
+ if (A == 1)
58
+ {
59
+ if (G == 0) y = x;
60
+ if (G == 1) y = x;
61
+ }
62
+
63
+ // relu
64
+ if (A == 2)
65
+ {
66
+ if (G == 0) y = (x > 0) ? x : 0;
67
+ if (G == 1) y = (yy > 0) ? x : 0;
68
+ }
69
+
70
+ // lrelu
71
+ if (A == 3)
72
+ {
73
+ if (G == 0) y = (x > 0) ? x : x * alpha;
74
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
75
+ }
76
+
77
+ // tanh
78
+ if (A == 4)
79
+ {
80
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
81
+ if (G == 1) y = x * (one - yy * yy);
82
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
83
+ }
84
+
85
+ // sigmoid
86
+ if (A == 5)
87
+ {
88
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
89
+ if (G == 1) y = x * yy * (one - yy);
90
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
91
+ }
92
+
93
+ // elu
94
+ if (A == 6)
95
+ {
96
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
97
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
98
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
99
+ }
100
+
101
+ // selu
102
+ if (A == 7)
103
+ {
104
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
105
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
106
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
107
+ }
108
+
109
+ // softplus
110
+ if (A == 8)
111
+ {
112
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
113
+ if (G == 1) y = x * (one - exp(-yy));
114
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
115
+ }
116
+
117
+ // swish
118
+ if (A == 9)
119
+ {
120
+ if (G == 0)
121
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
122
+ else
123
+ {
124
+ scalar_t c = exp(xref);
125
+ scalar_t d = c + one;
126
+ if (G == 1)
127
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
128
+ else
129
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
130
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
131
+ }
132
+ }
133
+
134
+ // Apply gain.
135
+ y *= gain * dy;
136
+
137
+ // Clamp.
138
+ if (clamp >= 0)
139
+ {
140
+ if (G == 0)
141
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
142
+ else
143
+ y = (yref > -clamp & yref < clamp) ? y : 0;
144
+ }
145
+
146
+ // Store.
147
+ ((T*)p.y)[xi] = (T)y;
148
+ }
149
+ }
150
+
151
+ //------------------------------------------------------------------------
152
+ // CUDA kernel selection.
153
+
154
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
155
+ {
156
+ if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
157
+ if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
158
+ if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
159
+ if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
160
+ if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
161
+ if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
162
+ if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
163
+ if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
164
+ if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
165
+ return NULL;
166
+ }
167
+
168
+ //------------------------------------------------------------------------
169
+ // Template specializations.
170
+
171
+ template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
172
+ template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
173
+ template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
174
+ template void* choose_bias_act_kernel<at::BFloat16> (const bias_act_kernel_params& p);
175
+
176
+ //------------------------------------------------------------------------
src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ //------------------------------------------------------------------------
10
+ // CUDA kernel parameters.
11
+
12
+ struct bias_act_kernel_params
13
+ {
14
+ const void* x; // [sizeX]
15
+ const void* b; // [sizeB] or NULL
16
+ const void* xref; // [sizeX] or NULL
17
+ const void* yref; // [sizeX] or NULL
18
+ const void* dy; // [sizeX] or NULL
19
+ void* y; // [sizeX]
20
+
21
+ int grad;
22
+ int act;
23
+ float alpha;
24
+ float gain;
25
+ float clamp;
26
+
27
+ int sizeX;
28
+ int sizeB;
29
+ int stepB;
30
+ int loopX;
31
+ };
32
+
33
+ //------------------------------------------------------------------------
34
+ // CUDA kernel selection.
35
+
36
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37
+
38
+ //------------------------------------------------------------------------
src/vqvaes/bsqvit/stylegan_utils/ops/bias_act.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient bias and activation."""
10
+
11
+ import os
12
+ import warnings
13
+ import numpy as np
14
+ import torch
15
+ import traceback
16
+ from typing import Any
17
+
18
+ from .. import custom_ops
19
+
20
+
21
+ class EasyDict(dict):
22
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
23
+
24
+ def __getattr__(self, name: str) -> Any:
25
+ try:
26
+ return self[name]
27
+ except KeyError:
28
+ raise AttributeError(name)
29
+
30
+ def __setattr__(self, name: str, value: Any) -> None:
31
+ self[name] = value
32
+
33
+ def __delattr__(self, name: str) -> None:
34
+ del self[name]
35
+
36
+ #----------------------------------------------------------------------------
37
+
38
+ activation_funcs = {
39
+ 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
40
+ 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
41
+ 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
42
+ 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
43
+ 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
44
+ 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
45
+ 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
46
+ 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
47
+ 'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
48
+ }
49
+
50
+ #----------------------------------------------------------------------------
51
+
52
+ _inited = False
53
+ _plugin = None
54
+ _null_tensor = torch.empty([0])
55
+
56
+ def _init():
57
+ global _inited, _plugin
58
+ if not _inited:
59
+ _inited = True
60
+ sources = ['bias_act.cpp', 'bias_act.cu']
61
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
62
+ try:
63
+ _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
64
+ except:
65
+ warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
66
+ return _plugin is not None
67
+
68
+ #----------------------------------------------------------------------------
69
+
70
+ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
71
+ r"""Fused bias and activation function.
72
+
73
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
74
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
75
+ the fused op is considerably more efficient than performing the same calculation
76
+ using standard PyTorch ops. It supports first and second order gradients,
77
+ but not third order gradients.
78
+
79
+ Args:
80
+ x: Input activation tensor. Can be of any shape.
81
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
82
+ as `x`. The shape must be known, and it must match the dimension of `x`
83
+ corresponding to `dim`.
84
+ dim: The dimension in `x` corresponding to the elements of `b`.
85
+ The value of `dim` is ignored if `b` is not specified.
86
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
87
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
88
+ See `activation_funcs` for a full list. `None` is not allowed.
89
+ alpha: Shape parameter for the activation function, or `None` to use the default.
90
+ gain: Scaling factor for the output tensor, or `None` to use default.
91
+ See `activation_funcs` for the default scaling of each activation function.
92
+ If unsure, consider specifying 1.
93
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
94
+ the clamping (default).
95
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
96
+
97
+ Returns:
98
+ Tensor of the same shape and datatype as `x`.
99
+ """
100
+ assert isinstance(x, torch.Tensor)
101
+ assert impl in ['ref', 'cuda']
102
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
103
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
104
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
105
+
106
+ #----------------------------------------------------------------------------
107
+
108
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
109
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
110
+ """
111
+ assert isinstance(x, torch.Tensor)
112
+ assert clamp is None or clamp >= 0
113
+ spec = activation_funcs[act]
114
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
115
+ gain = float(gain if gain is not None else spec.def_gain)
116
+ clamp = float(clamp if clamp is not None else -1)
117
+
118
+ # Add bias.
119
+ if b is not None:
120
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
121
+ assert 0 <= dim < x.ndim
122
+ assert b.shape[0] == x.shape[dim]
123
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
124
+
125
+ # Evaluate activation function.
126
+ alpha = float(alpha)
127
+ x = spec.func(x, alpha=alpha)
128
+
129
+ # Scale by gain.
130
+ gain = float(gain)
131
+ if gain != 1:
132
+ x = x * gain
133
+
134
+ # Clamp.
135
+ if clamp >= 0:
136
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
137
+ return x
138
+
139
+ #----------------------------------------------------------------------------
140
+
141
+ _bias_act_cuda_cache = dict()
142
+
143
+ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
144
+ """Fast CUDA implementation of `bias_act()` using custom ops.
145
+ """
146
+ # Parse arguments.
147
+ assert clamp is None or clamp >= 0
148
+ spec = activation_funcs[act]
149
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
150
+ gain = float(gain if gain is not None else spec.def_gain)
151
+ clamp = float(clamp if clamp is not None else -1)
152
+
153
+ # Lookup from cache.
154
+ key = (dim, act, alpha, gain, clamp)
155
+ if key in _bias_act_cuda_cache:
156
+ return _bias_act_cuda_cache[key]
157
+
158
+ # Forward op.
159
+ class BiasActCuda(torch.autograd.Function):
160
+ @staticmethod
161
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
162
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
163
+ x = x.contiguous(memory_format=ctx.memory_format)
164
+ b = b.contiguous() if b is not None else _null_tensor
165
+ y = x
166
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
167
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
168
+ ctx.save_for_backward(
169
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
170
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
171
+ y if 'y' in spec.ref else _null_tensor)
172
+ return y
173
+
174
+ @staticmethod
175
+ def backward(ctx, dy): # pylint: disable=arguments-differ
176
+ dy = dy.contiguous(memory_format=ctx.memory_format)
177
+ x, b, y = ctx.saved_tensors
178
+ dx = None
179
+ db = None
180
+
181
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
182
+ dx = dy
183
+ if act != 'linear' or gain != 1 or clamp >= 0:
184
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
185
+
186
+ if ctx.needs_input_grad[1]:
187
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
188
+
189
+ return dx, db
190
+
191
+ # Backward op.
192
+ class BiasActCudaGrad(torch.autograd.Function):
193
+ @staticmethod
194
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
195
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
196
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
197
+ ctx.save_for_backward(
198
+ dy if spec.has_2nd_grad else _null_tensor,
199
+ x, b, y)
200
+ return dx
201
+
202
+ @staticmethod
203
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
204
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
205
+ dy, x, b, y = ctx.saved_tensors
206
+ d_dy = None
207
+ d_x = None
208
+ d_b = None
209
+ d_y = None
210
+
211
+ if ctx.needs_input_grad[0]:
212
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
213
+
214
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
215
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
216
+
217
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
218
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
219
+
220
+ return d_dy, d_x, d_b, d_y
221
+
222
+ # Add to cache.
223
+ _bias_act_cuda_cache[key] = BiasActCuda
224
+ return BiasActCuda
225
+
226
+ #----------------------------------------------------------------------------
src/vqvaes/bsqvit/stylegan_utils/ops/conv2d_gradfix.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.conv2d` that supports
10
+ arbitrarily high order gradients with zero performance penalty."""
11
+
12
+ import warnings
13
+ import contextlib
14
+ import torch
15
+
16
+ # pylint: disable=redefined-builtin
17
+ # pylint: disable=arguments-differ
18
+ # pylint: disable=protected-access
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ enabled = False # Enable the custom op by setting this to true.
23
+ weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
24
+
25
+ @contextlib.contextmanager
26
+ def no_weight_gradients():
27
+ global weight_gradients_disabled
28
+ old = weight_gradients_disabled
29
+ weight_gradients_disabled = True
30
+ yield
31
+ weight_gradients_disabled = old
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36
+ if _should_use_custom_op(input):
37
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39
+
40
+ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41
+ if _should_use_custom_op(input):
42
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44
+
45
+ #----------------------------------------------------------------------------
46
+
47
+ def _should_use_custom_op(input):
48
+ assert isinstance(input, torch.Tensor)
49
+ if (not enabled) or (not torch.backends.cudnn.enabled):
50
+ return False
51
+ if input.device.type != 'cuda':
52
+ return False
53
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
54
+ return True
55
+ warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
56
+ return False
57
+
58
+ def _tuple_of_ints(xs, ndim):
59
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
60
+ assert len(xs) == ndim
61
+ assert all(isinstance(x, int) for x in xs)
62
+ return xs
63
+
64
+ #----------------------------------------------------------------------------
65
+
66
+ _conv2d_gradfix_cache = dict()
67
+
68
+ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
69
+ # Parse arguments.
70
+ ndim = 2
71
+ weight_shape = tuple(weight_shape)
72
+ stride = _tuple_of_ints(stride, ndim)
73
+ padding = _tuple_of_ints(padding, ndim)
74
+ output_padding = _tuple_of_ints(output_padding, ndim)
75
+ dilation = _tuple_of_ints(dilation, ndim)
76
+
77
+ # Lookup from cache.
78
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
79
+ if key in _conv2d_gradfix_cache:
80
+ return _conv2d_gradfix_cache[key]
81
+
82
+ # Validate arguments.
83
+ assert groups >= 1
84
+ assert len(weight_shape) == ndim + 2
85
+ assert all(stride[i] >= 1 for i in range(ndim))
86
+ assert all(padding[i] >= 0 for i in range(ndim))
87
+ assert all(dilation[i] >= 0 for i in range(ndim))
88
+ if not transpose:
89
+ assert all(output_padding[i] == 0 for i in range(ndim))
90
+ else: # transpose
91
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
92
+
93
+ # Helpers.
94
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
95
+ def calc_output_padding(input_shape, output_shape):
96
+ if transpose:
97
+ return [0, 0]
98
+ return [
99
+ input_shape[i + 2]
100
+ - (output_shape[i + 2] - 1) * stride[i]
101
+ - (1 - 2 * padding[i])
102
+ - dilation[i] * (weight_shape[i + 2] - 1)
103
+ for i in range(ndim)
104
+ ]
105
+
106
+ # Forward & backward.
107
+ class Conv2d(torch.autograd.Function):
108
+ @staticmethod
109
+ def forward(ctx, input, weight, bias):
110
+ assert weight.shape == weight_shape
111
+ if not transpose:
112
+ output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
113
+ else: # transpose
114
+ output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
115
+ ctx.save_for_backward(input, weight)
116
+ return output
117
+
118
+ @staticmethod
119
+ def backward(ctx, grad_output):
120
+ input, weight = ctx.saved_tensors
121
+ grad_input = None
122
+ grad_weight = None
123
+ grad_bias = None
124
+
125
+ if ctx.needs_input_grad[0]:
126
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
127
+ grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
128
+ assert grad_input.shape == input.shape
129
+
130
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
131
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
132
+ assert grad_weight.shape == weight_shape
133
+
134
+ if ctx.needs_input_grad[2]:
135
+ grad_bias = grad_output.sum([0, 2, 3])
136
+
137
+ return grad_input, grad_weight, grad_bias
138
+
139
+ # Gradient with respect to the weights.
140
+ class Conv2dGradWeight(torch.autograd.Function):
141
+ @staticmethod
142
+ def forward(ctx, grad_output, input):
143
+ op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
144
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
145
+ grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
146
+ assert grad_weight.shape == weight_shape
147
+ ctx.save_for_backward(grad_output, input)
148
+ return grad_weight
149
+
150
+ @staticmethod
151
+ def backward(ctx, grad2_grad_weight):
152
+ grad_output, input = ctx.saved_tensors
153
+ grad2_grad_output = None
154
+ grad2_input = None
155
+
156
+ if ctx.needs_input_grad[0]:
157
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
158
+ assert grad2_grad_output.shape == grad_output.shape
159
+
160
+ if ctx.needs_input_grad[1]:
161
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
162
+ grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
163
+ assert grad2_input.shape == input.shape
164
+
165
+ return grad2_grad_output, grad2_input
166
+
167
+ _conv2d_gradfix_cache[key] = Conv2d
168
+ return Conv2d
169
+
170
+ #----------------------------------------------------------------------------
src/vqvaes/bsqvit/stylegan_utils/ops/conv2d_resample.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """2D convolution with optional up/downsampling."""
10
+
11
+ import torch
12
+
13
+ from .. import misc
14
+ from . import conv2d_gradfix
15
+ from . import upfirdn2d
16
+ from .upfirdn2d import _parse_padding
17
+ from .upfirdn2d import _get_filter_size
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ def _get_weight_shape(w):
22
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23
+ shape = [int(sz) for sz in w.shape]
24
+ misc.assert_shape(w, shape)
25
+ return shape
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31
+ """
32
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
33
+
34
+ # Flip weight if requested.
35
+ if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36
+ w = w.flip([2, 3])
37
+
38
+ # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
39
+ # 1x1 kernel + memory_format=channels_last + less than 64 channels.
40
+ if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
41
+ if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
42
+ if out_channels <= 4 and groups == 1:
43
+ in_shape = x.shape
44
+ x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
45
+ x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
46
+ else:
47
+ x = x.to(memory_format=torch.contiguous_format)
48
+ w = w.to(memory_format=torch.contiguous_format)
49
+ x = conv2d_gradfix.conv2d(x, w, groups=groups)
50
+ return x.to(memory_format=torch.channels_last)
51
+
52
+ # Otherwise => execute using conv2d_gradfix.
53
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
54
+ return op(x, w, stride=stride, padding=padding, groups=groups)
55
+
56
+ #----------------------------------------------------------------------------
57
+
58
+ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
59
+ r"""2D convolution with optional up/downsampling.
60
+
61
+ Padding is performed only once at the beginning, not between the operations.
62
+
63
+ Args:
64
+ x: Input tensor of shape
65
+ `[batch_size, in_channels, in_height, in_width]`.
66
+ w: Weight tensor of shape
67
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
68
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
69
+ calling upfirdn2d.setup_filter(). None = identity (default).
70
+ up: Integer upsampling factor (default: 1).
71
+ down: Integer downsampling factor (default: 1).
72
+ padding: Padding with respect to the upsampled image. Can be a single number
73
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
74
+ (default: 0).
75
+ groups: Split input channels into N groups (default: 1).
76
+ flip_weight: False = convolution, True = correlation (default: True).
77
+ flip_filter: False = convolution, True = correlation (default: False).
78
+
79
+ Returns:
80
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
81
+ """
82
+ # Validate arguments.
83
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
84
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
85
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
86
+ assert isinstance(up, int) and (up >= 1)
87
+ assert isinstance(down, int) and (down >= 1)
88
+ assert isinstance(groups, int) and (groups >= 1)
89
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
90
+ fw, fh = _get_filter_size(f)
91
+ px0, px1, py0, py1 = _parse_padding(padding)
92
+
93
+ # Adjust padding to account for up/downsampling.
94
+ if up > 1:
95
+ px0 += (fw + up - 1) // 2
96
+ px1 += (fw - up) // 2
97
+ py0 += (fh + up - 1) // 2
98
+ py1 += (fh - up) // 2
99
+ if down > 1:
100
+ px0 += (fw - down + 1) // 2
101
+ px1 += (fw - down) // 2
102
+ py0 += (fh - down + 1) // 2
103
+ py1 += (fh - down) // 2
104
+
105
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
106
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
107
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
108
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
109
+ return x
110
+
111
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
112
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
113
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
114
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
115
+ return x
116
+
117
+ # Fast path: downsampling only => use strided convolution.
118
+ if down > 1 and up == 1:
119
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
120
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
121
+ return x
122
+
123
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
124
+ if up > 1:
125
+ if groups == 1:
126
+ w = w.transpose(0, 1)
127
+ else:
128
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
129
+ w = w.transpose(1, 2)
130
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
131
+ px0 -= kw - 1
132
+ px1 -= kw - up
133
+ py0 -= kh - 1
134
+ py1 -= kh - up
135
+ pxt = max(min(-px0, -px1), 0)
136
+ pyt = max(min(-py0, -py1), 0)
137
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
138
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
139
+ if down > 1:
140
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
141
+ return x
142
+
143
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
144
+ if up == 1 and down == 1:
145
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
146
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
147
+
148
+ # Fallback: Generic reference implementation.
149
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
150
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
151
+ if down > 1:
152
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
153
+ return x
154
+
155
+ #----------------------------------------------------------------------------
src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.cpp ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "upfirdn2d.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17
+ {
18
+ // Validate arguments.
19
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
25
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
26
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
27
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
28
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
29
+
30
+ // Create output tensor.
31
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
32
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
33
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
34
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
35
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
36
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
37
+
38
+ // Initialize CUDA kernel parameters.
39
+ upfirdn2d_kernel_params p;
40
+ p.x = x.data_ptr();
41
+ p.f = f.data_ptr<float>();
42
+ p.y = y.data_ptr();
43
+ p.up = make_int2(upx, upy);
44
+ p.down = make_int2(downx, downy);
45
+ p.pad0 = make_int2(padx0, pady0);
46
+ p.flip = (flip) ? 1 : 0;
47
+ p.gain = gain;
48
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
49
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
50
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
51
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
52
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
53
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
54
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
55
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
56
+
57
+ // Choose CUDA kernel.
58
+ upfirdn2d_kernel_spec spec;
59
+ AT_DISPATCH_REDUCED_FLOATING_TYPES(x.scalar_type(), "upfirdn2d_cuda", [&]
60
+ {
61
+ spec = choose_upfirdn2d_kernel<scalar_t>(p);
62
+ });
63
+
64
+ // Set looping options.
65
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
66
+ p.loopMinor = spec.loopMinor;
67
+ p.loopX = spec.loopX;
68
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
69
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
70
+
71
+ // Compute grid size.
72
+ dim3 blockSize, gridSize;
73
+ if (spec.tileOutW < 0) // large
74
+ {
75
+ blockSize = dim3(4, 32, 1);
76
+ gridSize = dim3(
77
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
78
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
79
+ p.launchMajor);
80
+ }
81
+ else // small
82
+ {
83
+ blockSize = dim3(256, 1, 1);
84
+ gridSize = dim3(
85
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
86
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
87
+ p.launchMajor);
88
+ }
89
+
90
+ // Launch CUDA kernel.
91
+ void* args[] = {&p};
92
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93
+ return y;
94
+ }
95
+
96
+ //------------------------------------------------------------------------
97
+
98
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99
+ {
100
+ m.def("upfirdn2d", &upfirdn2d);
101
+ }
102
+
103
+ //------------------------------------------------------------------------
src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.cu ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <c10/util/Half.h>
11
+ #include "upfirdn2d.h"
12
+
13
+ //------------------------------------------------------------------------
14
+ // Helpers.
15
+
16
+ template <class T> struct InternalType;
17
+ template <> struct InternalType<double> { typedef double scalar_t; };
18
+ template <> struct InternalType<float> { typedef float scalar_t; };
19
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
20
+ template <> struct InternalType<at::BFloat16> { typedef float scalar_t; };
21
+
22
+ static __device__ __forceinline__ int floor_div(int a, int b)
23
+ {
24
+ int t = 1 - a / b;
25
+ return (a + t * b) / b - t;
26
+ }
27
+
28
+ //------------------------------------------------------------------------
29
+ // Generic CUDA implementation for large filters.
30
+
31
+ template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
32
+ {
33
+ typedef typename InternalType<T>::scalar_t scalar_t;
34
+
35
+ // Calculate thread index.
36
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
37
+ int outY = minorBase / p.launchMinor;
38
+ minorBase -= outY * p.launchMinor;
39
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
40
+ int majorBase = blockIdx.z * p.loopMajor;
41
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
42
+ return;
43
+
44
+ // Setup Y receptive field.
45
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
46
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
47
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
48
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
49
+ if (p.flip)
50
+ filterY = p.filterSize.y - 1 - filterY;
51
+
52
+ // Loop over major, minor, and X.
53
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
54
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
55
+ {
56
+ int nc = major * p.sizeMinor + minor;
57
+ int n = nc / p.inSize.z;
58
+ int c = nc - n * p.inSize.z;
59
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
60
+ {
61
+ // Setup X receptive field.
62
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
63
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
64
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
65
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
66
+ if (p.flip)
67
+ filterX = p.filterSize.x - 1 - filterX;
68
+
69
+ // Initialize pointers.
70
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
71
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
72
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
73
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
74
+
75
+ // Inner loop.
76
+ scalar_t v = 0;
77
+ for (int y = 0; y < h; y++)
78
+ {
79
+ for (int x = 0; x < w; x++)
80
+ {
81
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
82
+ xp += p.inStride.x;
83
+ fp += filterStepX;
84
+ }
85
+ xp += p.inStride.y - w * p.inStride.x;
86
+ fp += filterStepY - w * filterStepX;
87
+ }
88
+
89
+ // Store result.
90
+ v *= p.gain;
91
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
92
+ }
93
+ }
94
+ }
95
+
96
+ //------------------------------------------------------------------------
97
+ // Specialized CUDA implementation for small filters.
98
+
99
+ template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
100
+ static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
101
+ {
102
+ typedef typename InternalType<T>::scalar_t scalar_t;
103
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
104
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
105
+ __shared__ volatile scalar_t sf[filterH][filterW];
106
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
107
+
108
+ // Calculate tile index.
109
+ int minorBase = blockIdx.x;
110
+ int tileOutY = minorBase / p.launchMinor;
111
+ minorBase -= tileOutY * p.launchMinor;
112
+ minorBase *= loopMinor;
113
+ tileOutY *= tileOutH;
114
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
115
+ int majorBase = blockIdx.z * p.loopMajor;
116
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
117
+ return;
118
+
119
+ // Load filter (flipped).
120
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
121
+ {
122
+ int fy = tapIdx / filterW;
123
+ int fx = tapIdx - fy * filterW;
124
+ scalar_t v = 0;
125
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
126
+ {
127
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
128
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
129
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
130
+ }
131
+ sf[fy][fx] = v;
132
+ }
133
+
134
+ // Loop over major and X.
135
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
136
+ {
137
+ int baseNC = major * p.sizeMinor + minorBase;
138
+ int n = baseNC / p.inSize.z;
139
+ int baseC = baseNC - n * p.inSize.z;
140
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
141
+ {
142
+ // Load input pixels.
143
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
144
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
145
+ int tileInX = floor_div(tileMidX, upx);
146
+ int tileInY = floor_div(tileMidY, upy);
147
+ __syncthreads();
148
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
149
+ {
150
+ int relC = inIdx;
151
+ int relInX = relC / loopMinor;
152
+ int relInY = relInX / tileInW;
153
+ relC -= relInX * loopMinor;
154
+ relInX -= relInY * tileInW;
155
+ int c = baseC + relC;
156
+ int inX = tileInX + relInX;
157
+ int inY = tileInY + relInY;
158
+ scalar_t v = 0;
159
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
160
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
161
+ sx[relInY][relInX][relC] = v;
162
+ }
163
+
164
+ // Loop over output pixels.
165
+ __syncthreads();
166
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
167
+ {
168
+ int relC = outIdx;
169
+ int relOutX = relC / loopMinor;
170
+ int relOutY = relOutX / tileOutW;
171
+ relC -= relOutX * loopMinor;
172
+ relOutX -= relOutY * tileOutW;
173
+ int c = baseC + relC;
174
+ int outX = tileOutX + relOutX;
175
+ int outY = tileOutY + relOutY;
176
+
177
+ // Setup receptive field.
178
+ int midX = tileMidX + relOutX * downx;
179
+ int midY = tileMidY + relOutY * downy;
180
+ int inX = floor_div(midX, upx);
181
+ int inY = floor_div(midY, upy);
182
+ int relInX = inX - tileInX;
183
+ int relInY = inY - tileInY;
184
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
185
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
186
+
187
+ // Inner loop.
188
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
189
+ {
190
+ scalar_t v = 0;
191
+ #pragma unroll
192
+ for (int y = 0; y < filterH / upy; y++)
193
+ #pragma unroll
194
+ for (int x = 0; x < filterW / upx; x++)
195
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
196
+ v *= p.gain;
197
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
198
+ }
199
+ }
200
+ }
201
+ }
202
+ }
203
+
204
+ //------------------------------------------------------------------------
205
+ // CUDA kernel selection.
206
+
207
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
208
+ {
209
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
210
+
211
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
212
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
213
+
214
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
215
+ {
216
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
217
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
218
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
219
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
220
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
221
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
222
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
223
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
224
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
225
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
226
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
227
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
228
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
229
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
230
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
231
+ }
232
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
233
+ {
234
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
235
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
236
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
237
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
238
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
239
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
240
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
241
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
242
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
243
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
244
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
245
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
246
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
247
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
248
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
249
+ }
250
+ if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
251
+ {
252
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
253
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
254
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
255
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
256
+ }
257
+ if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
258
+ {
259
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
260
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
261
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
262
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
263
+ }
264
+ if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
265
+ {
266
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
267
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
268
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
269
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
270
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
271
+ }
272
+ if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
273
+ {
274
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
275
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
276
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
277
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
278
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
279
+ }
280
+ if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
281
+ {
282
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
283
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
284
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
285
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
286
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
287
+ }
288
+ if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
289
+ {
290
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
291
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
292
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
293
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
294
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
295
+ }
296
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
297
+ {
298
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
299
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
300
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
301
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
302
+ }
303
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
304
+ {
305
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
306
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
307
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
308
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
309
+ }
310
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
311
+ {
312
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
313
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
314
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
315
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
316
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
317
+ }
318
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
319
+ {
320
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
321
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
322
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
323
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
324
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
325
+ }
326
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
327
+ {
328
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
329
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
330
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
331
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
332
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
333
+ }
334
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
335
+ {
336
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
337
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
338
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
339
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
340
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
341
+ }
342
+ return spec;
343
+ }
344
+
345
+ //------------------------------------------------------------------------
346
+ // Template specializations.
347
+
348
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
349
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
350
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half> (const upfirdn2d_kernel_params& p);
351
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<at::BFloat16> (const upfirdn2d_kernel_params& p);
352
+
353
+ //------------------------------------------------------------------------
src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct upfirdn2d_kernel_params
15
+ {
16
+ const void* x;
17
+ const float* f;
18
+ void* y;
19
+
20
+ int2 up;
21
+ int2 down;
22
+ int2 pad0;
23
+ int flip;
24
+ float gain;
25
+
26
+ int4 inSize; // [width, height, channel, batch]
27
+ int4 inStride;
28
+ int2 filterSize; // [width, height]
29
+ int2 filterStride;
30
+ int4 outSize; // [width, height, channel, batch]
31
+ int4 outStride;
32
+ int sizeMinor;
33
+ int sizeMajor;
34
+
35
+ int loopMinor;
36
+ int loopMajor;
37
+ int loopX;
38
+ int launchMinor;
39
+ int launchMajor;
40
+ };
41
+
42
+ //------------------------------------------------------------------------
43
+ // CUDA kernel specialization.
44
+
45
+ struct upfirdn2d_kernel_spec
46
+ {
47
+ void* kernel;
48
+ int tileOutW;
49
+ int tileOutH;
50
+ int loopMinor;
51
+ int loopX;
52
+ };
53
+
54
+ //------------------------------------------------------------------------
55
+ // CUDA kernel selection.
56
+
57
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58
+
59
+ //------------------------------------------------------------------------
src/vqvaes/bsqvit/stylegan_utils/ops/upfirdn2d.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient resampling of 2D images."""
10
+
11
+ import os
12
+ import warnings
13
+ import numpy as np
14
+ import torch
15
+ import traceback
16
+
17
+ from .. import custom_ops, misc
18
+ from . import conv2d_gradfix
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ _inited = False
23
+ _plugin = None
24
+
25
+ def _init():
26
+ global _inited, _plugin
27
+ if not _inited:
28
+ sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
29
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
30
+ try:
31
+ _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
32
+ except:
33
+ warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
34
+ return _plugin is not None
35
+
36
+ def _parse_scaling(scaling):
37
+ if isinstance(scaling, int):
38
+ scaling = [scaling, scaling]
39
+ assert isinstance(scaling, (list, tuple))
40
+ assert all(isinstance(x, int) for x in scaling)
41
+ sx, sy = scaling
42
+ assert sx >= 1 and sy >= 1
43
+ return sx, sy
44
+
45
+ def _parse_padding(padding):
46
+ if isinstance(padding, int):
47
+ padding = [padding, padding]
48
+ assert isinstance(padding, (list, tuple))
49
+ assert all(isinstance(x, int) for x in padding)
50
+ if len(padding) == 2:
51
+ padx, pady = padding
52
+ padding = [padx, padx, pady, pady]
53
+ padx0, padx1, pady0, pady1 = padding
54
+ return padx0, padx1, pady0, pady1
55
+
56
+ def _get_filter_size(f):
57
+ if f is None:
58
+ return 1, 1
59
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
60
+ fw = f.shape[-1]
61
+ fh = f.shape[0]
62
+ with misc.suppress_tracer_warnings():
63
+ fw = int(fw)
64
+ fh = int(fh)
65
+ misc.assert_shape(f, [fh, fw][:f.ndim])
66
+ assert fw >= 1 and fh >= 1
67
+ return fw, fh
68
+
69
+ #----------------------------------------------------------------------------
70
+
71
+ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
72
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
73
+
74
+ Args:
75
+ f: Torch tensor, numpy array, or python list of the shape
76
+ `[filter_height, filter_width]` (non-separable),
77
+ `[filter_taps]` (separable),
78
+ `[]` (impulse), or
79
+ `None` (identity).
80
+ device: Result device (default: cpu).
81
+ normalize: Normalize the filter so that it retains the magnitude
82
+ for constant input signal (DC)? (default: True).
83
+ flip_filter: Flip the filter? (default: False).
84
+ gain: Overall scaling factor for signal magnitude (default: 1).
85
+ separable: Return a separable filter? (default: select automatically).
86
+
87
+ Returns:
88
+ Float32 tensor of the shape
89
+ `[filter_height, filter_width]` (non-separable) or
90
+ `[filter_taps]` (separable).
91
+ """
92
+ # Validate.
93
+ if f is None:
94
+ f = 1
95
+ f = torch.as_tensor(f, dtype=torch.float32)
96
+ assert f.ndim in [0, 1, 2]
97
+ assert f.numel() > 0
98
+ if f.ndim == 0:
99
+ f = f[np.newaxis]
100
+
101
+ # Separable?
102
+ if separable is None:
103
+ separable = (f.ndim == 1 and f.numel() >= 8)
104
+ if f.ndim == 1 and not separable:
105
+ f = f.ger(f)
106
+ assert f.ndim == (1 if separable else 2)
107
+
108
+ # Apply normalize, flip, gain, and device.
109
+ if normalize:
110
+ f /= f.sum()
111
+ if flip_filter:
112
+ f = f.flip(list(range(f.ndim)))
113
+ f = f * (gain ** (f.ndim / 2))
114
+ f = f.to(device=device)
115
+ return f
116
+
117
+ #----------------------------------------------------------------------------
118
+
119
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
120
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
121
+
122
+ Performs the following sequence of operations for each channel:
123
+
124
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
125
+
126
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
127
+ Negative padding corresponds to cropping the image.
128
+
129
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
130
+ so that the footprint of all output pixels lies within the input image.
131
+
132
+ 4. Downsample the image by keeping every Nth pixel (`down`).
133
+
134
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
135
+ The fused op is considerably more efficient than performing the same calculation
136
+ using standard PyTorch ops. It supports gradients of arbitrary order.
137
+
138
+ Args:
139
+ x: Float32/float64/float16 input tensor of the shape
140
+ `[batch_size, num_channels, in_height, in_width]`.
141
+ f: Float32 FIR filter of the shape
142
+ `[filter_height, filter_width]` (non-separable),
143
+ `[filter_taps]` (separable), or
144
+ `None` (identity).
145
+ up: Integer upsampling factor. Can be a single int or a list/tuple
146
+ `[x, y]` (default: 1).
147
+ down: Integer downsampling factor. Can be a single int or a list/tuple
148
+ `[x, y]` (default: 1).
149
+ padding: Padding with respect to the upsampled image. Can be a single number
150
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
151
+ (default: 0).
152
+ flip_filter: False = convolution, True = correlation (default: False).
153
+ gain: Overall scaling factor for signal magnitude (default: 1).
154
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
155
+
156
+ Returns:
157
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
158
+ """
159
+ assert isinstance(x, torch.Tensor)
160
+ assert impl in ['ref', 'cuda']
161
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
162
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
163
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
164
+
165
+ #----------------------------------------------------------------------------
166
+
167
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
168
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
169
+ """
170
+ # Validate arguments.
171
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
172
+ if f is None:
173
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
174
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
175
+ assert f.dtype == torch.float32 and not f.requires_grad
176
+ batch_size, num_channels, in_height, in_width = x.shape
177
+ upx, upy = _parse_scaling(up)
178
+ downx, downy = _parse_scaling(down)
179
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
180
+
181
+ # Upsample by inserting zeros.
182
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
183
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
184
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
185
+
186
+ # Pad or crop.
187
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
188
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
189
+
190
+ # Setup filter.
191
+ f = f * (gain ** (f.ndim / 2))
192
+ f = f.to(x.dtype)
193
+ if not flip_filter:
194
+ f = f.flip(list(range(f.ndim)))
195
+
196
+ # Convolve with the filter.
197
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
198
+ if f.ndim == 4:
199
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
200
+ else:
201
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
202
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
203
+
204
+ # Downsample by throwing away pixels.
205
+ x = x[:, :, ::downy, ::downx]
206
+ return x
207
+
208
+ #----------------------------------------------------------------------------
209
+
210
+ _upfirdn2d_cuda_cache = dict()
211
+
212
+ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
213
+ """Fast CUDA implementation of `upfirdn2d()` using custom ops.
214
+ """
215
+ # Parse arguments.
216
+ upx, upy = _parse_scaling(up)
217
+ downx, downy = _parse_scaling(down)
218
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
219
+
220
+ # Lookup from cache.
221
+ key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
222
+ if key in _upfirdn2d_cuda_cache:
223
+ return _upfirdn2d_cuda_cache[key]
224
+
225
+ # Forward op.
226
+ class Upfirdn2dCuda(torch.autograd.Function):
227
+ @staticmethod
228
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
229
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
230
+ if f is None:
231
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
232
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
233
+ y = x
234
+ if f.ndim == 2:
235
+ y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
236
+ else:
237
+ y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
238
+ y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
239
+ ctx.save_for_backward(f)
240
+ ctx.x_shape = x.shape
241
+ return y
242
+
243
+ @staticmethod
244
+ def backward(ctx, dy): # pylint: disable=arguments-differ
245
+ f, = ctx.saved_tensors
246
+ _, _, ih, iw = ctx.x_shape
247
+ _, _, oh, ow = dy.shape
248
+ fw, fh = _get_filter_size(f)
249
+ p = [
250
+ fw - padx0 - 1,
251
+ iw * upx - ow * downx + padx0 - upx + 1,
252
+ fh - pady0 - 1,
253
+ ih * upy - oh * downy + pady0 - upy + 1,
254
+ ]
255
+ dx = None
256
+ df = None
257
+
258
+ if ctx.needs_input_grad[0]:
259
+ dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
260
+
261
+ assert not ctx.needs_input_grad[1]
262
+ return dx, df
263
+
264
+ # Add to cache.
265
+ _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
266
+ return Upfirdn2dCuda
267
+
268
+ #----------------------------------------------------------------------------
269
+
270
+ def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
271
+ r"""Filter a batch of 2D images using the given 2D FIR filter.
272
+
273
+ By default, the result is padded so that its shape matches the input.
274
+ User-specified padding is applied on top of that, with negative values
275
+ indicating cropping. Pixels outside the image are assumed to be zero.
276
+
277
+ Args:
278
+ x: Float32/float64/float16 input tensor of the shape
279
+ `[batch_size, num_channels, in_height, in_width]`.
280
+ f: Float32 FIR filter of the shape
281
+ `[filter_height, filter_width]` (non-separable),
282
+ `[filter_taps]` (separable), or
283
+ `None` (identity).
284
+ padding: Padding with respect to the output. Can be a single number or a
285
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
286
+ (default: 0).
287
+ flip_filter: False = convolution, True = correlation (default: False).
288
+ gain: Overall scaling factor for signal magnitude (default: 1).
289
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
290
+
291
+ Returns:
292
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
293
+ """
294
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
295
+ fw, fh = _get_filter_size(f)
296
+ p = [
297
+ padx0 + fw // 2,
298
+ padx1 + (fw - 1) // 2,
299
+ pady0 + fh // 2,
300
+ pady1 + (fh - 1) // 2,
301
+ ]
302
+ return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
303
+
304
+ #----------------------------------------------------------------------------
305
+
306
+ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
307
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
308
+
309
+ By default, the result is padded so that its shape is a multiple of the input.
310
+ User-specified padding is applied on top of that, with negative values
311
+ indicating cropping. Pixels outside the image are assumed to be zero.
312
+
313
+ Args:
314
+ x: Float32/float64/float16 input tensor of the shape
315
+ `[batch_size, num_channels, in_height, in_width]`.
316
+ f: Float32 FIR filter of the shape
317
+ `[filter_height, filter_width]` (non-separable),
318
+ `[filter_taps]` (separable), or
319
+ `None` (identity).
320
+ up: Integer upsampling factor. Can be a single int or a list/tuple
321
+ `[x, y]` (default: 1).
322
+ padding: Padding with respect to the output. Can be a single number or a
323
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
324
+ (default: 0).
325
+ flip_filter: False = convolution, True = correlation (default: False).
326
+ gain: Overall scaling factor for signal magnitude (default: 1).
327
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
328
+
329
+ Returns:
330
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
331
+ """
332
+ upx, upy = _parse_scaling(up)
333
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
334
+ fw, fh = _get_filter_size(f)
335
+ p = [
336
+ padx0 + (fw + upx - 1) // 2,
337
+ padx1 + (fw - upx) // 2,
338
+ pady0 + (fh + upy - 1) // 2,
339
+ pady1 + (fh - upy) // 2,
340
+ ]
341
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
342
+
343
+ #----------------------------------------------------------------------------
344
+
345
+ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
346
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
347
+
348
+ By default, the result is padded so that its shape is a fraction of the input.
349
+ User-specified padding is applied on top of that, with negative values
350
+ indicating cropping. Pixels outside the image are assumed to be zero.
351
+
352
+ Args:
353
+ x: Float32/float64/float16 input tensor of the shape
354
+ `[batch_size, num_channels, in_height, in_width]`.
355
+ f: Float32 FIR filter of the shape
356
+ `[filter_height, filter_width]` (non-separable),
357
+ `[filter_taps]` (separable), or
358
+ `None` (identity).
359
+ down: Integer downsampling factor. Can be a single int or a list/tuple
360
+ `[x, y]` (default: 1).
361
+ padding: Padding with respect to the input. Can be a single number or a
362
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
363
+ (default: 0).
364
+ flip_filter: False = convolution, True = correlation (default: False).
365
+ gain: Overall scaling factor for signal magnitude (default: 1).
366
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
367
+
368
+ Returns:
369
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
370
+ """
371
+ downx, downy = _parse_scaling(down)
372
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
373
+ fw, fh = _get_filter_size(f)
374
+ p = [
375
+ padx0 + (fw - downx + 1) // 2,
376
+ padx1 + (fw - downx) // 2,
377
+ pady0 + (fh - downy + 1) // 2,
378
+ pady1 + (fh - downy) // 2,
379
+ ]
380
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
381
+
382
+ #----------------------------------------------------------------------------
src/vqvaes/bsqvit/transformer.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Callable, Optional, Union
3
+ from einops import rearrange
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.checkpoint import checkpoint
7
+ from timm.models.layers import to_2tuple
8
+ from timm.models.layers import trunc_normal_
9
+ from timm.models.layers import DropPath
10
+
11
+ from .attention_mask import get_attention_mask
12
+
13
+
14
+ class LayerScale(nn.Module):
15
+ def __init__(self, dim, init_values=1e-5, inplace=False):
16
+ super().__init__()
17
+ self.inplace = inplace
18
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
19
+
20
+ def forward(self, x):
21
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
22
+
23
+
24
+ class ResidualAttentionBlock(nn.Module):
25
+ def __init__(
26
+ self,
27
+ d_model: int,
28
+ n_head: int,
29
+ mlp_ratio: float = 4.0,
30
+ ls_init_value: float = None,
31
+ drop: float = 0.,
32
+ attn_drop: float = 0.,
33
+ drop_path: float = 0.,
34
+ act_layer: Callable = nn.GELU,
35
+ norm_layer: Callable = nn.LayerNorm,
36
+ use_preln: bool = True,
37
+ ):
38
+ super().__init__()
39
+
40
+ self.ln_1 = norm_layer(d_model)
41
+ self.attn = nn.MultiheadAttention(d_model, n_head, dropout=attn_drop)
42
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
43
+
44
+ self.ln_2 = norm_layer(d_model)
45
+ mlp_width = int(d_model * mlp_ratio)
46
+ self.mlp = nn.Sequential(OrderedDict([
47
+ ("c_fc", nn.Linear(d_model, mlp_width)),
48
+ ("gelu", act_layer()),
49
+ # disable this following JAX implementation.
50
+ # Reference: https://github.com/google-research/magvit/blob/main/videogvt/models/simplified_bert.py#L112
51
+ # ("drop1", nn.Dropout(drop)),
52
+ ("c_proj", nn.Linear(mlp_width, d_model)),
53
+ ("drop2", nn.Dropout(drop)),
54
+ ]))
55
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
56
+
57
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
58
+
59
+ self.use_preln = use_preln
60
+
61
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False):
62
+ attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
63
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask, is_causal=is_causal)[0]
64
+
65
+ def checkpoint_forward(self, x: torch.Tensor,
66
+ attn_mask: Optional[torch.Tensor] = None,
67
+ is_causal: bool = False):
68
+ state = x
69
+ if self.use_preln:
70
+ x = checkpoint(self.ln_1, x, use_reentrant=False)
71
+ x = self.attention(x, attn_mask, is_causal)
72
+ x = checkpoint(self.ls_1, x, use_reentrant=False)
73
+ state = state + self.drop_path(x)
74
+ x = checkpoint(self.ln_2, state, use_reentrant=False)
75
+ x = self.mlp(x)
76
+ x = checkpoint(self.ls_2, x, use_reentrant=False)
77
+ state = state + self.drop_path(x)
78
+ else:
79
+ x = self.attention(x, attn_mask, is_causal)
80
+ x = state + self.drop_path(x)
81
+ state = checkpoint(self.ln_1, x, use_reentrant=False)
82
+ x = self.mlp(state)
83
+ state = state + self.drop_path(x)
84
+ state = checkpoint(self.ln_2, state, use_reentrant=False)
85
+ return state
86
+
87
+ def forward(self, x: torch.Tensor,
88
+ attn_mask: Optional[torch.Tensor] = None, is_causal: bool =False,
89
+ selective_checkpointing: bool = False):
90
+ if selective_checkpointing:
91
+ return self.checkpoint_forward(x, attn_mask, is_causal=is_causal)
92
+ if self.use_preln:
93
+ x = x + self.drop_path(self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask, is_causal=is_causal)))
94
+ x = x + self.drop_path(self.ls_2(self.mlp(self.ln_2(x))))
95
+ else:
96
+ x = x + self.drop_path(self.attention(x, attn_mask=attn_mask, is_causal=is_causal))
97
+ x = self.ln_1(x)
98
+ x = x + self.drop_path(self.mlp(x))
99
+ x = self.ln_2(x)
100
+ return x
101
+
102
+
103
+ class Transformer(nn.Module):
104
+ def __init__(self,
105
+ width: int,
106
+ layers: int,
107
+ heads: int,
108
+ mlp_ratio: float = 4.0,
109
+ ls_init_value: float = None,
110
+ drop: float = 0.,
111
+ attn_drop: float = 0.,
112
+ drop_path: float = 0.,
113
+ act_layer: nn.Module = nn.GELU,
114
+ norm_layer: nn.Module = nn.LayerNorm,
115
+ use_preln: bool = True,
116
+ ):
117
+ super().__init__()
118
+ self.width = width
119
+ self.layers = layers
120
+ self.grad_checkpointing = False
121
+ self.selective_checkpointing = False
122
+ self.grad_checkpointing_params = {'use_reentrant': False}
123
+ if attn_drop == 0 and drop_path == 0 and drop_path == 0:
124
+ self.grad_checkpointing_params.update({'preserve_rng_state': False})
125
+ else:
126
+ self.grad_checkpointing_params.update({'preserve_rng_state': True})
127
+
128
+ self.resblocks = nn.ModuleList([
129
+ ResidualAttentionBlock(
130
+ width, heads, mlp_ratio, ls_init_value=ls_init_value,
131
+ drop=drop, attn_drop=attn_drop, drop_path=drop_path,
132
+ act_layer=act_layer, norm_layer=norm_layer,
133
+ use_preln=use_preln)
134
+ for _ in range(layers)
135
+ ])
136
+
137
+ def forward(self, x: torch.Tensor,
138
+ attn_mask: Optional[torch.Tensor] = None,
139
+ is_causal: bool =False):
140
+ for r in self.resblocks:
141
+ if self.training and self.grad_checkpointing and not torch.jit.is_scripting():
142
+ if not self.selective_checkpointing:
143
+ x = checkpoint(r, x, attn_mask, is_causal=is_causal, **self.grad_checkpointing_params)
144
+ else:
145
+ x = r(x, attn_mask=attn_mask, is_causal=is_causal, selective_checkpointing=True)
146
+ else:
147
+ x = r(x, attn_mask=attn_mask)
148
+ return x
149
+
150
+
151
+ class TransformerEncoder(nn.Module):
152
+ def __init__(self,
153
+ image_size: int,
154
+ patch_size: int,
155
+ width: int,
156
+ layers: int,
157
+ heads: int,
158
+ mlp_ratio: float,
159
+ num_frames: int = 1,
160
+ cross_frames: bool = True,
161
+ ls_init_value: float = None,
162
+ drop_rate: float = 0.,
163
+ attn_drop_rate: float = 0.,
164
+ drop_path_rate: float = 0.,
165
+ ln_pre: bool = True,
166
+ ln_post: bool = True,
167
+ act_layer: str = 'gelu',
168
+ norm_layer: str = 'layer_norm',
169
+ mask_type: Union[str, None] = 'none',
170
+ mask_block_size: int = -1
171
+ ):
172
+ super().__init__()
173
+ self.image_size = to_2tuple(image_size)
174
+ self.patch_size = to_2tuple(patch_size)
175
+ self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
176
+ self.patches_per_frame = self.grid_size[0] * self.grid_size[1]
177
+ self.mask_type = mask_type
178
+ self.mask_block_size = mask_block_size
179
+
180
+ if act_layer.lower() == 'gelu':
181
+ self.act_layer = nn.GELU
182
+ else:
183
+ raise ValueError(f"Unsupported activation function: {act_layer}")
184
+ if norm_layer.lower() == 'layer_norm':
185
+ self.norm_layer = nn.LayerNorm
186
+ else:
187
+ raise ValueError(f"Unsupported normalization: {norm_layer}")
188
+
189
+ self.conv1 = nn.Linear(
190
+ in_features=3 * self.patch_size[0] * self.patch_size[1],
191
+ out_features=width,
192
+ bias=not ln_pre
193
+ )
194
+
195
+ scale = width ** -0.5
196
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1], width))
197
+ assert num_frames >= 1
198
+ self.num_frames = num_frames
199
+ self.cross_frames = cross_frames
200
+ if num_frames > 1 and cross_frames:
201
+ self.temporal_positional_embedding = nn.Parameter(torch.zeros(num_frames, width))
202
+ else:
203
+ self.temporal_positional_embedding = None
204
+
205
+ self.ln_pre = self.norm_layer(width) if ln_pre else nn.Identity()
206
+
207
+ self.transformer = Transformer(
208
+ width, layers, heads, mlp_ratio, ls_init_value=ls_init_value,
209
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate,
210
+ act_layer=self.act_layer, norm_layer=self.norm_layer,
211
+ )
212
+
213
+ self.ln_post = self.norm_layer(width)
214
+
215
+ self.init_parameters()
216
+
217
+ def init_parameters(self):
218
+ if self.positional_embedding is not None:
219
+ nn.init.normal_(self.positional_embedding, std=0.02)
220
+ trunc_normal_(self.conv1.weight, std=0.02)
221
+ for block in self.transformer.resblocks:
222
+ for n, p in block.named_parameters():
223
+ if 'weight' in n:
224
+ if 'ln' not in n:
225
+ trunc_normal_(p, std=0.02)
226
+ elif 'bias' in n:
227
+ nn.init.zeros_(p)
228
+ else:
229
+ raise NotImplementedError(f'Unknown parameters named {n}')
230
+
231
+ @torch.jit.ignore
232
+ def set_grad_checkpointing(self, enable=True, selective=False):
233
+ self.transformer.grad_checkpointing = enable
234
+ self.transformer.selective_checkpointing = selective
235
+
236
+
237
+ def forward(self, x):
238
+ if self.num_frames == 1:
239
+ x = rearrange(
240
+ x, "b c (hh sh) (ww sw) -> b (hh ww) (c sh sw)",
241
+ sh=self.patch_size[0], sw=self.patch_size[1]
242
+ )
243
+ x = self.conv1(x)
244
+ x = x + self.positional_embedding.to(x.dtype)
245
+ elif self.cross_frames:
246
+ num_frames = x.shape[2]
247
+ assert num_frames <= self.num_frames, 'Number of frames should be less or equal to the model setting'
248
+ x = rearrange(
249
+ x, "b c t (hh sh) (ww sw) -> b (t hh ww) (c sh sw)",
250
+ sh=self.patch_size[0], sw=self.patch_size[1]
251
+ )
252
+ x = self.conv1(x)
253
+ tile_pos_embed = self.positional_embedding.repeat(num_frames, 1)
254
+ tile_tem_embed = self.temporal_positional_embedding[:num_frames].repeat_interleave(self.patches_per_frame, 0)
255
+ total_pos_embed = tile_pos_embed + tile_tem_embed
256
+ x = x + total_pos_embed.to(x.dtype).squeeze(0)
257
+ else:
258
+ x = rearrange(
259
+ x, "b c t (hh sh) (ww sw) -> (b t) (hh ww) (c sh sw)",
260
+ sh=self.patch_size[0], sw=self.patch_size[1]
261
+ )
262
+ x = self.conv1(x)
263
+ x = x + self.positional_embedding.to(x.dtype)
264
+
265
+ x = self.ln_pre(x)
266
+ x = x.permute(1, 0, 2)
267
+ block_size = self.grid_size[0] * self.grid_size[1] if self.mask_block_size <= 0 else self.mask_block_size
268
+ attn_mask = get_attention_mask(x.size(0), x.device, mask_type=self.mask_type, block_size=block_size)
269
+ x = self.transformer(x, attn_mask, is_causal=self.mask_type == 'causal')
270
+ x = x.permute(1, 0, 2)
271
+ x = self.ln_post(x)
272
+
273
+ return x
274
+
275
+
276
+ class TransformerDecoder(nn.Module):
277
+ def __init__(self,
278
+ image_size: int,
279
+ patch_size: int,
280
+ width: int,
281
+ layers: int,
282
+ heads: int,
283
+ mlp_ratio: float,
284
+ num_frames: int = 1,
285
+ cross_frames: bool = True,
286
+ ls_init_value: float = None,
287
+ drop_rate: float = 0.,
288
+ attn_drop_rate: float = 0.,
289
+ drop_path_rate: float = 0.,
290
+ ln_pre: bool = True,
291
+ ln_post: bool = True,
292
+ act_layer: str = 'gelu',
293
+ norm_layer: str = 'layer_norm',
294
+ use_ffn_output: bool = True,
295
+ dim_ffn_output: int = 3072,
296
+ logit_laplace: bool = False,
297
+ mask_type: Union[str, None] = 'none',
298
+ mask_block_size: int = -1
299
+ ):
300
+ super().__init__()
301
+ self.image_size = to_2tuple(image_size)
302
+ self.patch_size = to_2tuple(patch_size)
303
+ self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
304
+ self.patches_per_frame = self.grid_size[0] * self.grid_size[1]
305
+ self.mask_type = mask_type
306
+ self.mask_block_size = mask_block_size
307
+
308
+ if act_layer.lower() == 'gelu':
309
+ self.act_layer = nn.GELU
310
+ else:
311
+ raise ValueError(f"Unsupported activation function: {act_layer}")
312
+ if norm_layer.lower() == 'layer_norm':
313
+ self.norm_layer = nn.LayerNorm
314
+ else:
315
+ raise ValueError(f"Unsupported normalization: {norm_layer}")
316
+
317
+ self.use_ffn_output = use_ffn_output
318
+ if use_ffn_output:
319
+ self.ffn = nn.Sequential(
320
+ nn.Linear(width, dim_ffn_output),
321
+ nn.Tanh(),
322
+ )
323
+ self.conv_out = nn.Linear(
324
+ in_features=dim_ffn_output,
325
+ out_features=3 * self.patch_size[0] * self.patch_size[1] * (1 + logit_laplace)
326
+ )
327
+ else:
328
+ self.ffn = nn.Identity()
329
+ self.conv_out = nn.Linear(
330
+ in_features=width,
331
+ out_features=3 * self.patch_size[0] * self.patch_size[1] * (1 + logit_laplace)
332
+ )
333
+
334
+ scale = width ** -0.5
335
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1], width))
336
+ assert num_frames >= 1
337
+ self.num_frames = num_frames
338
+ self.cross_frames = cross_frames
339
+ if num_frames > 1 and cross_frames:
340
+ self.temporal_positional_embedding = nn.Parameter(torch.zeros(num_frames, width))
341
+ else:
342
+ self.temporal_positional_embedding = None
343
+
344
+ self.ln_pre = self.norm_layer(width) if ln_pre else nn.Identity()
345
+
346
+ self.transformer = Transformer(
347
+ width, layers, heads, mlp_ratio, ls_init_value=ls_init_value,
348
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate,
349
+ act_layer=self.act_layer, norm_layer=self.norm_layer,
350
+ )
351
+
352
+ self.ln_post = self.norm_layer(width) if ln_post else nn.Identity()
353
+
354
+ self.init_parameters()
355
+
356
+ def init_parameters(self):
357
+ if self.positional_embedding is not None:
358
+ nn.init.normal_(self.positional_embedding, std=0.02)
359
+
360
+ for block in self.transformer.resblocks:
361
+ for n, p in block.named_parameters():
362
+ if 'weight' in n:
363
+ if 'ln' not in n:
364
+ trunc_normal_(p, std=0.02)
365
+ elif 'bias' in n:
366
+ nn.init.zeros_(p)
367
+ else:
368
+ raise NotImplementedError(f'Unknown parameters named {n}')
369
+ if self.use_ffn_output:
370
+ trunc_normal_(self.ffn[0].weight, std=0.02)
371
+ trunc_normal_(self.conv_out.weight, std=0.02)
372
+
373
+ @torch.jit.ignore
374
+ def set_grad_checkpointing(self, enable=True, selective=False):
375
+ self.transformer.grad_checkpointing = enable
376
+ self.transformer.selective_checkpointing = selective
377
+
378
+ def forward(self, x):
379
+ if self.num_frames == 1 or not self.cross_frames:
380
+ x = x + self.positional_embedding.to(x.dtype)
381
+ else:
382
+ num_frames = x.shape[1] // self.patches_per_frame
383
+ assert num_frames <= self.num_frames, 'Number of frames should be less or equal to the model setting'
384
+ tile_pos_embed = self.positional_embedding.repeat(num_frames, 1)
385
+ tile_tem_embed = self.temporal_positional_embedding[:num_frames].repeat_interleave(self.patches_per_frame, 0)
386
+ total_pos_embed = tile_pos_embed + tile_tem_embed
387
+ x = x + total_pos_embed.to(x.dtype).squeeze(0)
388
+ x = self.ln_pre(x)
389
+ x = x.permute(1, 0, 2)
390
+ block_size = self.grid_size[0] * self.grid_size[1] if self.mask_block_size <= 0 else self.mask_block_size
391
+ attn_mask = get_attention_mask(x.size(0), x.device, mask_type=self.mask_type, block_size=block_size)
392
+ x = self.transformer(x, attn_mask, is_causal=self.mask_type == 'causal')
393
+ x = x.permute(1, 0, 2)
394
+ x = self.ln_post(x)
395
+ x = self.ffn(x)
396
+ x = self.conv_out(x)
397
+ if self.num_frames == 1:
398
+ x = rearrange(
399
+ x, "b (hh ww) (c sh sw) -> b c (hh sh) (ww sw)",
400
+ hh = self.grid_size[0], ww=self.grid_size[1],
401
+ sh=self.patch_size[0], sw=self.patch_size[1]
402
+ )
403
+ elif self.cross_frames:
404
+ x = rearrange(
405
+ x, "b (t hh ww) (c sh sw) -> b c t (hh sh) (ww sw)",
406
+ t = num_frames, hh = self.grid_size[0], ww=self.grid_size[1],
407
+ sh=self.patch_size[0], sw=self.patch_size[1]
408
+ )
409
+ else:
410
+ x = rearrange(
411
+ x, "(b t) (hh ww) (c sh sw) -> b c t (hh sh) (ww sw)",
412
+ t = num_frames, hh = self.grid_size[0], ww=self.grid_size[1],
413
+ sh=self.patch_size[0], sw=self.patch_size[1]
414
+ )
415
+
416
+ return x
src/vqvaes/flowmo/flowmo.py ADDED
@@ -0,0 +1,945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model code for FlowMo.
2
+
3
+ Sources: https://github.com/feizc/FluxMusic/blob/main/train.py
4
+ https://github.com/black-forest-labs/flux/tree/main/src/flux
5
+ """
6
+
7
+ import ast
8
+ import itertools
9
+ import math
10
+ from dataclasses import dataclass
11
+ from typing import List, Tuple
12
+
13
+ import einops
14
+ import torch
15
+ from einops import rearrange, repeat
16
+ from mup import MuReadout
17
+ from torch import Tensor, nn
18
+ import argparse
19
+ import contextlib
20
+ import copy
21
+ import glob
22
+ import os
23
+ import subprocess
24
+ import tempfile
25
+ import time
26
+
27
+ import fsspec
28
+ import psutil
29
+ import torch
30
+ import torch.distributed as dist
31
+ from mup import MuReadout, set_base_shapes
32
+ from omegaconf import OmegaConf
33
+ from torch.utils.data import DataLoader
34
+
35
+ from .lookup_free_quantize import LFQ
36
+
37
+ MUP_ENABLED = True
38
+
39
+
40
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
41
+ b, h, l, d = q.shape
42
+ q, k = apply_rope(q, k, pe)
43
+
44
+ if torch.__version__ == "2.0.1+cu117": # tmp workaround
45
+ if d != 64:
46
+ print("MUP is broken in this setting! Be careful!")
47
+ x = torch.nn.functional.scaled_dot_product_attention(
48
+ q,
49
+ k,
50
+ v,
51
+ )
52
+ else:
53
+ x = torch.nn.functional.scaled_dot_product_attention(
54
+ q,
55
+ k,
56
+ v,
57
+ scale=8.0 / d if MUP_ENABLED else None,
58
+ )
59
+ assert x.shape == q.shape
60
+ x = rearrange(x, "B H L D -> B L (H D)")
61
+ return x
62
+
63
+
64
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
65
+ assert dim % 2 == 0
66
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
67
+ omega = 1.0 / (theta**scale)
68
+ out = torch.einsum("...n,d->...nd", pos, omega)
69
+ out = torch.stack(
70
+ [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)],
71
+ dim=-1,
72
+ )
73
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
74
+ return out.float()
75
+
76
+
77
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
78
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
79
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
80
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
81
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
82
+
83
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
84
+
85
+
86
+ def _get_diagonal_gaussian(parameters):
87
+ mean, logvar = torch.chunk(parameters, 2, dim=1)
88
+ logvar = torch.clamp(logvar, -30.0, 20.0)
89
+ return mean, logvar
90
+
91
+
92
+ def _sample_diagonal_gaussian(mean, logvar):
93
+ std = torch.exp(0.5 * logvar)
94
+ x = mean + std * torch.randn(mean.shape, device=mean.device)
95
+ return x
96
+
97
+
98
+ def _kl_diagonal_gaussian(mean, logvar):
99
+ var = torch.exp(logvar)
100
+ return 0.5 * torch.sum(torch.pow(mean, 2) + var - 1.0 - logvar, dim=1).mean()
101
+
102
+
103
+ class EmbedND(nn.Module):
104
+ def __init__(self, dim: int, theta: int, axes_dim):
105
+ super().__init__()
106
+ self.dim = dim
107
+ self.theta = theta
108
+ self.axes_dim = axes_dim
109
+
110
+ def forward(self, ids: Tensor) -> Tensor:
111
+ n_axes = ids.shape[-1]
112
+ emb = torch.cat(
113
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
114
+ dim=-3,
115
+ )
116
+
117
+ return emb.unsqueeze(1)
118
+
119
+
120
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
121
+ """
122
+ Create sinusoidal timestep embeddings.
123
+ :param t: a 1-D Tensor of N indices, one per batch element.
124
+ These may be fractional.
125
+ :param dim: the dimension of the output.
126
+ :param max_period: controls the minimum frequency of the embeddings.
127
+ :return: an (N, D) Tensor of positional embeddings.
128
+ """
129
+ t = time_factor * t
130
+ half = dim // 2
131
+ freqs = torch.exp(
132
+ -math.log(max_period)
133
+ * torch.arange(start=0, end=half, dtype=torch.float32)
134
+ / half
135
+ ).to(t.device)
136
+
137
+ args = t[:, None].float() * freqs[None]
138
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
139
+ if dim % 2:
140
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
141
+ if torch.is_floating_point(t):
142
+ embedding = embedding.to(t)
143
+ return embedding
144
+
145
+
146
+ class MLPEmbedder(nn.Module):
147
+ def __init__(self, in_dim: int, hidden_dim: int):
148
+ super().__init__()
149
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
150
+ self.silu = nn.SiLU()
151
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
152
+
153
+ def forward(self, x: Tensor) -> Tensor:
154
+ return self.out_layer(self.silu(self.in_layer(x)))
155
+
156
+
157
+ class RMSNorm(torch.nn.Module):
158
+ def __init__(self, dim: int):
159
+ super().__init__()
160
+ self.scale = nn.Parameter(torch.ones(dim))
161
+
162
+ def forward(self, x: Tensor):
163
+ x_dtype = x.dtype
164
+ x = x.float()
165
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
166
+ return (x * rrms).to(dtype=x_dtype) * self.scale
167
+
168
+
169
+ class QKNorm(torch.nn.Module):
170
+ def __init__(self, dim: int):
171
+ super().__init__()
172
+ self.query_norm = RMSNorm(dim)
173
+ self.key_norm = RMSNorm(dim)
174
+
175
+ def forward(self, q: Tensor, k: Tensor, v: Tensor):
176
+ q = self.query_norm(q)
177
+ k = self.key_norm(k)
178
+ return q.to(v), k.to(v)
179
+
180
+
181
+ class SelfAttention(nn.Module):
182
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
183
+ super().__init__()
184
+ self.num_heads = num_heads
185
+ head_dim = dim // num_heads
186
+
187
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
188
+ self.norm = QKNorm(head_dim)
189
+ self.proj = nn.Linear(dim, dim)
190
+
191
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
192
+ qkv = self.qkv(x)
193
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
194
+ q, k = self.norm(q, k, v)
195
+ x = attention(q, k, v, pe=pe)
196
+ x = self.proj(x)
197
+ return x
198
+
199
+
200
+ @dataclass
201
+ class ModulationOut:
202
+ shift: Tensor
203
+ scale: Tensor
204
+ gate: Tensor
205
+
206
+
207
+ class Modulation(nn.Module):
208
+ def __init__(self, dim: int, double: bool):
209
+ super().__init__()
210
+ self.is_double = double
211
+ self.multiplier = 6 if double else 3
212
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
213
+
214
+ self.lin.weight[dim * 2 : dim * 3].data[:] = 0.0
215
+ self.lin.bias[dim * 2 : dim * 3].data[:] = 0.0
216
+ self.lin.weight[dim * 5 : dim * 6].data[:] = 0.0
217
+ self.lin.bias[dim * 5 : dim * 6].data[:] = 0.0
218
+
219
+ def forward(self, vec: Tensor) -> Tuple[ModulationOut, ModulationOut]:
220
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
221
+ self.multiplier, dim=-1
222
+ )
223
+ return (
224
+ ModulationOut(*out[:3]),
225
+ ModulationOut(*out[3:]) if self.is_double else None,
226
+ )
227
+
228
+
229
+ class DoubleStreamBlock(nn.Module):
230
+ def __init__(
231
+ self,
232
+ hidden_size: int,
233
+ num_heads: int,
234
+ mlp_ratio: float,
235
+ qkv_bias: bool = False,
236
+ ):
237
+ super().__init__()
238
+
239
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
240
+ self.num_heads = num_heads
241
+ self.hidden_size = hidden_size
242
+
243
+ self.img_mod = Modulation(hidden_size, double=True)
244
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
245
+ self.img_attn = SelfAttention(
246
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
247
+ )
248
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
249
+ self.img_mlp = nn.Sequential(
250
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
251
+ nn.GELU(approximate="tanh"),
252
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
253
+ )
254
+
255
+ self.txt_mod = Modulation(hidden_size, double=True)
256
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
257
+ self.txt_attn = SelfAttention(
258
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
259
+ )
260
+
261
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
262
+ self.txt_mlp = nn.Sequential(
263
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
264
+ nn.GELU(approximate="tanh"),
265
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
266
+ )
267
+
268
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
269
+ pe_single, pe_double = pe
270
+ p = 1
271
+ if vec is None:
272
+ img_mod1, img_mod2 = ModulationOut(0, 1 - p, 1), ModulationOut(0, 1 - p, 1)
273
+ txt_mod1, txt_mod2 = ModulationOut(0, 1 - p, 1), ModulationOut(0, 1 - p, 1)
274
+ else:
275
+ img_mod1, img_mod2 = self.img_mod(vec)
276
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
277
+
278
+ # prepare image for attention
279
+ img_modulated = self.img_norm1(img)
280
+ img_modulated = (p + img_mod1.scale) * img_modulated + img_mod1.shift
281
+ img_qkv = self.img_attn.qkv(img_modulated)
282
+ img_q, img_k, img_v = rearrange(
283
+ img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
284
+ )
285
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
286
+
287
+ # prepare txt for attention
288
+ txt_modulated = self.txt_norm1(txt)
289
+ txt_modulated = (p + txt_mod1.scale) * txt_modulated + txt_mod1.shift
290
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
291
+ txt_q, txt_k, txt_v = rearrange(
292
+ txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
293
+ )
294
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
295
+
296
+ # run actual attention
297
+ q = torch.cat((txt_q, img_q), dim=2)
298
+ k = torch.cat((txt_k, img_k), dim=2)
299
+ v = torch.cat((txt_v, img_v), dim=2)
300
+
301
+ attn = attention(q, k, v, pe=pe_double)
302
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
303
+
304
+ # calculate the img bloks
305
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
306
+ img = img + img_mod2.gate * self.img_mlp(
307
+ (p + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
308
+ )
309
+
310
+ # calculate the txt bloks
311
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
312
+ txt = txt + txt_mod2.gate * self.txt_mlp(
313
+ (p + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
314
+ )
315
+ return img, txt
316
+
317
+
318
+ class LastLayer(nn.Module):
319
+ def __init__(
320
+ self,
321
+ hidden_size: int,
322
+ patch_size: int,
323
+ out_channels: int,
324
+ readout_zero_init=False,
325
+ ):
326
+ super().__init__()
327
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
328
+
329
+ if MUP_ENABLED:
330
+ self.linear = MuReadout(
331
+ hidden_size,
332
+ patch_size * patch_size * out_channels,
333
+ bias=True,
334
+ readout_zero_init=readout_zero_init,
335
+ )
336
+ else:
337
+ self.linear = nn.Linear(
338
+ hidden_size, patch_size * patch_size * out_channels, bias=True
339
+ )
340
+
341
+ self.adaLN_modulation = nn.Sequential(
342
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
343
+ )
344
+
345
+ def forward(self, x: Tensor, vec) -> Tensor:
346
+ if vec is None:
347
+ pass
348
+ else:
349
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
350
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
351
+ x = self.norm_final(x)
352
+ x = self.linear(x)
353
+ return x
354
+
355
+
356
+ @dataclass
357
+ class FluxParams:
358
+ in_channels: int
359
+ patch_size: int
360
+ context_dim: int
361
+ hidden_size: int
362
+ mlp_ratio: float
363
+ num_heads: int
364
+ depth: int
365
+ axes_dim: List[int]
366
+ theta: int
367
+ qkv_bias: bool
368
+
369
+
370
+ DIT_ZOO = dict(
371
+ dit_xl_4=dict(
372
+ hidden_size=1152,
373
+ mlp_ratio=4.0,
374
+ num_heads=16,
375
+ axes_dim=[8, 28, 28],
376
+ theta=10_000,
377
+ qkv_bias=True,
378
+ ),
379
+ dit_l_4=dict(
380
+ hidden_size=1024,
381
+ mlp_ratio=4.0,
382
+ num_heads=16,
383
+ axes_dim=[8, 28, 28],
384
+ theta=10_000,
385
+ qkv_bias=True,
386
+ ),
387
+ dit_b_4=dict(
388
+ hidden_size=768,
389
+ mlp_ratio=4.0,
390
+ num_heads=12,
391
+ axes_dim=[8, 28, 28],
392
+ theta=10_000,
393
+ qkv_bias=True,
394
+ ),
395
+ dit_s_4=dict(
396
+ hidden_size=384,
397
+ mlp_ratio=4.0,
398
+ num_heads=6,
399
+ axes_dim=[8, 28, 28],
400
+ theta=10_000,
401
+ qkv_bias=True,
402
+ ),
403
+ dit_mup_test=dict(
404
+ hidden_size=768,
405
+ mlp_ratio=4.0,
406
+ num_heads=12,
407
+ axes_dim=[8, 28, 28],
408
+ theta=10_000,
409
+ qkv_bias=True,
410
+ ),
411
+ )
412
+
413
+
414
+ def prepare_idxs(img, code_length, patch_size):
415
+ bs, c, h, w = img.shape
416
+
417
+ img_ids = torch.zeros(h // patch_size, w // patch_size, 3, device=img.device)
418
+ img_ids[..., 1] = (
419
+ img_ids[..., 1] + torch.arange(h // patch_size, device=img.device)[:, None]
420
+ )
421
+ img_ids[..., 2] = (
422
+ img_ids[..., 2] + torch.arange(w // patch_size, device=img.device)[None, :]
423
+ )
424
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
425
+
426
+ txt_ids = (
427
+ torch.zeros((bs, code_length, 3), device=img.device)
428
+ + torch.arange(code_length, device=img.device)[None, :, None]
429
+ )
430
+ return img_ids, txt_ids
431
+
432
+
433
+ class Flux(nn.Module):
434
+ """
435
+ Transformer model for flow matching on sequences.
436
+ """
437
+
438
+ def __init__(self, params: FluxParams, name="", lsg=False):
439
+ super().__init__()
440
+
441
+ self.name = name
442
+ self.lsg = lsg
443
+ self.params = params
444
+ self.in_channels = params.in_channels
445
+ self.patch_size = params.patch_size
446
+ self.out_channels = self.in_channels
447
+ if params.hidden_size % params.num_heads != 0:
448
+ raise ValueError(
449
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
450
+ )
451
+ pe_dim = params.hidden_size // params.num_heads
452
+ if sum(params.axes_dim) != pe_dim:
453
+ raise ValueError(
454
+ f"Got {params.axes_dim} but expected positional dim {pe_dim}"
455
+ )
456
+ self.hidden_size = params.hidden_size
457
+ self.num_heads = params.num_heads
458
+ self.pe_embedder = EmbedND(
459
+ dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
460
+ )
461
+
462
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
463
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
464
+ self.txt_in = nn.Linear(params.context_dim, self.hidden_size)
465
+
466
+ self.double_blocks = nn.ModuleList(
467
+ [
468
+ DoubleStreamBlock(
469
+ self.hidden_size,
470
+ self.num_heads,
471
+ mlp_ratio=params.mlp_ratio,
472
+ qkv_bias=params.qkv_bias,
473
+ )
474
+ for idx in range(params.depth)
475
+ ]
476
+ )
477
+
478
+ self.final_layer_img = LastLayer(
479
+ self.hidden_size, 1, self.out_channels, readout_zero_init=False
480
+ )
481
+ self.final_layer_txt = LastLayer(
482
+ self.hidden_size, 1, params.context_dim, readout_zero_init=False
483
+ )
484
+
485
+ def forward(
486
+ self,
487
+ img: Tensor,
488
+ img_ids: Tensor,
489
+ txt: Tensor,
490
+ txt_ids: Tensor,
491
+ timesteps: Tensor,
492
+ ) -> Tensor:
493
+ b, c, h, w = img.shape
494
+
495
+ img = rearrange(
496
+ img,
497
+ "b c (gh ph) (gw pw) -> b (gh gw) (ph pw c)",
498
+ ph=self.patch_size,
499
+ pw=self.patch_size,
500
+ )
501
+ if img.ndim != 3 or txt.ndim != 3:
502
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
503
+ img = self.img_in(img)
504
+
505
+ if timesteps is None:
506
+ vec = None
507
+ else:
508
+ vec = self.time_in(timestep_embedding(timesteps, 256))
509
+
510
+ txt = self.txt_in(txt)
511
+ pe_single = self.pe_embedder(torch.cat((txt_ids,), dim=1))
512
+ pe_double = self.pe_embedder(torch.cat((txt_ids, img_ids), dim=1))
513
+
514
+ for block in self.double_blocks:
515
+ img, txt = block(img=img, txt=txt, pe=(pe_single, pe_double), vec=vec)
516
+
517
+ img = self.final_layer_img(img, vec=vec)
518
+ img = rearrange(
519
+ img,
520
+ "b (gh gw) (ph pw c) -> b c (gh ph) (gw pw)",
521
+ ph=self.patch_size,
522
+ pw=self.patch_size,
523
+ gh=h // self.patch_size,
524
+ gw=w // self.patch_size,
525
+ )
526
+
527
+ txt = self.final_layer_txt(txt, vec=vec)
528
+ return img, txt, {"final_txt": txt}
529
+
530
+
531
+ def get_weights_to_fix(model):
532
+ with torch.no_grad():
533
+ for name, module in itertools.chain(model.named_modules()):
534
+ if "double_blocks" in name and isinstance(module, torch.nn.Linear):
535
+ yield name, module.weight
536
+
537
+
538
+ class FlowMo(nn.Module):
539
+ def __init__(self, width, config):
540
+ super().__init__()
541
+ code_length = config.model.code_length
542
+ context_dim = config.model.context_dim
543
+ enc_depth = config.model.enc_depth
544
+ dec_depth = config.model.dec_depth
545
+
546
+ patch_size = config.model.patch_size
547
+ self.config = config
548
+
549
+ self.image_size = config.data.image_size
550
+ self.patch_size = config.model.patch_size
551
+ self.code_length = code_length
552
+ self.dit_mode = "dit_b_4"
553
+ self.context_dim = context_dim
554
+ self.encoder_context_dim = context_dim * (
555
+ 1 + (self.config.model.quantization_type == "kl")
556
+ )
557
+
558
+ if config.model.quantization_type == "lfq":
559
+ self.quantizer = LFQ(
560
+ codebook_size=2**self.config.model.codebook_size_for_entropy,
561
+ dim=self.config.model.codebook_size_for_entropy,
562
+ num_codebooks=1,
563
+ token_factorization=False,
564
+ )
565
+
566
+ if self.config.model.enc_mup_width is not None:
567
+ enc_width = self.config.model.enc_mup_width
568
+ else:
569
+ enc_width = width
570
+
571
+ encoder_params = FluxParams(
572
+ in_channels=3 * patch_size**2,
573
+ context_dim=self.encoder_context_dim,
574
+ patch_size=patch_size,
575
+ depth=enc_depth,
576
+ **DIT_ZOO[self.dit_mode],
577
+ )
578
+ decoder_params = FluxParams(
579
+ in_channels=3 * patch_size**2,
580
+ context_dim=context_dim + 1,
581
+ patch_size=patch_size,
582
+ depth=dec_depth,
583
+ **DIT_ZOO[self.dit_mode],
584
+ )
585
+
586
+ # width=4, dit_b_4 is the usual model
587
+ encoder_params.hidden_size = enc_width * (encoder_params.hidden_size // 4)
588
+ decoder_params.hidden_size = width * (decoder_params.hidden_size // 4)
589
+ encoder_params.axes_dim = [
590
+ (d // 4) * enc_width for d in encoder_params.axes_dim
591
+ ]
592
+ decoder_params.axes_dim = [(d // 4) * width for d in decoder_params.axes_dim]
593
+
594
+ self.encoder = Flux(encoder_params, name="encoder")
595
+ self.decoder = Flux(decoder_params, name="decoder")
596
+
597
+ @torch.compile
598
+ def encode(self, img):
599
+ b, c, h, w = img.shape
600
+
601
+ img_idxs, txt_idxs = prepare_idxs(img, self.code_length, self.patch_size)
602
+ txt = torch.zeros(
603
+ (b, self.code_length, self.encoder_context_dim), device=img.device
604
+ )
605
+
606
+ _, code, aux = self.encoder(img, img_idxs, txt, txt_idxs, timesteps=None)
607
+
608
+ return code, aux
609
+
610
+ def _decode(self, img, code, timesteps):
611
+ b, c, h, w = img.shape
612
+
613
+ img_idxs, txt_idxs = prepare_idxs(
614
+ img,
615
+ self.code_length,
616
+ self.patch_size,
617
+ )
618
+ pred, _, decode_aux = self.decoder(
619
+ img, img_idxs, code, txt_idxs, timesteps=timesteps
620
+ )
621
+ return pred, decode_aux
622
+
623
+ @torch.compile
624
+ def decode(self, *args, **kwargs):
625
+ return self._decode(*args, **kwargs)
626
+
627
+ @torch.compile
628
+ def decode_checkpointed(self, *args, **kwargs):
629
+ # Need to compile(checkpoint), not checkpoint(compile)
630
+ assert not kwargs, kwargs
631
+ return torch.utils.checkpoint.checkpoint(
632
+ self._decode,
633
+ *args,
634
+ # WARNING: Do not use_reentrant=True with compile, it will silently
635
+ # produce incorrect gradients!
636
+ use_reentrant=False,
637
+ )
638
+
639
+ @torch.compile
640
+ def _quantize(self, code):
641
+ """
642
+ Args:
643
+ code: [b codelength context dim]
644
+
645
+ Returns:
646
+ quantized code of the same shape
647
+ """
648
+ b, t, f = code.shape
649
+ indices = None
650
+ if self.config.model.quantization_type == "noop":
651
+ quantized = code
652
+ quantizer_loss = torch.tensor(0.0).to(code.device)
653
+ elif self.config.model.quantization_type == "kl":
654
+ # colocating features of same token before split is maybe slightly
655
+ # better?
656
+ mean, logvar = _get_diagonal_gaussian(
657
+ einops.rearrange(code, "b t f -> b (f t)")
658
+ )
659
+ code = einops.rearrange(
660
+ _sample_diagonal_gaussian(mean, logvar),
661
+ "b (f t) -> b t f",
662
+ f=f // 2,
663
+ t=t,
664
+ )
665
+ quantizer_loss = _kl_diagonal_gaussian(mean, logvar)
666
+ elif self.config.model.quantization_type == "lfq":
667
+ assert f % self.config.model.codebook_size_for_entropy == 0, f
668
+ code = einops.rearrange(
669
+ code,
670
+ "b t (fg fh) -> b fg (t fh)",
671
+ fg=self.config.model.codebook_size_for_entropy,
672
+ )
673
+
674
+ (quantized, entropy_aux_loss, indices), breakdown = self.quantizer(
675
+ code, return_loss_breakdown=True
676
+ )
677
+ assert quantized.shape == code.shape
678
+ quantized = einops.rearrange(quantized, "b fg (t fh) -> b t (fg fh)", t=t)
679
+
680
+ quantizer_loss = (
681
+ entropy_aux_loss * self.config.model.entropy_loss_weight
682
+ + breakdown.commitment * self.config.model.commit_loss_weight
683
+ )
684
+ code = quantized
685
+ else:
686
+ raise NotImplementedError
687
+ return code, indices, quantizer_loss
688
+
689
+ # def forward(
690
+ # self,
691
+ # img,
692
+ # noised_img,
693
+ # timesteps,
694
+ # enable_cfg=True,
695
+ # ):
696
+ # aux = {}
697
+ #
698
+ # code, encode_aux = self.encode(img)
699
+ #
700
+ # aux["original_code"] = code
701
+ #
702
+ # b, t, f = code.shape
703
+ #
704
+ # code, _, aux["quantizer_loss"] = self._quantize(code)
705
+ #
706
+ # mask = torch.ones_like(code[..., :1])
707
+ # code = torch.concatenate([code, mask], axis=-1)
708
+ # code_pre_cfg = code
709
+ #
710
+ # if self.config.model.enable_cfg and enable_cfg:
711
+ # cfg_mask = (torch.rand((b,), device=code.device) > 0.1)[:, None, None]
712
+ # code = code * cfg_mask
713
+ #
714
+ # v_est, decode_aux = self.decode(noised_img, code, timesteps)
715
+ # aux.update(decode_aux)
716
+ #
717
+ # if self.config.model.posttrain_sample:
718
+ # aux["posttrain_sample"] = self.reconstruct_checkpoint(code_pre_cfg)
719
+ #
720
+ # return v_est, aux
721
+
722
+ def forward(self, img):
723
+ return self.reconstruct(img)
724
+
725
+ def reconstruct_checkpoint(self, code):
726
+ with torch.autocast(
727
+ "cuda",
728
+ dtype=torch.bfloat16,
729
+ ):
730
+ bs, *_ = code.shape
731
+
732
+ z = torch.randn((bs, 3, self.image_size, self.image_size)).cuda()
733
+ ts = (
734
+ torch.rand((bs, self.config.model.posttrain_sample_k + 1))
735
+ .cumsum(dim=1)
736
+ .cuda()
737
+ )
738
+ ts = ts - ts[:, :1]
739
+ ts = (ts / ts[:, -1:]).flip(dims=(1,))
740
+ dts = ts[:, :-1] - ts[:, 1:]
741
+
742
+ for i, (t, dt) in enumerate((zip(ts.T, dts.T))):
743
+ if self.config.model.posttrain_sample_enable_cfg:
744
+ mask = (torch.rand((bs,), device=code.device) > 0.1)[
745
+ :, None, None
746
+ ].to(code.dtype)
747
+ code_t = code * mask
748
+ else:
749
+ code_t = code
750
+
751
+ vc, _ = self.decode_checkpointed(z, code_t, t)
752
+
753
+ z = z - dt[:, None, None, None] * vc
754
+ return z
755
+
756
+ @torch.no_grad()
757
+ def reconstruct(self, images, dtype=torch.bfloat16, code=None):
758
+ """
759
+ Args:
760
+ images in [bchw] [-1, 1]
761
+
762
+ Returns:
763
+ images in [bchw] [-1, 1]
764
+ """
765
+ model = self
766
+ config = self.config.eval.sampling
767
+
768
+ with torch.autocast(
769
+ "cuda",
770
+ dtype=dtype,
771
+ ):
772
+ bs, c, h, w = images.shape
773
+ if code is None:
774
+ x = images.cuda()
775
+ prequantized_code = model.encode(x)[0].cuda()
776
+ code, indices, _ = model._quantize(prequantized_code)
777
+
778
+ z = torch.randn((bs, 3, h, w)).cuda()
779
+
780
+ mask = torch.ones_like(code[..., :1])
781
+ code = torch.concatenate([code * mask, mask], axis=-1)
782
+
783
+ cfg_mask = 0.0
784
+ null_code = code * cfg_mask if config.cfg != 1.0 else None
785
+
786
+ samples = rf_sample(
787
+ model,
788
+ z,
789
+ code,
790
+ null_code=null_code,
791
+ sample_steps=config.sample_steps,
792
+ cfg=config.cfg,
793
+ schedule=config.schedule,
794
+ )[-1].clip(-1, 1)
795
+ return samples.to(torch.float32), code, prequantized_code
796
+
797
+
798
+ def rf_loss(config, model, batch, aux_state):
799
+ x = batch["image"]
800
+ b = x.size(0)
801
+
802
+ if config.opt.schedule == "lognormal":
803
+ nt = torch.randn((b,)).to(x.device)
804
+ t = torch.sigmoid(nt)
805
+ elif config.opt.schedule == "fat_lognormal":
806
+ nt = torch.randn((b,)).to(x.device)
807
+ t = torch.sigmoid(nt)
808
+ t = torch.where(torch.rand_like(t) <= 0.9, t, torch.rand_like(t))
809
+ elif config.opt.schedule == "uniform":
810
+ t = torch.rand((b,), device=x.device)
811
+ elif config.opt.schedule.startswith("debug"):
812
+ p = float(config.opt.schedule.split("_")[1])
813
+ t = torch.ones((b,), device=x.device) * p
814
+ else:
815
+ raise NotImplementedError
816
+
817
+ t = t.view([b, *([1] * len(x.shape[1:]))])
818
+ z1 = torch.randn_like(x)
819
+ zt = (1 - t) * x + t * z1
820
+
821
+ zt, t = zt.to(x.dtype), t.to(x.dtype)
822
+
823
+ vtheta, aux = model(
824
+ img=x,
825
+ noised_img=zt,
826
+ timesteps=t.reshape((b,)),
827
+ )
828
+
829
+ diff = z1 - vtheta - x
830
+ x_pred = zt - vtheta * t
831
+
832
+ loss = ((diff) ** 2).mean(dim=list(range(1, len(x.shape))))
833
+ loss = loss.mean()
834
+
835
+ aux["loss_dict"] = {}
836
+ aux["loss_dict"]["diffusion_loss"] = loss
837
+ aux["loss_dict"]["quantizer_loss"] = aux["quantizer_loss"]
838
+
839
+ if config.opt.lpips_weight != 0.0:
840
+ aux_loss = 0.0
841
+ if config.model.posttrain_sample:
842
+ x_pred = aux["posttrain_sample"]
843
+
844
+ lpips_dist = aux_state["lpips_model"](x, x_pred)
845
+ lpips_dist = (config.opt.lpips_weight * lpips_dist).mean() + aux_loss
846
+ aux["loss_dict"]["lpips_loss"] = lpips_dist
847
+ else:
848
+ lpips_dist = 0.0
849
+
850
+ loss = loss + aux["quantizer_loss"] + lpips_dist
851
+ aux["loss_dict"]["total_loss"] = loss
852
+ return loss, aux
853
+
854
+
855
+ def _edm_to_flow_convention(noise_level):
856
+ # z = x + \sigma z'
857
+ return noise_level / (1 + noise_level)
858
+
859
+
860
+ def rf_sample(
861
+ model,
862
+ z,
863
+ code,
864
+ null_code=None,
865
+ sample_steps=25,
866
+ cfg=2.0,
867
+ schedule="linear",
868
+ ):
869
+ b = z.size(0)
870
+ if schedule == "linear":
871
+ ts = torch.arange(1, sample_steps + 1).flip(0) / sample_steps
872
+ dts = torch.ones_like(ts) * (1.0 / sample_steps)
873
+ elif schedule.startswith("pow"):
874
+ p = float(schedule.split("_")[1])
875
+ ts = torch.arange(0, sample_steps + 1).flip(0) ** (1 / p) / sample_steps ** (
876
+ 1 / p
877
+ )
878
+ dts = ts[:-1] - ts[1:]
879
+ else:
880
+ raise NotImplementedError
881
+
882
+ if model.config.eval.sampling.cfg_interval is None:
883
+ interval = None
884
+ else:
885
+ cfg_lo, cfg_hi = ast.literal_eval(model.config.eval.sampling.cfg_interval)
886
+ interval = _edm_to_flow_convention(cfg_lo), _edm_to_flow_convention(cfg_hi)
887
+
888
+ images = []
889
+ for i, (t, dt) in enumerate((zip(ts, dts))):
890
+ timesteps = torch.tensor([t] * b).to(z.device)
891
+ vc, decode_aux = model.decode(img=z, timesteps=timesteps, code=code)
892
+
893
+ if null_code is not None and (
894
+ interval is None
895
+ or ((t.item() >= interval[0]) and (t.item() <= interval[1]))
896
+ ):
897
+ vu, _ = model.decode(img=z, timesteps=timesteps, code=null_code)
898
+ vc = vu + cfg * (vc - vu)
899
+
900
+ z = z - dt * vc
901
+ images.append(z)
902
+ return images
903
+
904
+
905
+ def build_model(config):
906
+ with tempfile.TemporaryDirectory() as log_dir:
907
+ MUP_ENABLED = config.model.enable_mup
908
+ model_partial = FlowMo
909
+
910
+ shared_kwargs = dict(config=config)
911
+ model = model_partial(
912
+ **shared_kwargs,
913
+ width=config.model.mup_width,
914
+ ).cuda()
915
+
916
+ if config.model.enable_mup:
917
+ print("Mup enabled!")
918
+ with torch.device("cpu"):
919
+ base_model = model_partial(
920
+ **shared_kwargs, width=config.model.mup_width
921
+ )
922
+ delta_model = model_partial(
923
+ **shared_kwargs,
924
+ width=(
925
+ config.model.mup_width * 4 if config.model.mup_width == 1 else 1
926
+ ),
927
+ )
928
+ true_model = model_partial(
929
+ **shared_kwargs, width=config.model.mup_width
930
+ )
931
+
932
+ if torch.distributed.is_initialized():
933
+ bsh_path = os.path.join(log_dir, f"{dist.get_rank()}.bsh")
934
+ else:
935
+ bsh_path = os.path.join(log_dir, "0.bsh")
936
+ set_base_shapes(
937
+ true_model, base_model, delta=delta_model, savefile=bsh_path
938
+ )
939
+
940
+ model = set_base_shapes(model, base=bsh_path)
941
+
942
+ for module in model.modules():
943
+ if isinstance(module, MuReadout):
944
+ module.width_mult = lambda: module.weight.infshape.width_mult()
945
+ return model
src/vqvaes/flowmo/lookup_free_quantize.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code is from https://github.com/TencentARC/SEED-Voken. Thanks!
3
+
4
+ Lookup Free Quantization
5
+ Proposed in https://arxiv.org/abs/2310.05737
6
+
7
+ In the simplest setup, each dimension is quantized into {-1, 1}.
8
+ An entropy penalty is used to encourage utilization.
9
+
10
+ Refer to
11
+ https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py
12
+ https://github.com/theAdamColton/ijepa-enhanced/blob/7edef5f7288ae8f537f0db8a10044a2a487f70c9/ijepa_enhanced/lfq.py
13
+ """
14
+
15
+ from collections import namedtuple
16
+ from math import ceil, log2
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from einops import pack, rearrange, reduce, unpack
21
+ from torch import einsum
22
+ from torch.nn import Module
23
+
24
+ # constants
25
+
26
+ LossBreakdown = namedtuple(
27
+ "LossBreakdown",
28
+ ["per_sample_entropy", "codebook_entropy", "commitment", "avg_probs"],
29
+ )
30
+
31
+ # helper functions
32
+
33
+
34
+ def exists(v):
35
+ return v is not None
36
+
37
+
38
+ def default(*args):
39
+ for arg in args:
40
+ if exists(arg):
41
+ return arg() if callable(arg) else arg
42
+ return None
43
+
44
+
45
+ def pack_one(t, pattern):
46
+ return pack([t], pattern)
47
+
48
+
49
+ def unpack_one(t, ps, pattern):
50
+ return unpack(t, ps, pattern)[0]
51
+
52
+
53
+ # entropy
54
+
55
+ # def log(t, eps = 1e-5):
56
+ # return t.clamp(min = eps).log()
57
+
58
+
59
+ def entropy(prob):
60
+ return (-prob * torch.log(prob + 1e-5)).sum(dim=-1)
61
+
62
+
63
+ # class
64
+
65
+
66
+ def mult_along_first_dims(x, y):
67
+ """
68
+ returns x * y elementwise along the leading dimensions of y
69
+ """
70
+ ndim_to_expand = x.ndim - y.ndim
71
+ for _ in range(ndim_to_expand):
72
+ y = y.unsqueeze(-1)
73
+ return x * y
74
+
75
+
76
+ def masked_mean(x, m):
77
+ """
78
+ takes the mean of the elements of x that are not masked
79
+ the mean is taken along the shared leading dims of m
80
+ equivalent to: x[m].mean(tuple(range(m.ndim)))
81
+
82
+ The benefit of using masked_mean rather than using
83
+ tensor indexing is that masked_mean is much faster
84
+ for torch-compile on batches.
85
+
86
+ The drawback is larger floating point errors
87
+ """
88
+ x = mult_along_first_dims(x, m)
89
+ x = x / m.sum()
90
+ return x.sum(tuple(range(m.ndim)))
91
+
92
+
93
+ def entropy_loss(
94
+ logits,
95
+ mask=None,
96
+ # temperature=0.01,
97
+ sample_minimization_weight=1.0,
98
+ batch_maximization_weight=1.0,
99
+ eps=1e-5,
100
+ ):
101
+ """
102
+ Entropy loss of unnormalized logits
103
+
104
+ logits: Affinities are over the last dimension
105
+
106
+ https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279
107
+ LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024)
108
+ """
109
+ # import pdb
110
+ # pdb.set_trace()
111
+ # print(logits.shape)
112
+ # raise
113
+
114
+ temperature = 0.1
115
+ probs = F.softmax(logits / temperature, -1)
116
+ log_probs = F.log_softmax(logits / temperature + eps, -1)
117
+
118
+ if mask is not None:
119
+ # avg_probs = probs[mask].mean(tuple(range(probs.ndim - 1)))
120
+ # avg_probs = einx.mean("... D -> D", probs[mask])
121
+
122
+ avg_probs = masked_mean(probs, mask)
123
+ # avg_probs = einx.mean("... D -> D", avg_probs)
124
+ else:
125
+ avg_probs = reduce(probs, "... D -> D", "mean")
126
+
127
+ avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps))
128
+
129
+ sample_entropy = -torch.sum(probs * log_probs, -1)
130
+ if mask is not None:
131
+ # sample_entropy = sample_entropy[mask].mean()
132
+ sample_entropy = masked_mean(sample_entropy, mask).mean()
133
+ else:
134
+ sample_entropy = torch.mean(sample_entropy)
135
+
136
+ loss = (sample_minimization_weight * sample_entropy) - (
137
+ batch_maximization_weight * avg_entropy
138
+ )
139
+
140
+ return sample_entropy, avg_entropy, loss
141
+
142
+
143
+ class LFQ(Module):
144
+ def __init__(
145
+ self,
146
+ *,
147
+ dim=None,
148
+ codebook_size=None,
149
+ num_codebooks=1,
150
+ sample_minimization_weight=1.0,
151
+ batch_maximization_weight=1.0,
152
+ token_factorization=False,
153
+ factorized_bits=[9, 9],
154
+ ):
155
+ super().__init__()
156
+
157
+ # some assert validations
158
+
159
+ assert exists(dim) or exists(
160
+ codebook_size
161
+ ), "either dim or codebook_size must be specified for LFQ"
162
+ assert (
163
+ not exists(codebook_size) or log2(codebook_size).is_integer()
164
+ ), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})"
165
+
166
+ self.codebook_size = default(codebook_size, lambda: 2**dim)
167
+ self.codebook_dim = int(log2(codebook_size))
168
+
169
+ codebook_dims = self.codebook_dim * num_codebooks
170
+ dim = default(dim, codebook_dims)
171
+
172
+ has_projections = dim != codebook_dims
173
+ self.has_projections = has_projections
174
+
175
+ self.dim = dim
176
+ self.codebook_dim = self.codebook_dim
177
+ self.num_codebooks = num_codebooks
178
+
179
+ # for entropy loss
180
+ self.sample_minimization_weight = sample_minimization_weight
181
+ self.batch_maximization_weight = batch_maximization_weight
182
+
183
+ # for no auxiliary loss, during inference
184
+ self.token_factorization = token_factorization
185
+ if not self.token_factorization: # for first stage model
186
+ self.register_buffer(
187
+ "mask", 2 ** torch.arange(self.codebook_dim), persistent=False
188
+ )
189
+ else:
190
+ self.factorized_bits = factorized_bits
191
+ self.register_buffer(
192
+ "pre_mask", 2 ** torch.arange(factorized_bits[0]), persistent=False
193
+ )
194
+ self.register_buffer(
195
+ "post_mask", 2 ** torch.arange(factorized_bits[1]), persistent=False
196
+ )
197
+
198
+ self.register_buffer("zero", torch.tensor(0.0), persistent=False)
199
+
200
+ # codes
201
+
202
+ all_codes = torch.arange(codebook_size)
203
+ bits = self.indices_to_bits(all_codes)
204
+ codebook = bits * 2.0 - 1.0
205
+
206
+ self.register_buffer("codebook", codebook, persistent=False)
207
+
208
+ @property
209
+ def dtype(self):
210
+ return self.codebook.dtype
211
+
212
+ def indices_to_bits(self, x):
213
+ """
214
+ x: long tensor of indices
215
+
216
+ returns big endian bits
217
+ """
218
+ mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long)
219
+ # x is now big endian bits, the last dimension being the bits
220
+ x = (x.unsqueeze(-1) & mask) != 0
221
+ return x
222
+
223
+ def get_codebook_entry(self, x, bhwc, order): # 0610
224
+ if self.token_factorization:
225
+ if order == "pre":
226
+ mask = 2 ** torch.arange(
227
+ self.factorized_bits[0], device=x.device, dtype=torch.long
228
+ )
229
+ else:
230
+ mask = 2 ** torch.arange(
231
+ self.factorized_bits[1], device=x.device, dtype=torch.long
232
+ )
233
+ else:
234
+ mask = 2 ** torch.arange(
235
+ self.codebook_dim, device=x.device, dtype=torch.long
236
+ )
237
+
238
+ x = (x.unsqueeze(-1) & mask) != 0
239
+ x = x * 2.0 - 1.0 # back to the float
240
+ ## scale back to the
241
+ b, h, w, c = bhwc
242
+ x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c)
243
+ x = rearrange(x, "b h w c -> b c h w")
244
+ return x
245
+
246
+ def bits_to_indices(self, bits):
247
+ """
248
+ bits: bool tensor of big endian bits, where the last dimension is the bit dimension
249
+
250
+ returns indices, which are long integers from 0 to self.codebook_size
251
+ """
252
+ assert bits.shape[-1] == self.codebook_dim
253
+ indices = 2 ** torch.arange(
254
+ 0,
255
+ self.codebook_dim,
256
+ 1,
257
+ dtype=torch.long,
258
+ device=bits.device,
259
+ )
260
+ return (bits * indices).sum(-1)
261
+
262
+ def decode(self, x):
263
+ """
264
+ x: ... NH
265
+ where NH is number of codebook heads
266
+ A longtensor of codebook indices, containing values from
267
+ 0 to self.codebook_size
268
+ """
269
+ x = self.indices_to_bits(x)
270
+ # to some sort of float
271
+ x = x.to(self.dtype)
272
+ # -1 or 1
273
+ x = x * 2 - 1
274
+ x = rearrange(x, "... NC Z-> ... (NC Z)")
275
+ return x
276
+
277
+ def forward(
278
+ self,
279
+ x,
280
+ inv_temperature=100.0,
281
+ return_loss_breakdown=False,
282
+ mask=None,
283
+ return_loss=True,
284
+ ):
285
+ """
286
+ einstein notation
287
+ b - batch
288
+ n - sequence (or flattened spatial dimensions)
289
+ d - feature dimension, which is also log2(codebook size)
290
+ c - number of codebook dim
291
+ """
292
+ # x = x.tanh() * 1.5
293
+
294
+ x = rearrange(x, "b d ... -> b ... d")
295
+ x, ps = pack_one(x, "b * d")
296
+ # split out number of codebooks
297
+
298
+ x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks)
299
+
300
+ codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype)
301
+ quantized = torch.where(
302
+ x > 0, codebook_value, -codebook_value
303
+ ) # higher than 0 filled
304
+
305
+ # calculate indices
306
+ if self.token_factorization:
307
+ indices_pre = reduce(
308
+ (quantized[..., : self.factorized_bits[0]] > 0).int()
309
+ * self.pre_mask.int(),
310
+ "b n c d -> b n c",
311
+ "sum",
312
+ )
313
+ indices_post = reduce(
314
+ (quantized[..., self.factorized_bits[0] :] > 0).int()
315
+ * self.post_mask.int(),
316
+ "b n c d -> b n c",
317
+ "sum",
318
+ )
319
+ else:
320
+ # print(quantized.shape)
321
+ indices = reduce(
322
+ (quantized > 0).int() * self.mask.int(), "b n c d -> b n c", "sum"
323
+ )
324
+ # print(indices.shape)
325
+
326
+ # entropy aux loss
327
+
328
+ if self.training and return_loss:
329
+ logits = 2 * einsum("... i d, j d -> ... i j", x, self.codebook)
330
+ # the same as euclidean distance up to a constant
331
+ # import pdb
332
+ # pdb.set_trace()
333
+
334
+ per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss(
335
+ logits=logits,
336
+ sample_minimization_weight=self.sample_minimization_weight,
337
+ batch_maximization_weight=self.batch_maximization_weight,
338
+ )
339
+
340
+ avg_probs = self.zero
341
+ else:
342
+ # logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook)
343
+ # probs = F.softmax(logits / 0.01, -1)
344
+ # avg_probs = reduce(probs, "b n c d -> b d", "mean")
345
+ # avg_probs = torch.sum(avg_probs, 0) #batch dimension
346
+ # if not training, just return dummy 0
347
+ per_sample_entropy = codebook_entropy = self.zero
348
+ ## calculate the codebook_entropy needed for one batch evaluation
349
+ entropy_aux_loss = self.zero
350
+ avg_probs = self.zero
351
+
352
+ # commit loss
353
+
354
+ if self.training:
355
+ commit_loss = F.mse_loss(x, quantized.detach(), reduction="none")
356
+
357
+ if exists(mask):
358
+ commit_loss = commit_loss[mask]
359
+
360
+ commit_loss = commit_loss.mean()
361
+ else:
362
+ commit_loss = self.zero
363
+
364
+ # use straight-through gradients (optionally with custom activation fn) if training
365
+
366
+ quantized = x + (quantized - x).detach() # transfer to quantized
367
+
368
+ # merge back codebook dim
369
+
370
+ quantized = rearrange(quantized, "b n c d -> b n (c d)")
371
+
372
+ # reconstitute image or video dimensions
373
+
374
+ quantized = unpack_one(quantized, ps, "b * d")
375
+ quantized = rearrange(quantized, "b ... d -> b d ...")
376
+
377
+ if self.token_factorization:
378
+ indices_pre = unpack_one(indices_pre, ps, "b * c")
379
+ indices_post = unpack_one(indices_post, ps, "b * c")
380
+ indices_pre = indices_pre.flatten()
381
+ indices_post = indices_post.flatten()
382
+ indices = (indices_pre, indices_post)
383
+ else:
384
+ # print(indices.shape, ps)
385
+ indices = unpack_one(indices, ps, "b * c")
386
+ # print(indices.shape)
387
+ indices = indices.flatten()
388
+
389
+ ret = (quantized, entropy_aux_loss, indices)
390
+
391
+ if not return_loss_breakdown:
392
+ return ret
393
+
394
+ return ret, LossBreakdown(
395
+ per_sample_entropy, codebook_entropy, commit_loss, avg_probs
396
+ )
src/vqvaes/infinity/conv.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class Conv(nn.Module):
8
+ def __init__(
9
+ self,
10
+ in_channels,
11
+ out_channels,
12
+ kernel_size,
13
+ stride=1,
14
+ padding=0,
15
+ cnn_type="2d",
16
+ causal_offset=0,
17
+ temporal_down=False,
18
+ ):
19
+ super().__init__()
20
+ self.cnn_type = cnn_type
21
+ self.slice_seq_len = 17
22
+
23
+ if cnn_type == "2d":
24
+ self.conv = nn.Conv2d(
25
+ in_channels, out_channels, kernel_size, stride=stride, padding=padding
26
+ )
27
+ if cnn_type == "3d":
28
+ if temporal_down == False:
29
+ stride = (1, stride, stride)
30
+ else:
31
+ stride = (stride, stride, stride)
32
+ self.conv = nn.Conv3d(
33
+ in_channels, out_channels, kernel_size, stride=stride, padding=0
34
+ )
35
+ if isinstance(kernel_size, int):
36
+ kernel_size = (kernel_size, kernel_size, kernel_size)
37
+ self.padding = (
38
+ kernel_size[0] - 1 + causal_offset, # Temporal causal padding
39
+ padding, # Height padding
40
+ padding, # Width padding
41
+ )
42
+ self.causal_offset = causal_offset
43
+ self.stride = stride
44
+ self.kernel_size = kernel_size
45
+
46
+ def forward(self, x):
47
+ if self.cnn_type == "2d":
48
+ if x.ndim == 5:
49
+ B, C, T, H, W = x.shape
50
+ x = rearrange(x, "B C T H W -> (B T) C H W")
51
+ x = self.conv(x)
52
+ x = rearrange(x, "(B T) C H W -> B C T H W", T=T)
53
+ return x
54
+ else:
55
+ return self.conv(x)
56
+ if self.cnn_type == "3d":
57
+ assert (
58
+ self.stride[0] == 1 or self.stride[0] == 2
59
+ ), f"only temporal stride = 1 or 2 are supported"
60
+ xs = []
61
+ for i in range(0, x.shape[2], self.slice_seq_len + self.stride[0] - 1):
62
+ st = i
63
+ en = min(i + self.slice_seq_len, x.shape[2])
64
+ _x = x[:, :, st:en, :, :]
65
+ if i == 0:
66
+ _x = F.pad(
67
+ _x,
68
+ (
69
+ self.padding[2],
70
+ self.padding[2], # Width
71
+ self.padding[1],
72
+ self.padding[1], # Height
73
+ self.padding[0],
74
+ 0,
75
+ ),
76
+ ) # Temporal
77
+ else:
78
+ padding_0 = self.kernel_size[0] - 1
79
+ _x = F.pad(
80
+ _x,
81
+ (
82
+ self.padding[2],
83
+ self.padding[2], # Width
84
+ self.padding[1],
85
+ self.padding[1], # Height
86
+ padding_0,
87
+ 0,
88
+ ),
89
+ ) # Temporal
90
+ _x[
91
+ :,
92
+ :,
93
+ :padding_0,
94
+ self.padding[1] : _x.shape[-2] - self.padding[1],
95
+ self.padding[2] : _x.shape[-1] - self.padding[2],
96
+ ] += x[:, :, i - padding_0 : i, :, :]
97
+ _x = self.conv(_x)
98
+ xs.append(_x)
99
+ try:
100
+ x = torch.cat(xs, dim=2)
101
+ except:
102
+ device = x.device
103
+ del x
104
+ xs = [_x.cpu().pin_memory() for _x in xs]
105
+ torch.cuda.empty_cache()
106
+ x = torch.cat([_x.cpu() for _x in xs], dim=2).to(device=device)
107
+ return x
src/vqvaes/infinity/dynamic_resolution.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import tqdm
4
+
5
+ vae_stride = 16
6
+ ratio2hws = {
7
+ 1.000: [
8
+ (1, 1),
9
+ (2, 2),
10
+ (4, 4),
11
+ (6, 6),
12
+ (8, 8),
13
+ (12, 12),
14
+ (16, 16),
15
+ (20, 20),
16
+ (24, 24),
17
+ (32, 32),
18
+ (40, 40),
19
+ (48, 48),
20
+ (64, 64),
21
+ ],
22
+ 1.250: [
23
+ (1, 1),
24
+ (2, 2),
25
+ (3, 3),
26
+ (5, 4),
27
+ (10, 8),
28
+ (15, 12),
29
+ (20, 16),
30
+ (25, 20),
31
+ (30, 24),
32
+ (35, 28),
33
+ (45, 36),
34
+ (55, 44),
35
+ (70, 56),
36
+ ],
37
+ 1.333: [
38
+ (1, 1),
39
+ (2, 2),
40
+ (4, 3),
41
+ (8, 6),
42
+ (12, 9),
43
+ (16, 12),
44
+ (20, 15),
45
+ (24, 18),
46
+ (28, 21),
47
+ (36, 27),
48
+ (48, 36),
49
+ (60, 45),
50
+ (72, 54),
51
+ ],
52
+ 1.500: [
53
+ (1, 1),
54
+ (2, 2),
55
+ (3, 2),
56
+ (6, 4),
57
+ (9, 6),
58
+ (15, 10),
59
+ (21, 14),
60
+ (27, 18),
61
+ (33, 22),
62
+ (39, 26),
63
+ (48, 32),
64
+ (63, 42),
65
+ (78, 52),
66
+ ],
67
+ 1.750: [
68
+ (1, 1),
69
+ (2, 2),
70
+ (3, 3),
71
+ (7, 4),
72
+ (11, 6),
73
+ (14, 8),
74
+ (21, 12),
75
+ (28, 16),
76
+ (35, 20),
77
+ (42, 24),
78
+ (56, 32),
79
+ (70, 40),
80
+ (84, 48),
81
+ ],
82
+ 2.000: [
83
+ (1, 1),
84
+ (2, 2),
85
+ (4, 2),
86
+ (6, 3),
87
+ (10, 5),
88
+ (16, 8),
89
+ (22, 11),
90
+ (30, 15),
91
+ (38, 19),
92
+ (46, 23),
93
+ (60, 30),
94
+ (74, 37),
95
+ (90, 45),
96
+ ],
97
+ 2.500: [
98
+ (1, 1),
99
+ (2, 2),
100
+ (5, 2),
101
+ (10, 4),
102
+ (15, 6),
103
+ (20, 8),
104
+ (25, 10),
105
+ (30, 12),
106
+ (40, 16),
107
+ (50, 20),
108
+ (65, 26),
109
+ (80, 32),
110
+ (100, 40),
111
+ ],
112
+ 3.000: [
113
+ (1, 1),
114
+ (2, 2),
115
+ (6, 2),
116
+ (9, 3),
117
+ (15, 5),
118
+ (21, 7),
119
+ (27, 9),
120
+ (36, 12),
121
+ (45, 15),
122
+ (54, 18),
123
+ (72, 24),
124
+ (90, 30),
125
+ (111, 37),
126
+ ],
127
+ }
128
+ full_ratio2hws = {}
129
+ for ratio, hws in ratio2hws.items():
130
+ full_ratio2hws[ratio] = hws
131
+ full_ratio2hws[int(1 / ratio * 1000) / 1000] = [(item[1], item[0]) for item in hws]
132
+
133
+ dynamic_resolution_h_w = {}
134
+ predefined_HW_Scales_dynamic = {}
135
+ for ratio in full_ratio2hws:
136
+ dynamic_resolution_h_w[ratio] = {}
137
+ for ind, leng in enumerate([7, 10, 13]):
138
+ h, w = (
139
+ full_ratio2hws[ratio][leng - 1][0],
140
+ full_ratio2hws[ratio][leng - 1][1],
141
+ ) # feature map size
142
+ pixel = (h * vae_stride, w * vae_stride) # The original image (H, W)
143
+ dynamic_resolution_h_w[ratio][pixel[1]] = {
144
+ "pixel": pixel,
145
+ "scales": full_ratio2hws[ratio][:leng],
146
+ } # W as key
147
+ predefined_HW_Scales_dynamic[(h, w)] = full_ratio2hws[ratio][:leng]
src/vqvaes/infinity/flux_vqgan.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import imageio
4
+ import torch
5
+ import numpy as np
6
+ from einops import rearrange
7
+ from torch import Tensor, nn
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+ from torchvision import transforms
11
+ from safetensors.torch import load_file
12
+ import torch.utils.checkpoint as checkpoint
13
+
14
+ from .conv import Conv
15
+ from .multiscale_bsq import MultiScaleBSQ
16
+
17
+ ptdtype = {None: torch.float32, "fp32": torch.float32, "bf16": torch.bfloat16}
18
+
19
+
20
+ class Normalize(nn.Module):
21
+ def __init__(self, in_channels, norm_type, norm_axis="spatial"):
22
+ super().__init__()
23
+ self.norm_axis = norm_axis
24
+ assert norm_type in ["group", "batch", "no"]
25
+ if norm_type == "group":
26
+ if in_channels % 32 == 0:
27
+ self.norm = nn.GroupNorm(
28
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
29
+ )
30
+ elif in_channels % 24 == 0:
31
+ self.norm = nn.GroupNorm(
32
+ num_groups=24, num_channels=in_channels, eps=1e-6, affine=True
33
+ )
34
+ else:
35
+ raise NotImplementedError
36
+ elif norm_type == "batch":
37
+ self.norm = nn.SyncBatchNorm(
38
+ in_channels, track_running_stats=False
39
+ ) # Runtime Error: grad inplace if set track_running_stats to True
40
+ elif norm_type == "no":
41
+ self.norm = nn.Identity()
42
+
43
+ def forward(self, x):
44
+ if self.norm_axis == "spatial":
45
+ if x.ndim == 4:
46
+ x = self.norm(x)
47
+ else:
48
+ B, C, T, H, W = x.shape
49
+ x = rearrange(x, "B C T H W -> (B T) C H W")
50
+ x = self.norm(x)
51
+ x = rearrange(x, "(B T) C H W -> B C T H W", T=T)
52
+ elif self.norm_axis == "spatial-temporal":
53
+ x = self.norm(x)
54
+ else:
55
+ raise NotImplementedError
56
+ return x
57
+
58
+
59
+ def swish(x: Tensor) -> Tensor:
60
+ try:
61
+ return x * torch.sigmoid(x)
62
+ except:
63
+ device = x.device
64
+ x = x.cpu().pin_memory()
65
+ return (x * torch.sigmoid(x)).to(device=device)
66
+
67
+
68
+ class AttnBlock(nn.Module):
69
+ def __init__(self, in_channels, norm_type="group", cnn_param=None):
70
+ super().__init__()
71
+ self.in_channels = in_channels
72
+
73
+ self.norm = Normalize(
74
+ in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"]
75
+ )
76
+
77
+ self.q = Conv(in_channels, in_channels, kernel_size=1)
78
+ self.k = Conv(in_channels, in_channels, kernel_size=1)
79
+ self.v = Conv(in_channels, in_channels, kernel_size=1)
80
+ self.proj_out = Conv(in_channels, in_channels, kernel_size=1)
81
+
82
+ def attention(self, h_: Tensor) -> Tensor:
83
+ B, _, T, _, _ = h_.shape
84
+ h_ = self.norm(h_)
85
+ h_ = rearrange(h_, "B C T H W -> (B T) C H W") # spatial attention only
86
+ q = self.q(h_)
87
+ k = self.k(h_)
88
+ v = self.v(h_)
89
+
90
+ b, c, h, w = q.shape
91
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
92
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
93
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
94
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
95
+
96
+ return rearrange(h_, "(b t) 1 (h w) c -> b c t h w", h=h, w=w, c=c, b=B, t=T)
97
+
98
+ def forward(self, x: Tensor) -> Tensor:
99
+ return x + self.proj_out(self.attention(x))
100
+
101
+
102
+ class ResnetBlock(nn.Module):
103
+ def __init__(
104
+ self, in_channels: int, out_channels: int, norm_type="group", cnn_param=None
105
+ ):
106
+ super().__init__()
107
+ self.in_channels = in_channels
108
+ out_channels = in_channels if out_channels is None else out_channels
109
+ self.out_channels = out_channels
110
+
111
+ self.norm1 = Normalize(
112
+ in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"]
113
+ )
114
+ if cnn_param["res_conv_2d"] in ["half", "full"]:
115
+ self.conv1 = Conv(
116
+ in_channels,
117
+ out_channels,
118
+ kernel_size=3,
119
+ stride=1,
120
+ padding=1,
121
+ cnn_type="2d",
122
+ )
123
+ else:
124
+ self.conv1 = Conv(
125
+ in_channels,
126
+ out_channels,
127
+ kernel_size=3,
128
+ stride=1,
129
+ padding=1,
130
+ cnn_type=cnn_param["cnn_type"],
131
+ )
132
+ self.norm2 = Normalize(
133
+ out_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"]
134
+ )
135
+ if cnn_param["res_conv_2d"] in ["full"]:
136
+ self.conv2 = Conv(
137
+ out_channels,
138
+ out_channels,
139
+ kernel_size=3,
140
+ stride=1,
141
+ padding=1,
142
+ cnn_type="2d",
143
+ )
144
+ else:
145
+ self.conv2 = Conv(
146
+ out_channels,
147
+ out_channels,
148
+ kernel_size=3,
149
+ stride=1,
150
+ padding=1,
151
+ cnn_type=cnn_param["cnn_type"],
152
+ )
153
+ if self.in_channels != self.out_channels:
154
+ self.nin_shortcut = Conv(
155
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
156
+ )
157
+
158
+ def forward(self, x):
159
+ h = x
160
+ h = self.norm1(h)
161
+ h = swish(h)
162
+ h = self.conv1(h)
163
+
164
+ h = self.norm2(h)
165
+ h = swish(h)
166
+ h = self.conv2(h)
167
+
168
+ if self.in_channels != self.out_channels:
169
+ x = self.nin_shortcut(x)
170
+
171
+ return x + h
172
+
173
+
174
+ class Downsample(nn.Module):
175
+ def __init__(
176
+ self, in_channels, cnn_type="2d", spatial_down=False, temporal_down=False
177
+ ):
178
+ super().__init__()
179
+ assert spatial_down == True
180
+ if cnn_type == "2d":
181
+ self.pad = (0, 1, 0, 1)
182
+ if cnn_type == "3d":
183
+ self.pad = (
184
+ 0,
185
+ 1,
186
+ 0,
187
+ 1,
188
+ 0,
189
+ 0,
190
+ ) # add padding to the right for h-axis and w-axis. No padding for t-axis
191
+ # no asymmetric padding in torch conv, must do it ourselves
192
+ self.conv = Conv(
193
+ in_channels,
194
+ in_channels,
195
+ kernel_size=3,
196
+ stride=2,
197
+ padding=0,
198
+ cnn_type=cnn_type,
199
+ temporal_down=temporal_down,
200
+ )
201
+
202
+ def forward(self, x: Tensor):
203
+ x = nn.functional.pad(x, self.pad, mode="constant", value=0)
204
+ x = self.conv(x)
205
+ return x
206
+
207
+
208
+ class Upsample(nn.Module):
209
+ def __init__(
210
+ self,
211
+ in_channels,
212
+ cnn_type="2d",
213
+ spatial_up=False,
214
+ temporal_up=False,
215
+ use_pxsl=False,
216
+ ):
217
+ super().__init__()
218
+ if cnn_type == "2d":
219
+ self.scale_factor = 2
220
+ self.causal_offset = 0
221
+ else:
222
+ assert spatial_up == True
223
+ if temporal_up:
224
+ self.scale_factor = (2, 2, 2)
225
+ self.causal_offset = -1
226
+ else:
227
+ self.scale_factor = (1, 2, 2)
228
+ self.causal_offset = 0
229
+ self.use_pxsl = use_pxsl
230
+ if self.use_pxsl:
231
+ self.conv = Conv(
232
+ in_channels,
233
+ in_channels * 4,
234
+ kernel_size=3,
235
+ stride=1,
236
+ padding=1,
237
+ cnn_type=cnn_type,
238
+ causal_offset=self.causal_offset,
239
+ )
240
+ self.pxsl = nn.PixelShuffle(2)
241
+ else:
242
+ self.conv = Conv(
243
+ in_channels,
244
+ in_channels,
245
+ kernel_size=3,
246
+ stride=1,
247
+ padding=1,
248
+ cnn_type=cnn_type,
249
+ causal_offset=self.causal_offset,
250
+ )
251
+
252
+ def forward(self, x: Tensor):
253
+ if self.use_pxsl:
254
+ x = self.conv(x)
255
+ x = self.pxsl(x)
256
+ else:
257
+ try:
258
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
259
+ except:
260
+ # shard across channel
261
+ _xs = []
262
+ for i in range(x.shape[1]):
263
+ _x = F.interpolate(
264
+ x[:, i : i + 1, ...],
265
+ scale_factor=self.scale_factor,
266
+ mode="nearest",
267
+ )
268
+ _xs.append(_x)
269
+ x = torch.cat(_xs, dim=1)
270
+ x = self.conv(x)
271
+ return x
272
+
273
+
274
+ class Encoder(nn.Module):
275
+ def __init__(
276
+ self,
277
+ ch: int,
278
+ ch_mult: list[int],
279
+ num_res_blocks: int,
280
+ z_channels: int,
281
+ in_channels=3,
282
+ patch_size=8,
283
+ temporal_patch_size=4,
284
+ norm_type="group",
285
+ cnn_param=None,
286
+ use_checkpoint=False,
287
+ use_vae=True,
288
+ ):
289
+ super().__init__()
290
+ self.max_down = np.log2(patch_size)
291
+ self.temporal_max_down = np.log2(temporal_patch_size)
292
+ self.temporal_down_offset = self.max_down - self.temporal_max_down
293
+ self.ch = ch
294
+ self.num_resolutions = len(ch_mult)
295
+ self.num_res_blocks = num_res_blocks
296
+ self.in_channels = in_channels
297
+ self.cnn_param = cnn_param
298
+ self.use_checkpoint = use_checkpoint
299
+ # downsampling
300
+ # self.conv_in = Conv(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
301
+ # cnn_param["cnn_type"] = "2d" for images, cnn_param["cnn_type"] = "3d" for videos
302
+ if cnn_param["conv_in_out_2d"] == "yes": # "yes" for video
303
+ self.conv_in = Conv(
304
+ in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type="2d"
305
+ )
306
+ else:
307
+ self.conv_in = Conv(
308
+ in_channels,
309
+ ch,
310
+ kernel_size=3,
311
+ stride=1,
312
+ padding=1,
313
+ cnn_type=cnn_param["cnn_type"],
314
+ )
315
+
316
+ in_ch_mult = (1,) + tuple(ch_mult)
317
+ self.in_ch_mult = in_ch_mult
318
+ self.down = nn.ModuleList()
319
+ block_in = self.ch
320
+ for i_level in range(self.num_resolutions):
321
+ block = nn.ModuleList()
322
+ attn = nn.ModuleList()
323
+ block_in = ch * in_ch_mult[i_level]
324
+ block_out = ch * ch_mult[i_level]
325
+ for _ in range(self.num_res_blocks):
326
+ block.append(
327
+ ResnetBlock(
328
+ in_channels=block_in,
329
+ out_channels=block_out,
330
+ norm_type=norm_type,
331
+ cnn_param=cnn_param,
332
+ )
333
+ )
334
+ block_in = block_out
335
+ down = nn.Module()
336
+ down.block = block
337
+ down.attn = attn
338
+ # downsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE
339
+ spatial_down = True if i_level < self.max_down else False
340
+ temporal_down = (
341
+ True
342
+ if i_level < self.max_down and i_level >= self.temporal_down_offset
343
+ else False
344
+ )
345
+ if spatial_down or temporal_down:
346
+ down.downsample = Downsample(
347
+ block_in,
348
+ cnn_type=cnn_param["cnn_type"],
349
+ spatial_down=spatial_down,
350
+ temporal_down=temporal_down,
351
+ )
352
+ self.down.append(down)
353
+
354
+ # middle
355
+ self.mid = nn.Module()
356
+ self.mid.block_1 = ResnetBlock(
357
+ in_channels=block_in,
358
+ out_channels=block_in,
359
+ norm_type=norm_type,
360
+ cnn_param=cnn_param,
361
+ )
362
+ if cnn_param["cnn_attention"] == "yes":
363
+ self.mid.attn_1 = AttnBlock(block_in, norm_type, cnn_param=cnn_param)
364
+ self.mid.block_2 = ResnetBlock(
365
+ in_channels=block_in,
366
+ out_channels=block_in,
367
+ norm_type=norm_type,
368
+ cnn_param=cnn_param,
369
+ )
370
+
371
+ # end
372
+ self.norm_out = Normalize(
373
+ block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"]
374
+ )
375
+ if cnn_param["conv_inner_2d"] == "yes":
376
+ self.conv_out = Conv(
377
+ block_in,
378
+ (int(use_vae) + 1) * z_channels,
379
+ kernel_size=3,
380
+ stride=1,
381
+ padding=1,
382
+ cnn_type="2d",
383
+ )
384
+ else:
385
+ self.conv_out = Conv(
386
+ block_in,
387
+ (int(use_vae) + 1) * z_channels,
388
+ kernel_size=3,
389
+ stride=1,
390
+ padding=1,
391
+ cnn_type=cnn_param["cnn_type"],
392
+ )
393
+
394
+ def forward(self, x, return_hidden=False):
395
+ if not self.use_checkpoint:
396
+ return self._forward(x, return_hidden=return_hidden)
397
+ else:
398
+ return checkpoint.checkpoint(
399
+ self._forward, x, return_hidden, use_reentrant=False
400
+ )
401
+
402
+ def _forward(self, x: Tensor, return_hidden=False) -> Tensor:
403
+ # downsampling
404
+ h0 = self.conv_in(x)
405
+ hs = [h0]
406
+ for i_level in range(self.num_resolutions):
407
+ for i_block in range(self.num_res_blocks):
408
+ h = self.down[i_level].block[i_block](hs[-1])
409
+ if len(self.down[i_level].attn) > 0:
410
+ h = self.down[i_level].attn[i_block](h)
411
+ hs.append(h)
412
+ if hasattr(self.down[i_level], "downsample"):
413
+ hs.append(self.down[i_level].downsample(hs[-1]))
414
+
415
+ # middle
416
+ h = hs[-1]
417
+ hs_mid = [h]
418
+ h = self.mid.block_1(h)
419
+ if self.cnn_param["cnn_attention"] == "yes":
420
+ h = self.mid.attn_1(h)
421
+ h = self.mid.block_2(h)
422
+ hs_mid.append(h)
423
+ # end
424
+ h = self.norm_out(h)
425
+ h = swish(h)
426
+ h = self.conv_out(h)
427
+ if return_hidden:
428
+ return h, hs, hs_mid
429
+ else:
430
+ return h
431
+
432
+
433
+ class Decoder(nn.Module):
434
+ def __init__(
435
+ self,
436
+ ch: int,
437
+ ch_mult: list[int],
438
+ num_res_blocks: int,
439
+ z_channels: int,
440
+ out_ch=3,
441
+ patch_size=8,
442
+ temporal_patch_size=4,
443
+ norm_type="group",
444
+ cnn_param=None,
445
+ use_checkpoint=False,
446
+ use_freq_dec=False, # use frequency features for decoder
447
+ use_pxsf=False,
448
+ ):
449
+ super().__init__()
450
+ self.max_up = np.log2(patch_size)
451
+ self.temporal_max_up = np.log2(temporal_patch_size)
452
+ self.temporal_up_offset = self.max_up - self.temporal_max_up
453
+ self.ch = ch
454
+ self.num_resolutions = len(ch_mult)
455
+ self.num_res_blocks = num_res_blocks
456
+ self.ffactor = 2 ** (self.num_resolutions - 1)
457
+ self.cnn_param = cnn_param
458
+ self.use_checkpoint = use_checkpoint
459
+ self.use_freq_dec = use_freq_dec
460
+ self.use_pxsf = use_pxsf
461
+
462
+ # compute in_ch_mult, block_in and curr_res at lowest res
463
+ block_in = ch * ch_mult[self.num_resolutions - 1]
464
+
465
+ # z to block_in
466
+ if cnn_param["conv_inner_2d"] == "yes":
467
+ self.conv_in = Conv(
468
+ z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type="2d"
469
+ )
470
+ else:
471
+ self.conv_in = Conv(
472
+ z_channels,
473
+ block_in,
474
+ kernel_size=3,
475
+ stride=1,
476
+ padding=1,
477
+ cnn_type=cnn_param["cnn_type"],
478
+ )
479
+
480
+ # middle
481
+ self.mid = nn.Module()
482
+ self.mid.block_1 = ResnetBlock(
483
+ in_channels=block_in,
484
+ out_channels=block_in,
485
+ norm_type=norm_type,
486
+ cnn_param=cnn_param,
487
+ )
488
+ if cnn_param["cnn_attention"] == "yes":
489
+ self.mid.attn_1 = AttnBlock(
490
+ block_in, norm_type=norm_type, cnn_param=cnn_param
491
+ )
492
+ self.mid.block_2 = ResnetBlock(
493
+ in_channels=block_in,
494
+ out_channels=block_in,
495
+ norm_type=norm_type,
496
+ cnn_param=cnn_param,
497
+ )
498
+
499
+ # upsampling
500
+ self.up = nn.ModuleList()
501
+ for i_level in reversed(range(self.num_resolutions)):
502
+ block = nn.ModuleList()
503
+ attn = nn.ModuleList()
504
+ block_out = ch * ch_mult[i_level]
505
+ for _ in range(self.num_res_blocks + 1):
506
+ block.append(
507
+ ResnetBlock(
508
+ in_channels=block_in,
509
+ out_channels=block_out,
510
+ norm_type=norm_type,
511
+ cnn_param=cnn_param,
512
+ )
513
+ )
514
+ block_in = block_out
515
+ up = nn.Module()
516
+ up.block = block
517
+ up.attn = attn
518
+ # upsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE, offset 1 compared with encoder
519
+ # https://github.com/black-forest-labs/flux/blob/b4f689aaccd40de93429865793e84a734f4a6254/src/flux/modules/autoencoder.py#L228
520
+ spatial_up = True if 1 <= i_level <= self.max_up else False
521
+ temporal_up = (
522
+ True
523
+ if 1 <= i_level <= self.max_up
524
+ and i_level >= self.temporal_up_offset + 1
525
+ else False
526
+ )
527
+ if spatial_up or temporal_up:
528
+ up.upsample = Upsample(
529
+ block_in,
530
+ cnn_type=cnn_param["cnn_type"],
531
+ spatial_up=spatial_up,
532
+ temporal_up=temporal_up,
533
+ use_pxsl=self.use_pxsf,
534
+ )
535
+ self.up.insert(0, up) # prepend to get consistent order
536
+
537
+ # end
538
+ self.norm_out = Normalize(
539
+ block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"]
540
+ )
541
+ if cnn_param["conv_in_out_2d"] == "yes":
542
+ self.conv_out = Conv(
543
+ block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type="2d"
544
+ )
545
+ else:
546
+ self.conv_out = Conv(
547
+ block_in,
548
+ out_ch,
549
+ kernel_size=3,
550
+ stride=1,
551
+ padding=1,
552
+ cnn_type=cnn_param["cnn_type"],
553
+ )
554
+
555
+ def forward(self, z):
556
+ if not self.use_checkpoint:
557
+ return self._forward(z)
558
+ else:
559
+ return checkpoint.checkpoint(self._forward, z, use_reentrant=False)
560
+
561
+ def _forward(self, z: Tensor) -> Tensor:
562
+ # z to block_in
563
+ h = self.conv_in(z)
564
+
565
+ # middle
566
+ h = self.mid.block_1(h)
567
+ if self.cnn_param["cnn_attention"] == "yes":
568
+ h = self.mid.attn_1(h)
569
+ h = self.mid.block_2(h)
570
+
571
+ # upsampling
572
+ for i_level in reversed(range(self.num_resolutions)):
573
+ for i_block in range(self.num_res_blocks + 1):
574
+ h = self.up[i_level].block[i_block](h)
575
+ if len(self.up[i_level].attn) > 0:
576
+ h = self.up[i_level].attn[i_block](h)
577
+ if hasattr(self.up[i_level], "upsample"):
578
+ h = self.up[i_level].upsample(h)
579
+
580
+ # end
581
+ h = self.norm_out(h)
582
+ h = swish(h)
583
+ h = self.conv_out(h)
584
+ return h
585
+
586
+
587
+ class AutoEncoder(nn.Module):
588
+ def __init__(self, args):
589
+ super().__init__()
590
+ self.args = args
591
+ cnn_param = dict(
592
+ cnn_type=args.cnn_type,
593
+ conv_in_out_2d=args.conv_in_out_2d,
594
+ res_conv_2d=args.res_conv_2d,
595
+ cnn_attention=args.cnn_attention,
596
+ cnn_norm_axis=args.cnn_norm_axis,
597
+ conv_inner_2d=args.conv_inner_2d,
598
+ )
599
+ self.encoder = Encoder(
600
+ ch=args.base_ch,
601
+ ch_mult=args.encoder_ch_mult,
602
+ num_res_blocks=args.num_res_blocks,
603
+ z_channels=args.codebook_dim,
604
+ patch_size=args.patch_size,
605
+ temporal_patch_size=args.temporal_patch_size,
606
+ cnn_param=cnn_param,
607
+ use_checkpoint=args.use_checkpoint,
608
+ use_vae=args.use_vae,
609
+ )
610
+ self.decoder = Decoder(
611
+ ch=args.base_ch,
612
+ ch_mult=args.decoder_ch_mult,
613
+ num_res_blocks=args.num_res_blocks,
614
+ z_channels=args.codebook_dim,
615
+ patch_size=args.patch_size,
616
+ temporal_patch_size=args.temporal_patch_size,
617
+ cnn_param=cnn_param,
618
+ use_checkpoint=args.use_checkpoint,
619
+ use_freq_dec=args.use_freq_dec,
620
+ use_pxsf=args.use_pxsf, # pixelshuffle for upsampling
621
+ )
622
+ self.z_drop = nn.Dropout(args.z_drop)
623
+ self.scale_factor = 0.3611
624
+ self.shift_factor = 0.1159
625
+ self.codebook_dim = self.embed_dim = args.codebook_dim
626
+
627
+ self.gan_feat_weight = args.gan_feat_weight
628
+ self.video_perceptual_weight = args.video_perceptual_weight
629
+ self.recon_loss_type = args.recon_loss_type
630
+ self.l1_weight = args.l1_weight
631
+ self.use_vae = args.use_vae
632
+ self.kl_weight = args.kl_weight
633
+ self.lfq_weight = args.lfq_weight
634
+ self.image_gan_weight = args.image_gan_weight # image GAN loss weight
635
+ self.video_gan_weight = args.video_gan_weight # video GAN loss weight
636
+ self.perceptual_weight = args.perceptual_weight
637
+ self.flux_weight = args.flux_weight
638
+ self.cycle_weight = args.cycle_weight
639
+ self.cycle_feat_weight = args.cycle_feat_weight
640
+ self.cycle_gan_weight = args.cycle_gan_weight
641
+
642
+ self.flux_image_encoder = None
643
+
644
+ if not args.use_vae:
645
+ if args.quantizer_type == "MultiScaleBSQ":
646
+ self.quantizer = MultiScaleBSQ(
647
+ dim=args.codebook_dim, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
648
+ codebook_size=args.codebook_size, # codebook size, must be a power of 2
649
+ entropy_loss_weight=args.entropy_loss_weight, # how much weight to place on entropy loss
650
+ diversity_gamma=args.diversity_gamma, # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
651
+ preserve_norm=args.preserve_norm, # preserve norm of the input for BSQ
652
+ ln_before_quant=args.ln_before_quant, # use layer norm before quantization
653
+ ln_init_by_sqrt=args.ln_init_by_sqrt, # layer norm init value 1/sqrt(d)
654
+ commitment_loss_weight=args.commitment_loss_weight, # loss weight of commitment loss
655
+ new_quant=args.new_quant,
656
+ use_decay_factor=args.use_decay_factor,
657
+ mask_out=args.mask_out,
658
+ use_stochastic_depth=args.use_stochastic_depth,
659
+ drop_rate=args.drop_rate,
660
+ schedule_mode=args.schedule_mode,
661
+ keep_first_quant=args.keep_first_quant,
662
+ keep_last_quant=args.keep_last_quant,
663
+ remove_residual_detach=args.remove_residual_detach,
664
+ use_out_phi=args.use_out_phi,
665
+ use_out_phi_res=args.use_out_phi_res,
666
+ random_flip=args.random_flip,
667
+ flip_prob=args.flip_prob,
668
+ flip_mode=args.flip_mode,
669
+ max_flip_lvl=args.max_flip_lvl,
670
+ random_flip_1lvl=args.random_flip_1lvl,
671
+ flip_lvl_idx=args.flip_lvl_idx,
672
+ drop_when_test=args.drop_when_test,
673
+ drop_lvl_idx=args.drop_lvl_idx,
674
+ drop_lvl_num=args.drop_lvl_num,
675
+ )
676
+ self.quantize = self.quantizer
677
+ self.vocab_size = args.codebook_size
678
+ else:
679
+ raise NotImplementedError(f"{args.quantizer_type} not supported")
680
+
681
+ def forward(self, x):
682
+ is_image = x.ndim == 4
683
+ if not is_image:
684
+ B, C, T, H, W = x.shape
685
+ else:
686
+ B, C, H, W = x.shape
687
+ T = 1
688
+ enc_dtype = ptdtype[self.args.encoder_dtype]
689
+
690
+ with torch.amp.autocast("cuda", dtype=enc_dtype):
691
+ h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W
692
+ hs = [_h.detach() for _h in hs]
693
+ hs_mid = [_h.detach() for _h in hs_mid]
694
+ h = h.to(dtype=torch.float32)
695
+ # print(z.shape)
696
+ # Multiscale LFQ
697
+ z, all_indices, _, _, all_loss, _ = self.quantizer(h)
698
+ x_recon = self.decoder(z)
699
+ vq_output = {
700
+ "commitment_loss": torch.mean(all_loss)
701
+ * self.lfq_weight, # here commitment loss is sum of commitment loss and entropy penalty
702
+ "encodings": all_indices,
703
+ }
704
+ # return x_recon, vq_output
705
+ return x_recon, None, z
706
+
707
+ def encode_for_raw_features(
708
+ self, x, scale_schedule, return_residual_norm_per_scale=False
709
+ ):
710
+ is_image = x.ndim == 4
711
+ if not is_image:
712
+ B, C, T, H, W = x.shape
713
+ else:
714
+ B, C, H, W = x.shape
715
+ T = 1
716
+
717
+ enc_dtype = ptdtype[self.args.encoder_dtype]
718
+ with torch.amp.autocast("cuda", dtype=enc_dtype):
719
+ h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W
720
+
721
+ hs = [_h.detach() for _h in hs]
722
+ hs_mid = [_h.detach() for _h in hs_mid]
723
+ h = h.to(dtype=torch.float32)
724
+ return h, hs, hs_mid
725
+
726
+ def encode(self, x, scale_schedule, return_residual_norm_per_scale=False):
727
+ h, hs, hs_mid = self.encode_for_raw_features(
728
+ x, scale_schedule, return_residual_norm_per_scale
729
+ )
730
+ # Multiscale LFQ
731
+ (
732
+ z,
733
+ all_indices,
734
+ all_bit_indices,
735
+ residual_norm_per_scale,
736
+ all_loss,
737
+ var_input,
738
+ ) = self.quantizer(
739
+ h,
740
+ scale_schedule=scale_schedule,
741
+ return_residual_norm_per_scale=return_residual_norm_per_scale,
742
+ )
743
+ return h, z, all_indices, all_bit_indices, residual_norm_per_scale, var_input
744
+
745
+ def decode(self, z):
746
+ x_recon = self.decoder(z)
747
+ x_recon = torch.clamp(x_recon, min=-1, max=1)
748
+ return x_recon
749
+
750
+ def decode_from_indices(self, all_indices, scale_schedule, label_type):
751
+ summed_codes = 0
752
+ for idx_Bl in all_indices:
753
+ codes = self.quantizer.lfq.indices_to_codes(idx_Bl, label_type)
754
+ summed_codes += F.interpolate(
755
+ codes, size=scale_schedule[-1], mode=self.quantizer.z_interplote_up
756
+ )
757
+ assert summed_codes.shape[-3] == 1
758
+ x_recon = self.decoder(summed_codes.squeeze(-3))
759
+ x_recon = torch.clamp(x_recon, min=-1, max=1)
760
+ return summed_codes, x_recon
761
+
762
+ @staticmethod
763
+ def add_model_specific_args(parent_parser):
764
+ parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
765
+ parser.add_argument("--flux_weight", type=float, default=0)
766
+ parser.add_argument("--cycle_weight", type=float, default=0)
767
+ parser.add_argument("--cycle_feat_weight", type=float, default=0)
768
+ parser.add_argument("--cycle_gan_weight", type=float, default=0)
769
+ parser.add_argument("--cycle_loop", type=int, default=0)
770
+ parser.add_argument("--z_drop", type=float, default=0.0)
771
+ return parser