Upload 37 files
Browse files- Text2Human/configs/index_pred_net.yml +84 -0
- Text2Human/configs/parsing_gen.yml +40 -0
- Text2Human/configs/parsing_token.yml +47 -0
- Text2Human/configs/sample_from_parsing.yml +93 -0
- Text2Human/configs/sample_from_pose.yml +107 -0
- Text2Human/configs/sampler.yml +83 -0
- Text2Human/configs/vqvae_bottom.yml +72 -0
- Text2Human/configs/vqvae_top.yml +53 -0
- models/__init__.py +42 -0
- models/archs/__init__.py +0 -0
- models/archs/__pycache__/__init__.cpython-38.pyc +0 -0
- models/archs/__pycache__/fcn_arch.cpython-38.pyc +0 -0
- models/archs/__pycache__/shape_attr_embedding_arch.cpython-38.pyc +0 -0
- models/archs/__pycache__/transformer_arch.cpython-38.pyc +0 -0
- models/archs/__pycache__/unet_arch.cpython-38.pyc +0 -0
- models/archs/__pycache__/vqgan_arch.cpython-38.pyc +0 -0
- models/archs/fcn_arch.py +418 -0
- models/archs/shape_attr_embedding_arch.py +35 -0
- models/archs/transformer_arch.py +273 -0
- models/archs/unet_arch.py +693 -0
- models/archs/vqgan_arch.py +1203 -0
- models/hierarchy_inference_model.py +363 -0
- models/hierarchy_vqgan_model.py +374 -0
- models/losses/__init__.py +0 -0
- models/losses/__pycache__/__init__.cpython-38.pyc +0 -0
- models/losses/__pycache__/accuracy.cpython-38.pyc +0 -0
- models/losses/__pycache__/cross_entropy_loss.cpython-38.pyc +0 -0
- models/losses/__pycache__/segmentation_loss.cpython-38.pyc +0 -0
- models/losses/__pycache__/vqgan_loss.cpython-38.pyc +0 -0
- models/losses/accuracy.py +46 -0
- models/losses/cross_entropy_loss.py +246 -0
- models/losses/segmentation_loss.py +25 -0
- models/losses/vqgan_loss.py +114 -0
- models/parsing_gen_model.py +220 -0
- models/sample_model.py +498 -0
- models/transformer_model.py +482 -0
- models/vqgan_model.py +551 -0
Text2Human/configs/index_pred_net.yml
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: index_prediction_network
|
| 2 |
+
use_tb_logger: true
|
| 3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
| 4 |
+
gpu_ids: [3]
|
| 5 |
+
|
| 6 |
+
# dataset configs
|
| 7 |
+
batch_size: 4
|
| 8 |
+
num_workers: 4
|
| 9 |
+
train_img_dir: ./datasets/train_images
|
| 10 |
+
test_img_dir: ./datasets/test_images
|
| 11 |
+
segm_dir: ./datasets/segm
|
| 12 |
+
pose_dir: ./datasets/densepose
|
| 13 |
+
train_ann_file: ./datasets/texture_ann/train
|
| 14 |
+
val_ann_file: ./datasets/texture_ann/val
|
| 15 |
+
test_ann_file: ./datasets/texture_ann/test
|
| 16 |
+
downsample_factor: 2
|
| 17 |
+
|
| 18 |
+
model_type: VQGANTextureAwareSpatialHierarchyInferenceModel
|
| 19 |
+
# network configs
|
| 20 |
+
embed_dim: 256
|
| 21 |
+
n_embed: 1024
|
| 22 |
+
codebook_spatial_size: 2
|
| 23 |
+
|
| 24 |
+
# bottom level vqvae
|
| 25 |
+
bot_n_embed: 512
|
| 26 |
+
bot_double_z: false
|
| 27 |
+
bot_z_channels: 256
|
| 28 |
+
bot_resolution: 512
|
| 29 |
+
bot_in_channels: 3
|
| 30 |
+
bot_out_ch: 3
|
| 31 |
+
bot_ch: 128
|
| 32 |
+
bot_ch_mult: [1, 1, 2, 4]
|
| 33 |
+
bot_num_res_blocks: 2
|
| 34 |
+
bot_attn_resolutions: [64]
|
| 35 |
+
bot_dropout: 0.0
|
| 36 |
+
bot_vae_path: ./pretrained_models/vqvae_bottom.pth
|
| 37 |
+
|
| 38 |
+
# top level vqgan
|
| 39 |
+
top_double_z: false
|
| 40 |
+
top_z_channels: 256
|
| 41 |
+
top_resolution: 512
|
| 42 |
+
top_in_channels: 3
|
| 43 |
+
top_out_ch: 3
|
| 44 |
+
top_ch: 128
|
| 45 |
+
top_ch_mult: [1, 1, 2, 2, 4]
|
| 46 |
+
top_num_res_blocks: 2
|
| 47 |
+
top_attn_resolutions: [32]
|
| 48 |
+
top_dropout: 0.0
|
| 49 |
+
top_vae_path: ./pretrained_models/vqvae_top.pth
|
| 50 |
+
|
| 51 |
+
# unet configs
|
| 52 |
+
encoder_in_channels: 256
|
| 53 |
+
fc_in_channels: 64
|
| 54 |
+
fc_in_index: 4
|
| 55 |
+
fc_channels: 64
|
| 56 |
+
fc_num_convs: 1
|
| 57 |
+
fc_concat_input: False
|
| 58 |
+
fc_dropout_ratio: 0.1
|
| 59 |
+
fc_num_classes: 512
|
| 60 |
+
fc_align_corners: False
|
| 61 |
+
|
| 62 |
+
disc_layers: 3
|
| 63 |
+
disc_weight_max: 1
|
| 64 |
+
disc_start_step: 30001
|
| 65 |
+
n_channels: 3
|
| 66 |
+
ndf: 64
|
| 67 |
+
nf: 128
|
| 68 |
+
perceptual_weight: 1.0
|
| 69 |
+
|
| 70 |
+
num_segm_classes: 24
|
| 71 |
+
|
| 72 |
+
# training configs
|
| 73 |
+
val_freq: 5
|
| 74 |
+
print_freq: 100
|
| 75 |
+
weight_decay: 0
|
| 76 |
+
manual_seed: 2021
|
| 77 |
+
num_epochs: 100
|
| 78 |
+
lr: !!float 1.0e-04
|
| 79 |
+
lr_decay: step
|
| 80 |
+
gamma: 1.0
|
| 81 |
+
step: 50
|
| 82 |
+
optimizer: Adam
|
| 83 |
+
loss_function: cross_entropy
|
| 84 |
+
|
Text2Human/configs/parsing_gen.yml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: parsing_generation
|
| 2 |
+
use_tb_logger: true
|
| 3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
| 4 |
+
gpu_ids: [3]
|
| 5 |
+
|
| 6 |
+
# dataset configs
|
| 7 |
+
batch_size: 8
|
| 8 |
+
num_workers: 4
|
| 9 |
+
segm_dir: ./datasets/segm
|
| 10 |
+
pose_dir: ./datasets/densepose
|
| 11 |
+
train_ann_file: ./datasets/shape_ann/train_ann_file.txt
|
| 12 |
+
val_ann_file: ./datasets/shape_ann/val_ann_file.txt
|
| 13 |
+
test_ann_file: ./datasets/shape_ann/test_ann_file.txt
|
| 14 |
+
downsample_factor: 2
|
| 15 |
+
|
| 16 |
+
model_type: ParsingGenModel
|
| 17 |
+
# network configs
|
| 18 |
+
embedder_dim: 8
|
| 19 |
+
embedder_out_dim: 128
|
| 20 |
+
attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2]
|
| 21 |
+
encoder_in_channels: 1
|
| 22 |
+
fc_in_channels: 64
|
| 23 |
+
fc_in_index: 4
|
| 24 |
+
fc_channels: 64
|
| 25 |
+
fc_num_convs: 1
|
| 26 |
+
fc_concat_input: False
|
| 27 |
+
fc_dropout_ratio: 0.1
|
| 28 |
+
fc_num_classes: 24
|
| 29 |
+
fc_align_corners: False
|
| 30 |
+
|
| 31 |
+
# training configs
|
| 32 |
+
val_freq: 5
|
| 33 |
+
print_freq: 100
|
| 34 |
+
weight_decay: 0
|
| 35 |
+
manual_seed: 2021
|
| 36 |
+
num_epochs: 100
|
| 37 |
+
lr: !!float 1e-4
|
| 38 |
+
lr_decay: step
|
| 39 |
+
gamma: 0.1
|
| 40 |
+
step: 50
|
Text2Human/configs/parsing_token.yml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: parsing_tokenization
|
| 2 |
+
use_tb_logger: true
|
| 3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
| 4 |
+
gpu_ids: [3]
|
| 5 |
+
|
| 6 |
+
# dataset configs
|
| 7 |
+
batch_size: 4
|
| 8 |
+
num_workers: 4
|
| 9 |
+
train_img_dir: ./datasets/train_images
|
| 10 |
+
test_img_dir: ./datasets/test_images
|
| 11 |
+
segm_dir: ./datasets/segm
|
| 12 |
+
pose_dir: ./datasets/densepose
|
| 13 |
+
train_ann_file: ./datasets/texture_ann/train
|
| 14 |
+
val_ann_file: ./datasets/texture_ann/val
|
| 15 |
+
test_ann_file: ./datasets/texture_ann/test
|
| 16 |
+
downsample_factor: 2
|
| 17 |
+
|
| 18 |
+
model_type: VQSegmentationModel
|
| 19 |
+
# network configs
|
| 20 |
+
embed_dim: 32
|
| 21 |
+
n_embed: 1024
|
| 22 |
+
image_key: "segmentation"
|
| 23 |
+
n_labels: 24
|
| 24 |
+
double_z: false
|
| 25 |
+
z_channels: 32
|
| 26 |
+
resolution: 512
|
| 27 |
+
in_channels: 24
|
| 28 |
+
out_ch: 24
|
| 29 |
+
ch: 64
|
| 30 |
+
ch_mult: [1, 1, 2, 2, 4]
|
| 31 |
+
num_res_blocks: 1
|
| 32 |
+
attn_resolutions: [16]
|
| 33 |
+
dropout: 0.0
|
| 34 |
+
|
| 35 |
+
num_segm_classes: 24
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# training configs
|
| 39 |
+
val_freq: 5
|
| 40 |
+
print_freq: 100
|
| 41 |
+
weight_decay: 0
|
| 42 |
+
manual_seed: 2021
|
| 43 |
+
num_epochs: 100
|
| 44 |
+
lr: !!float 4.5e-05
|
| 45 |
+
lr_decay: step
|
| 46 |
+
gamma: 0.1
|
| 47 |
+
step: 50
|
Text2Human/configs/sample_from_parsing.yml
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sample_from_parsing
|
| 2 |
+
use_tb_logger: true
|
| 3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
| 4 |
+
gpu_ids: [3]
|
| 5 |
+
|
| 6 |
+
# dataset configs
|
| 7 |
+
batch_size: 4
|
| 8 |
+
num_workers: 4
|
| 9 |
+
test_img_dir: ./datasets/test_images
|
| 10 |
+
segm_dir: ./datasets/segm
|
| 11 |
+
pose_dir: ./datasets/densepose
|
| 12 |
+
test_ann_file: ./datasets/texture_ann/test
|
| 13 |
+
downsample_factor: 2
|
| 14 |
+
|
| 15 |
+
model_type: SampleFromParsingModel
|
| 16 |
+
# network configs
|
| 17 |
+
embed_dim: 256
|
| 18 |
+
n_embed: 1024
|
| 19 |
+
codebook_spatial_size: 2
|
| 20 |
+
|
| 21 |
+
# bottom level vqvae
|
| 22 |
+
bot_n_embed: 512
|
| 23 |
+
bot_codebook_spatial_size: 2
|
| 24 |
+
bot_double_z: false
|
| 25 |
+
bot_z_channels: 256
|
| 26 |
+
bot_resolution: 512
|
| 27 |
+
bot_in_channels: 3
|
| 28 |
+
bot_out_ch: 3
|
| 29 |
+
bot_ch: 128
|
| 30 |
+
bot_ch_mult: [1, 1, 2, 4]
|
| 31 |
+
bot_num_res_blocks: 2
|
| 32 |
+
bot_attn_resolutions: [64]
|
| 33 |
+
bot_dropout: 0.0
|
| 34 |
+
bot_vae_path: ./pretrained_models/vqvae_bottom.pth
|
| 35 |
+
|
| 36 |
+
# top level vqgan
|
| 37 |
+
top_double_z: false
|
| 38 |
+
top_z_channels: 256
|
| 39 |
+
top_resolution: 512
|
| 40 |
+
top_in_channels: 3
|
| 41 |
+
top_out_ch: 3
|
| 42 |
+
top_ch: 128
|
| 43 |
+
top_ch_mult: [1, 1, 2, 2, 4]
|
| 44 |
+
top_num_res_blocks: 2
|
| 45 |
+
top_attn_resolutions: [32]
|
| 46 |
+
top_dropout: 0.0
|
| 47 |
+
top_vae_path: ./pretrained_models/vqvae_top.pth
|
| 48 |
+
|
| 49 |
+
# unet configs
|
| 50 |
+
index_pred_encoder_in_channels: 256
|
| 51 |
+
index_pred_fc_in_channels: 64
|
| 52 |
+
index_pred_fc_in_index: 4
|
| 53 |
+
index_pred_fc_channels: 64
|
| 54 |
+
index_pred_fc_num_convs: 1
|
| 55 |
+
index_pred_fc_concat_input: False
|
| 56 |
+
index_pred_fc_dropout_ratio: 0.1
|
| 57 |
+
index_pred_fc_num_classes: 512
|
| 58 |
+
index_pred_fc_align_corners: False
|
| 59 |
+
pretrained_index_network: ./pretrained_models/index_pred_net.pth
|
| 60 |
+
|
| 61 |
+
# segmentation tokenization
|
| 62 |
+
segm_double_z: false
|
| 63 |
+
segm_z_channels: 32
|
| 64 |
+
segm_resolution: 512
|
| 65 |
+
segm_in_channels: 24
|
| 66 |
+
segm_out_ch: 24
|
| 67 |
+
segm_ch: 64
|
| 68 |
+
segm_ch_mult: [1, 1, 2, 2, 4]
|
| 69 |
+
segm_num_res_blocks: 1
|
| 70 |
+
segm_attn_resolutions: [16]
|
| 71 |
+
segm_dropout: 0.0
|
| 72 |
+
segm_num_segm_classes: 24
|
| 73 |
+
segm_n_embed: 1024
|
| 74 |
+
segm_embed_dim: 32
|
| 75 |
+
segm_token_path: ./pretrained_models/parsing_token.pth
|
| 76 |
+
|
| 77 |
+
# sampler configs
|
| 78 |
+
codebook_size: 18432
|
| 79 |
+
segm_codebook_size: 1024
|
| 80 |
+
texture_codebook_size: 18
|
| 81 |
+
bert_n_emb: 512
|
| 82 |
+
bert_n_layers: 24
|
| 83 |
+
bert_n_head: 8
|
| 84 |
+
block_size: 512 # 32 x 16
|
| 85 |
+
latent_shape: [32, 16]
|
| 86 |
+
embd_pdrop: 0.0
|
| 87 |
+
resid_pdrop: 0.0
|
| 88 |
+
attn_pdrop: 0.0
|
| 89 |
+
num_head: 18
|
| 90 |
+
pretrained_sampler: ./pretrained_models/sampler.pth
|
| 91 |
+
|
| 92 |
+
manual_seed: 2021
|
| 93 |
+
sample_steps: 256
|
Text2Human/configs/sample_from_pose.yml
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sample_from_pose
|
| 2 |
+
use_tb_logger: true
|
| 3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
| 4 |
+
gpu_ids: [3]
|
| 5 |
+
|
| 6 |
+
# dataset configs
|
| 7 |
+
batch_size: 4
|
| 8 |
+
num_workers: 4
|
| 9 |
+
pose_dir: ./datasets/densepose
|
| 10 |
+
texture_ann_file: ./datasets/texture_ann/test
|
| 11 |
+
shape_ann_path: ./datasets/shape_ann/test_ann_file.txt
|
| 12 |
+
downsample_factor: 2
|
| 13 |
+
|
| 14 |
+
model_type: SampleFromPoseModel
|
| 15 |
+
# network configs
|
| 16 |
+
embed_dim: 256
|
| 17 |
+
n_embed: 1024
|
| 18 |
+
codebook_spatial_size: 2
|
| 19 |
+
|
| 20 |
+
# bottom level vqgan
|
| 21 |
+
bot_n_embed: 512
|
| 22 |
+
bot_codebook_spatial_size: 2
|
| 23 |
+
bot_double_z: false
|
| 24 |
+
bot_z_channels: 256
|
| 25 |
+
bot_resolution: 512
|
| 26 |
+
bot_in_channels: 3
|
| 27 |
+
bot_out_ch: 3
|
| 28 |
+
bot_ch: 128
|
| 29 |
+
bot_ch_mult: [1, 1, 2, 4]
|
| 30 |
+
bot_num_res_blocks: 2
|
| 31 |
+
bot_attn_resolutions: [64]
|
| 32 |
+
bot_dropout: 0.0
|
| 33 |
+
bot_vae_path: ./pretrained_models/vqvae_bottom.pth
|
| 34 |
+
|
| 35 |
+
# top level vqgan
|
| 36 |
+
top_double_z: false
|
| 37 |
+
top_z_channels: 256
|
| 38 |
+
top_resolution: 512
|
| 39 |
+
top_in_channels: 3
|
| 40 |
+
top_out_ch: 3
|
| 41 |
+
top_ch: 128
|
| 42 |
+
top_ch_mult: [1, 1, 2, 2, 4]
|
| 43 |
+
top_num_res_blocks: 2
|
| 44 |
+
top_attn_resolutions: [32]
|
| 45 |
+
top_dropout: 0.0
|
| 46 |
+
top_vae_path: ./pretrained_models/vqvae_top.pth
|
| 47 |
+
|
| 48 |
+
# unet configs
|
| 49 |
+
index_pred_encoder_in_channels: 256
|
| 50 |
+
index_pred_fc_in_channels: 64
|
| 51 |
+
index_pred_fc_in_index: 4
|
| 52 |
+
index_pred_fc_channels: 64
|
| 53 |
+
index_pred_fc_num_convs: 1
|
| 54 |
+
index_pred_fc_concat_input: False
|
| 55 |
+
index_pred_fc_dropout_ratio: 0.1
|
| 56 |
+
index_pred_fc_num_classes: 512
|
| 57 |
+
index_pred_fc_align_corners: False
|
| 58 |
+
pretrained_index_network: ./pretrained_models/index_pred_net.pth
|
| 59 |
+
|
| 60 |
+
# segmentation tokenization
|
| 61 |
+
segm_double_z: false
|
| 62 |
+
segm_z_channels: 32
|
| 63 |
+
segm_resolution: 512
|
| 64 |
+
segm_in_channels: 24
|
| 65 |
+
segm_out_ch: 24
|
| 66 |
+
segm_ch: 64
|
| 67 |
+
segm_ch_mult: [1, 1, 2, 2, 4]
|
| 68 |
+
segm_num_res_blocks: 1
|
| 69 |
+
segm_attn_resolutions: [16]
|
| 70 |
+
segm_dropout: 0.0
|
| 71 |
+
segm_num_segm_classes: 24
|
| 72 |
+
segm_n_embed: 1024
|
| 73 |
+
segm_embed_dim: 32
|
| 74 |
+
segm_token_path: ./pretrained_models/parsing_token.pth
|
| 75 |
+
|
| 76 |
+
# sampler configs
|
| 77 |
+
codebook_size: 18432
|
| 78 |
+
segm_codebook_size: 1024
|
| 79 |
+
texture_codebook_size: 18
|
| 80 |
+
bert_n_emb: 512
|
| 81 |
+
bert_n_layers: 24
|
| 82 |
+
bert_n_head: 8
|
| 83 |
+
block_size: 512 # 32 x 16
|
| 84 |
+
latent_shape: [32, 16]
|
| 85 |
+
embd_pdrop: 0.0
|
| 86 |
+
resid_pdrop: 0.0
|
| 87 |
+
attn_pdrop: 0.0
|
| 88 |
+
num_head: 18
|
| 89 |
+
pretrained_sampler: ./pretrained_models/sampler.pth
|
| 90 |
+
|
| 91 |
+
# shape network configs
|
| 92 |
+
shape_embedder_dim: 8
|
| 93 |
+
shape_embedder_out_dim: 128
|
| 94 |
+
shape_attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2]
|
| 95 |
+
shape_encoder_in_channels: 1
|
| 96 |
+
shape_fc_in_channels: 64
|
| 97 |
+
shape_fc_in_index: 4
|
| 98 |
+
shape_fc_channels: 64
|
| 99 |
+
shape_fc_num_convs: 1
|
| 100 |
+
shape_fc_concat_input: False
|
| 101 |
+
shape_fc_dropout_ratio: 0.1
|
| 102 |
+
shape_fc_num_classes: 24
|
| 103 |
+
shape_fc_align_corners: False
|
| 104 |
+
pretrained_parsing_gen: ./pretrained_models/parsing_gen.pth
|
| 105 |
+
|
| 106 |
+
manual_seed: 2021
|
| 107 |
+
sample_steps: 256
|
Text2Human/configs/sampler.yml
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sampler
|
| 2 |
+
use_tb_logger: true
|
| 3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
| 4 |
+
gpu_ids: [3]
|
| 5 |
+
|
| 6 |
+
# dataset configs
|
| 7 |
+
batch_size: 4
|
| 8 |
+
num_workers: 1
|
| 9 |
+
train_img_dir: ./datasets/train_images
|
| 10 |
+
test_img_dir: ./datasets/test_images
|
| 11 |
+
segm_dir: ./datasets/segm
|
| 12 |
+
pose_dir: ./datasets/densepose
|
| 13 |
+
train_ann_file: ./datasets/texture_ann/train
|
| 14 |
+
val_ann_file: ./datasets/texture_ann/val
|
| 15 |
+
test_ann_file: ./datasets/texture_ann/test
|
| 16 |
+
downsample_factor: 2
|
| 17 |
+
|
| 18 |
+
# pretrained models
|
| 19 |
+
img_ae_path: ./pretrained_models/vqvae_top.pth
|
| 20 |
+
segm_ae_path: ./pretrained_models/parsing_token.pth
|
| 21 |
+
|
| 22 |
+
model_type: TransformerTextureAwareModel
|
| 23 |
+
# network configs
|
| 24 |
+
|
| 25 |
+
# image autoencoder
|
| 26 |
+
img_embed_dim: 256
|
| 27 |
+
img_n_embed: 1024
|
| 28 |
+
img_double_z: false
|
| 29 |
+
img_z_channels: 256
|
| 30 |
+
img_resolution: 512
|
| 31 |
+
img_in_channels: 3
|
| 32 |
+
img_out_ch: 3
|
| 33 |
+
img_ch: 128
|
| 34 |
+
img_ch_mult: [1, 1, 2, 2, 4]
|
| 35 |
+
img_num_res_blocks: 2
|
| 36 |
+
img_attn_resolutions: [32]
|
| 37 |
+
img_dropout: 0.0
|
| 38 |
+
|
| 39 |
+
# segmentation tokenization
|
| 40 |
+
segm_double_z: false
|
| 41 |
+
segm_z_channels: 32
|
| 42 |
+
segm_resolution: 512
|
| 43 |
+
segm_in_channels: 24
|
| 44 |
+
segm_out_ch: 24
|
| 45 |
+
segm_ch: 64
|
| 46 |
+
segm_ch_mult: [1, 1, 2, 2, 4]
|
| 47 |
+
segm_num_res_blocks: 1
|
| 48 |
+
segm_attn_resolutions: [16]
|
| 49 |
+
segm_dropout: 0.0
|
| 50 |
+
segm_num_segm_classes: 24
|
| 51 |
+
segm_n_embed: 1024
|
| 52 |
+
segm_embed_dim: 32
|
| 53 |
+
|
| 54 |
+
# sampler configs
|
| 55 |
+
codebook_size: 18432
|
| 56 |
+
segm_codebook_size: 1024
|
| 57 |
+
texture_codebook_size: 18
|
| 58 |
+
bert_n_emb: 512
|
| 59 |
+
bert_n_layers: 24
|
| 60 |
+
bert_n_head: 8
|
| 61 |
+
block_size: 512 # 32 x 16
|
| 62 |
+
latent_shape: [32, 16]
|
| 63 |
+
embd_pdrop: 0.0
|
| 64 |
+
resid_pdrop: 0.0
|
| 65 |
+
attn_pdrop: 0.0
|
| 66 |
+
num_head: 18
|
| 67 |
+
|
| 68 |
+
# loss configs
|
| 69 |
+
loss_type: reweighted_elbo
|
| 70 |
+
mask_schedule: random
|
| 71 |
+
|
| 72 |
+
sample_steps: 256
|
| 73 |
+
|
| 74 |
+
# training configs
|
| 75 |
+
val_freq: 5
|
| 76 |
+
print_freq: 100
|
| 77 |
+
weight_decay: 0
|
| 78 |
+
manual_seed: 2021
|
| 79 |
+
num_epochs: 100
|
| 80 |
+
lr: !!float 1e-4
|
| 81 |
+
lr_decay: step
|
| 82 |
+
gamma: 1.0
|
| 83 |
+
step: 50
|
Text2Human/configs/vqvae_bottom.yml
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: vqvae_bottom
|
| 2 |
+
use_tb_logger: true
|
| 3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
| 4 |
+
gpu_ids: [3]
|
| 5 |
+
|
| 6 |
+
# dataset configs
|
| 7 |
+
batch_size: 4
|
| 8 |
+
num_workers: 4
|
| 9 |
+
train_img_dir: ./datasets/train_images
|
| 10 |
+
test_img_dir: ./datasets/test_images
|
| 11 |
+
segm_dir: ./datasets/segm
|
| 12 |
+
pose_dir: ./datasets/densepose
|
| 13 |
+
train_ann_file: ./datasets/texture_ann/train
|
| 14 |
+
val_ann_file: ./datasets/texture_ann/val
|
| 15 |
+
test_ann_file: ./datasets/texture_ann/test
|
| 16 |
+
downsample_factor: 2
|
| 17 |
+
|
| 18 |
+
model_type: HierarchyVQSpatialTextureAwareModel
|
| 19 |
+
# network configs
|
| 20 |
+
embed_dim: 256
|
| 21 |
+
n_embed: 1024
|
| 22 |
+
codebook_spatial_size: 2
|
| 23 |
+
|
| 24 |
+
# bottom level vqvae
|
| 25 |
+
bot_n_embed: 512
|
| 26 |
+
bot_double_z: false
|
| 27 |
+
bot_z_channels: 256
|
| 28 |
+
bot_resolution: 512
|
| 29 |
+
bot_in_channels: 3
|
| 30 |
+
bot_out_ch: 3
|
| 31 |
+
bot_ch: 128
|
| 32 |
+
bot_ch_mult: [1, 1, 2, 4]
|
| 33 |
+
bot_num_res_blocks: 2
|
| 34 |
+
bot_attn_resolutions: [64]
|
| 35 |
+
bot_dropout: 0.0
|
| 36 |
+
|
| 37 |
+
# top level vqgan
|
| 38 |
+
top_double_z: false
|
| 39 |
+
top_z_channels: 256
|
| 40 |
+
top_resolution: 512
|
| 41 |
+
top_in_channels: 3
|
| 42 |
+
top_out_ch: 3
|
| 43 |
+
top_ch: 128
|
| 44 |
+
top_ch_mult: [1, 1, 2, 2, 4]
|
| 45 |
+
top_num_res_blocks: 2
|
| 46 |
+
top_attn_resolutions: [32]
|
| 47 |
+
top_dropout: 0.0
|
| 48 |
+
top_vae_path: ./pretrained_models/vqvae_top.pth
|
| 49 |
+
|
| 50 |
+
fix_decoder: false
|
| 51 |
+
|
| 52 |
+
disc_layers: 3
|
| 53 |
+
disc_weight_max: 1
|
| 54 |
+
disc_start_step: 1
|
| 55 |
+
n_channels: 3
|
| 56 |
+
ndf: 64
|
| 57 |
+
nf: 128
|
| 58 |
+
perceptual_weight: 1.0
|
| 59 |
+
|
| 60 |
+
num_segm_classes: 24
|
| 61 |
+
|
| 62 |
+
# training configs
|
| 63 |
+
val_freq: 5
|
| 64 |
+
print_freq: 100
|
| 65 |
+
weight_decay: 0
|
| 66 |
+
manual_seed: 2021
|
| 67 |
+
num_epochs: 1000
|
| 68 |
+
lr: !!float 1.0e-04
|
| 69 |
+
lr_decay: step
|
| 70 |
+
gamma: 1.0
|
| 71 |
+
step: 50
|
| 72 |
+
|
Text2Human/configs/vqvae_top.yml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: vqvae_top
|
| 2 |
+
use_tb_logger: true
|
| 3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
| 4 |
+
gpu_ids: [3]
|
| 5 |
+
|
| 6 |
+
# dataset configs
|
| 7 |
+
batch_size: 4
|
| 8 |
+
num_workers: 4
|
| 9 |
+
train_img_dir: ./datasets/train_images
|
| 10 |
+
test_img_dir: ./datasets/test_images
|
| 11 |
+
segm_dir: ./datasets/segm
|
| 12 |
+
pose_dir: ./datasets/densepose
|
| 13 |
+
train_ann_file: ./datasets/texture_ann/train
|
| 14 |
+
val_ann_file: ./datasets/texture_ann/val
|
| 15 |
+
test_ann_file: ./datasets/texture_ann/test
|
| 16 |
+
downsample_factor: 2
|
| 17 |
+
|
| 18 |
+
model_type: VQImageSegmTextureModel
|
| 19 |
+
# network configs
|
| 20 |
+
embed_dim: 256
|
| 21 |
+
n_embed: 1024
|
| 22 |
+
double_z: false
|
| 23 |
+
z_channels: 256
|
| 24 |
+
resolution: 512
|
| 25 |
+
in_channels: 3
|
| 26 |
+
out_ch: 3
|
| 27 |
+
ch: 128
|
| 28 |
+
ch_mult: [1, 1, 2, 2, 4]
|
| 29 |
+
num_res_blocks: 2
|
| 30 |
+
attn_resolutions: [32]
|
| 31 |
+
dropout: 0.0
|
| 32 |
+
|
| 33 |
+
disc_layers: 3
|
| 34 |
+
disc_weight_max: 0
|
| 35 |
+
disc_start_step: 3000000000000000000000000001
|
| 36 |
+
n_channels: 3
|
| 37 |
+
ndf: 64
|
| 38 |
+
nf: 128
|
| 39 |
+
perceptual_weight: 1.0
|
| 40 |
+
|
| 41 |
+
num_segm_classes: 24
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# training configs
|
| 45 |
+
val_freq: 5
|
| 46 |
+
print_freq: 100
|
| 47 |
+
weight_decay: 0
|
| 48 |
+
manual_seed: 2021
|
| 49 |
+
num_epochs: 1000
|
| 50 |
+
lr: !!float 1.0e-04
|
| 51 |
+
lr_decay: step
|
| 52 |
+
gamma: 1.0
|
| 53 |
+
step: 50
|
models/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import importlib
|
| 3 |
+
import logging
|
| 4 |
+
import os.path as osp
|
| 5 |
+
|
| 6 |
+
# automatically scan and import model modules
|
| 7 |
+
# scan all the files under the 'models' folder and collect files ending with
|
| 8 |
+
# '_model.py'
|
| 9 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
| 10 |
+
model_filenames = [
|
| 11 |
+
osp.splitext(osp.basename(v))[0]
|
| 12 |
+
for v in glob.glob(f'{model_folder}/*_model.py')
|
| 13 |
+
]
|
| 14 |
+
# import all the model modules
|
| 15 |
+
_model_modules = [
|
| 16 |
+
importlib.import_module(f'models.{file_name}')
|
| 17 |
+
for file_name in model_filenames
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def create_model(opt):
|
| 22 |
+
"""Create model.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
opt (dict): Configuration. It constains:
|
| 26 |
+
model_type (str): Model type.
|
| 27 |
+
"""
|
| 28 |
+
model_type = opt['model_type']
|
| 29 |
+
|
| 30 |
+
# dynamically instantiation
|
| 31 |
+
for module in _model_modules:
|
| 32 |
+
model_cls = getattr(module, model_type, None)
|
| 33 |
+
if model_cls is not None:
|
| 34 |
+
break
|
| 35 |
+
if model_cls is None:
|
| 36 |
+
raise ValueError(f'Model {model_type} is not found.')
|
| 37 |
+
|
| 38 |
+
model = model_cls(opt)
|
| 39 |
+
|
| 40 |
+
logger = logging.getLogger('base')
|
| 41 |
+
logger.info(f'Model [{model.__class__.__name__}] is created.')
|
| 42 |
+
return model
|
models/archs/__init__.py
ADDED
|
File without changes
|
models/archs/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (126 Bytes). View file
|
|
|
models/archs/__pycache__/fcn_arch.cpython-38.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
models/archs/__pycache__/shape_attr_embedding_arch.cpython-38.pyc
ADDED
|
Binary file (1.33 kB). View file
|
|
|
models/archs/__pycache__/transformer_arch.cpython-38.pyc
ADDED
|
Binary file (7.61 kB). View file
|
|
|
models/archs/__pycache__/unet_arch.cpython-38.pyc
ADDED
|
Binary file (21.9 kB). View file
|
|
|
models/archs/__pycache__/vqgan_arch.cpython-38.pyc
ADDED
|
Binary file (24.5 kB). View file
|
|
|
models/archs/fcn_arch.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from mmcv.cnn import ConvModule, normal_init
|
| 4 |
+
from mmseg.ops import resize
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BaseDecodeHead(nn.Module):
|
| 8 |
+
"""Base class for BaseDecodeHead.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 12 |
+
channels (int): Channels after modules, before conv_seg.
|
| 13 |
+
num_classes (int): Number of classes.
|
| 14 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
| 15 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
| 16 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
| 17 |
+
act_cfg (dict): Config of activation layers.
|
| 18 |
+
Default: dict(type='ReLU')
|
| 19 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
| 20 |
+
input_transform (str|None): Transformation type of input features.
|
| 21 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 22 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 23 |
+
same size as first one and than concat together.
|
| 24 |
+
Usually used in FCN head of HRNet.
|
| 25 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 26 |
+
a list and passed into decode head.
|
| 27 |
+
None: Only one select feature map is allowed.
|
| 28 |
+
Default: None.
|
| 29 |
+
loss_decode (dict): Config of decode loss.
|
| 30 |
+
Default: dict(type='CrossEntropyLoss').
|
| 31 |
+
ignore_index (int | None): The label index to be ignored. When using
|
| 32 |
+
masked BCE loss, ignore_index should be set to None. Default: 255
|
| 33 |
+
sampler (dict|None): The config of segmentation map sampler.
|
| 34 |
+
Default: None.
|
| 35 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
| 36 |
+
Default: False.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self,
|
| 40 |
+
in_channels,
|
| 41 |
+
channels,
|
| 42 |
+
*,
|
| 43 |
+
num_classes,
|
| 44 |
+
dropout_ratio=0.1,
|
| 45 |
+
conv_cfg=None,
|
| 46 |
+
norm_cfg=dict(type='BN'),
|
| 47 |
+
act_cfg=dict(type='ReLU'),
|
| 48 |
+
in_index=-1,
|
| 49 |
+
input_transform=None,
|
| 50 |
+
ignore_index=255,
|
| 51 |
+
align_corners=False):
|
| 52 |
+
super(BaseDecodeHead, self).__init__()
|
| 53 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
| 54 |
+
self.channels = channels
|
| 55 |
+
self.num_classes = num_classes
|
| 56 |
+
self.dropout_ratio = dropout_ratio
|
| 57 |
+
self.conv_cfg = conv_cfg
|
| 58 |
+
self.norm_cfg = norm_cfg
|
| 59 |
+
self.act_cfg = act_cfg
|
| 60 |
+
self.in_index = in_index
|
| 61 |
+
|
| 62 |
+
self.ignore_index = ignore_index
|
| 63 |
+
self.align_corners = align_corners
|
| 64 |
+
|
| 65 |
+
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
| 66 |
+
if dropout_ratio > 0:
|
| 67 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
| 68 |
+
else:
|
| 69 |
+
self.dropout = None
|
| 70 |
+
|
| 71 |
+
def extra_repr(self):
|
| 72 |
+
"""Extra repr."""
|
| 73 |
+
s = f'input_transform={self.input_transform}, ' \
|
| 74 |
+
f'ignore_index={self.ignore_index}, ' \
|
| 75 |
+
f'align_corners={self.align_corners}'
|
| 76 |
+
return s
|
| 77 |
+
|
| 78 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
| 79 |
+
"""Check and initialize input transforms.
|
| 80 |
+
|
| 81 |
+
The in_channels, in_index and input_transform must match.
|
| 82 |
+
Specifically, when input_transform is None, only single feature map
|
| 83 |
+
will be selected. So in_channels and in_index must be of type int.
|
| 84 |
+
When input_transform
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 88 |
+
in_index (int|Sequence[int]): Input feature index.
|
| 89 |
+
input_transform (str|None): Transformation type of input features.
|
| 90 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 91 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 92 |
+
same size as first one and than concat together.
|
| 93 |
+
Usually used in FCN head of HRNet.
|
| 94 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 95 |
+
a list and passed into decode head.
|
| 96 |
+
None: Only one select feature map is allowed.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
if input_transform is not None:
|
| 100 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
| 101 |
+
self.input_transform = input_transform
|
| 102 |
+
self.in_index = in_index
|
| 103 |
+
if input_transform is not None:
|
| 104 |
+
assert isinstance(in_channels, (list, tuple))
|
| 105 |
+
assert isinstance(in_index, (list, tuple))
|
| 106 |
+
assert len(in_channels) == len(in_index)
|
| 107 |
+
if input_transform == 'resize_concat':
|
| 108 |
+
self.in_channels = sum(in_channels)
|
| 109 |
+
else:
|
| 110 |
+
self.in_channels = in_channels
|
| 111 |
+
else:
|
| 112 |
+
assert isinstance(in_channels, int)
|
| 113 |
+
assert isinstance(in_index, int)
|
| 114 |
+
self.in_channels = in_channels
|
| 115 |
+
|
| 116 |
+
def init_weights(self):
|
| 117 |
+
"""Initialize weights of classification layer."""
|
| 118 |
+
normal_init(self.conv_seg, mean=0, std=0.01)
|
| 119 |
+
|
| 120 |
+
def _transform_inputs(self, inputs):
|
| 121 |
+
"""Transform inputs for decoder.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
Tensor: The transformed inputs
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
if self.input_transform == 'resize_concat':
|
| 131 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 132 |
+
upsampled_inputs = [
|
| 133 |
+
resize(
|
| 134 |
+
input=x,
|
| 135 |
+
size=inputs[0].shape[2:],
|
| 136 |
+
mode='bilinear',
|
| 137 |
+
align_corners=self.align_corners) for x in inputs
|
| 138 |
+
]
|
| 139 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
| 140 |
+
elif self.input_transform == 'multiple_select':
|
| 141 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 142 |
+
else:
|
| 143 |
+
inputs = inputs[self.in_index]
|
| 144 |
+
|
| 145 |
+
return inputs
|
| 146 |
+
|
| 147 |
+
def forward(self, inputs):
|
| 148 |
+
"""Placeholder of forward function."""
|
| 149 |
+
pass
|
| 150 |
+
|
| 151 |
+
def cls_seg(self, feat):
|
| 152 |
+
"""Classify each pixel."""
|
| 153 |
+
if self.dropout is not None:
|
| 154 |
+
feat = self.dropout(feat)
|
| 155 |
+
output = self.conv_seg(feat)
|
| 156 |
+
return output
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class FCNHead(BaseDecodeHead):
|
| 160 |
+
"""Fully Convolution Networks for Semantic Segmentation.
|
| 161 |
+
|
| 162 |
+
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
num_convs (int): Number of convs in the head. Default: 2.
|
| 166 |
+
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
| 167 |
+
concat_input (bool): Whether concat the input and output of convs
|
| 168 |
+
before classification layer.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def __init__(self,
|
| 172 |
+
num_convs=2,
|
| 173 |
+
kernel_size=3,
|
| 174 |
+
concat_input=True,
|
| 175 |
+
**kwargs):
|
| 176 |
+
assert num_convs >= 0
|
| 177 |
+
self.num_convs = num_convs
|
| 178 |
+
self.concat_input = concat_input
|
| 179 |
+
self.kernel_size = kernel_size
|
| 180 |
+
super(FCNHead, self).__init__(**kwargs)
|
| 181 |
+
if num_convs == 0:
|
| 182 |
+
assert self.in_channels == self.channels
|
| 183 |
+
|
| 184 |
+
convs = []
|
| 185 |
+
convs.append(
|
| 186 |
+
ConvModule(
|
| 187 |
+
self.in_channels,
|
| 188 |
+
self.channels,
|
| 189 |
+
kernel_size=kernel_size,
|
| 190 |
+
padding=kernel_size // 2,
|
| 191 |
+
conv_cfg=self.conv_cfg,
|
| 192 |
+
norm_cfg=self.norm_cfg,
|
| 193 |
+
act_cfg=self.act_cfg))
|
| 194 |
+
for i in range(num_convs - 1):
|
| 195 |
+
convs.append(
|
| 196 |
+
ConvModule(
|
| 197 |
+
self.channels,
|
| 198 |
+
self.channels,
|
| 199 |
+
kernel_size=kernel_size,
|
| 200 |
+
padding=kernel_size // 2,
|
| 201 |
+
conv_cfg=self.conv_cfg,
|
| 202 |
+
norm_cfg=self.norm_cfg,
|
| 203 |
+
act_cfg=self.act_cfg))
|
| 204 |
+
if num_convs == 0:
|
| 205 |
+
self.convs = nn.Identity()
|
| 206 |
+
else:
|
| 207 |
+
self.convs = nn.Sequential(*convs)
|
| 208 |
+
if self.concat_input:
|
| 209 |
+
self.conv_cat = ConvModule(
|
| 210 |
+
self.in_channels + self.channels,
|
| 211 |
+
self.channels,
|
| 212 |
+
kernel_size=kernel_size,
|
| 213 |
+
padding=kernel_size // 2,
|
| 214 |
+
conv_cfg=self.conv_cfg,
|
| 215 |
+
norm_cfg=self.norm_cfg,
|
| 216 |
+
act_cfg=self.act_cfg)
|
| 217 |
+
|
| 218 |
+
def forward(self, inputs):
|
| 219 |
+
"""Forward function."""
|
| 220 |
+
x = self._transform_inputs(inputs)
|
| 221 |
+
output = self.convs(x)
|
| 222 |
+
if self.concat_input:
|
| 223 |
+
output = self.conv_cat(torch.cat([x, output], dim=1))
|
| 224 |
+
output = self.cls_seg(output)
|
| 225 |
+
return output
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class MultiHeadFCNHead(nn.Module):
|
| 229 |
+
"""Fully Convolution Networks for Semantic Segmentation.
|
| 230 |
+
|
| 231 |
+
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
num_convs (int): Number of convs in the head. Default: 2.
|
| 235 |
+
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
| 236 |
+
concat_input (bool): Whether concat the input and output of convs
|
| 237 |
+
before classification layer.
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
def __init__(self,
|
| 241 |
+
in_channels,
|
| 242 |
+
channels,
|
| 243 |
+
*,
|
| 244 |
+
num_classes,
|
| 245 |
+
dropout_ratio=0.1,
|
| 246 |
+
conv_cfg=None,
|
| 247 |
+
norm_cfg=dict(type='BN'),
|
| 248 |
+
act_cfg=dict(type='ReLU'),
|
| 249 |
+
in_index=-1,
|
| 250 |
+
input_transform=None,
|
| 251 |
+
ignore_index=255,
|
| 252 |
+
align_corners=False,
|
| 253 |
+
num_convs=2,
|
| 254 |
+
kernel_size=3,
|
| 255 |
+
concat_input=True,
|
| 256 |
+
num_head=18,
|
| 257 |
+
**kwargs):
|
| 258 |
+
super(MultiHeadFCNHead, self).__init__()
|
| 259 |
+
assert num_convs >= 0
|
| 260 |
+
self.num_convs = num_convs
|
| 261 |
+
self.concat_input = concat_input
|
| 262 |
+
self.kernel_size = kernel_size
|
| 263 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
| 264 |
+
self.channels = channels
|
| 265 |
+
self.num_classes = num_classes
|
| 266 |
+
self.dropout_ratio = dropout_ratio
|
| 267 |
+
self.conv_cfg = conv_cfg
|
| 268 |
+
self.norm_cfg = norm_cfg
|
| 269 |
+
self.act_cfg = act_cfg
|
| 270 |
+
self.in_index = in_index
|
| 271 |
+
self.num_head = num_head
|
| 272 |
+
|
| 273 |
+
self.ignore_index = ignore_index
|
| 274 |
+
self.align_corners = align_corners
|
| 275 |
+
|
| 276 |
+
if dropout_ratio > 0:
|
| 277 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
| 278 |
+
|
| 279 |
+
conv_seg_head_list = []
|
| 280 |
+
for _ in range(self.num_head):
|
| 281 |
+
conv_seg_head_list.append(
|
| 282 |
+
nn.Conv2d(channels, num_classes, kernel_size=1))
|
| 283 |
+
|
| 284 |
+
self.conv_seg_head_list = nn.ModuleList(conv_seg_head_list)
|
| 285 |
+
|
| 286 |
+
self.init_weights()
|
| 287 |
+
|
| 288 |
+
if num_convs == 0:
|
| 289 |
+
assert self.in_channels == self.channels
|
| 290 |
+
|
| 291 |
+
convs_list = []
|
| 292 |
+
conv_cat_list = []
|
| 293 |
+
|
| 294 |
+
for _ in range(self.num_head):
|
| 295 |
+
convs = []
|
| 296 |
+
convs.append(
|
| 297 |
+
ConvModule(
|
| 298 |
+
self.in_channels,
|
| 299 |
+
self.channels,
|
| 300 |
+
kernel_size=kernel_size,
|
| 301 |
+
padding=kernel_size // 2,
|
| 302 |
+
conv_cfg=self.conv_cfg,
|
| 303 |
+
norm_cfg=self.norm_cfg,
|
| 304 |
+
act_cfg=self.act_cfg))
|
| 305 |
+
for _ in range(num_convs - 1):
|
| 306 |
+
convs.append(
|
| 307 |
+
ConvModule(
|
| 308 |
+
self.channels,
|
| 309 |
+
self.channels,
|
| 310 |
+
kernel_size=kernel_size,
|
| 311 |
+
padding=kernel_size // 2,
|
| 312 |
+
conv_cfg=self.conv_cfg,
|
| 313 |
+
norm_cfg=self.norm_cfg,
|
| 314 |
+
act_cfg=self.act_cfg))
|
| 315 |
+
if num_convs == 0:
|
| 316 |
+
convs_list.append(nn.Identity())
|
| 317 |
+
else:
|
| 318 |
+
convs_list.append(nn.Sequential(*convs))
|
| 319 |
+
if self.concat_input:
|
| 320 |
+
conv_cat_list.append(
|
| 321 |
+
ConvModule(
|
| 322 |
+
self.in_channels + self.channels,
|
| 323 |
+
self.channels,
|
| 324 |
+
kernel_size=kernel_size,
|
| 325 |
+
padding=kernel_size // 2,
|
| 326 |
+
conv_cfg=self.conv_cfg,
|
| 327 |
+
norm_cfg=self.norm_cfg,
|
| 328 |
+
act_cfg=self.act_cfg))
|
| 329 |
+
|
| 330 |
+
self.convs_list = nn.ModuleList(convs_list)
|
| 331 |
+
self.conv_cat_list = nn.ModuleList(conv_cat_list)
|
| 332 |
+
|
| 333 |
+
def forward(self, inputs):
|
| 334 |
+
"""Forward function."""
|
| 335 |
+
x = self._transform_inputs(inputs)
|
| 336 |
+
|
| 337 |
+
output_list = []
|
| 338 |
+
for head_idx in range(self.num_head):
|
| 339 |
+
output = self.convs_list[head_idx](x)
|
| 340 |
+
if self.concat_input:
|
| 341 |
+
output = self.conv_cat_list[head_idx](
|
| 342 |
+
torch.cat([x, output], dim=1))
|
| 343 |
+
if self.dropout is not None:
|
| 344 |
+
output = self.dropout(output)
|
| 345 |
+
output = self.conv_seg_head_list[head_idx](output)
|
| 346 |
+
output_list.append(output)
|
| 347 |
+
|
| 348 |
+
return output_list
|
| 349 |
+
|
| 350 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
| 351 |
+
"""Check and initialize input transforms.
|
| 352 |
+
|
| 353 |
+
The in_channels, in_index and input_transform must match.
|
| 354 |
+
Specifically, when input_transform is None, only single feature map
|
| 355 |
+
will be selected. So in_channels and in_index must be of type int.
|
| 356 |
+
When input_transform
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 360 |
+
in_index (int|Sequence[int]): Input feature index.
|
| 361 |
+
input_transform (str|None): Transformation type of input features.
|
| 362 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 363 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 364 |
+
same size as first one and than concat together.
|
| 365 |
+
Usually used in FCN head of HRNet.
|
| 366 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 367 |
+
a list and passed into decode head.
|
| 368 |
+
None: Only one select feature map is allowed.
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
if input_transform is not None:
|
| 372 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
| 373 |
+
self.input_transform = input_transform
|
| 374 |
+
self.in_index = in_index
|
| 375 |
+
if input_transform is not None:
|
| 376 |
+
assert isinstance(in_channels, (list, tuple))
|
| 377 |
+
assert isinstance(in_index, (list, tuple))
|
| 378 |
+
assert len(in_channels) == len(in_index)
|
| 379 |
+
if input_transform == 'resize_concat':
|
| 380 |
+
self.in_channels = sum(in_channels)
|
| 381 |
+
else:
|
| 382 |
+
self.in_channels = in_channels
|
| 383 |
+
else:
|
| 384 |
+
assert isinstance(in_channels, int)
|
| 385 |
+
assert isinstance(in_index, int)
|
| 386 |
+
self.in_channels = in_channels
|
| 387 |
+
|
| 388 |
+
def init_weights(self):
|
| 389 |
+
"""Initialize weights of classification layer."""
|
| 390 |
+
for conv_seg_head in self.conv_seg_head_list:
|
| 391 |
+
normal_init(conv_seg_head, mean=0, std=0.01)
|
| 392 |
+
|
| 393 |
+
def _transform_inputs(self, inputs):
|
| 394 |
+
"""Transform inputs for decoder.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 398 |
+
|
| 399 |
+
Returns:
|
| 400 |
+
Tensor: The transformed inputs
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
if self.input_transform == 'resize_concat':
|
| 404 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 405 |
+
upsampled_inputs = [
|
| 406 |
+
resize(
|
| 407 |
+
input=x,
|
| 408 |
+
size=inputs[0].shape[2:],
|
| 409 |
+
mode='bilinear',
|
| 410 |
+
align_corners=self.align_corners) for x in inputs
|
| 411 |
+
]
|
| 412 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
| 413 |
+
elif self.input_transform == 'multiple_select':
|
| 414 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 415 |
+
else:
|
| 416 |
+
inputs = inputs[self.in_index]
|
| 417 |
+
|
| 418 |
+
return inputs
|
models/archs/shape_attr_embedding_arch.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ShapeAttrEmbedding(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self, dim, out_dim, cls_num_list):
|
| 9 |
+
super(ShapeAttrEmbedding, self).__init__()
|
| 10 |
+
|
| 11 |
+
for idx, cls_num in enumerate(cls_num_list):
|
| 12 |
+
setattr(
|
| 13 |
+
self, f'attr_{idx}',
|
| 14 |
+
nn.Sequential(
|
| 15 |
+
nn.Linear(cls_num, dim), nn.LeakyReLU(),
|
| 16 |
+
nn.Linear(dim, dim)))
|
| 17 |
+
self.cls_num_list = cls_num_list
|
| 18 |
+
self.attr_num = len(cls_num_list)
|
| 19 |
+
self.fusion = nn.Sequential(
|
| 20 |
+
nn.Linear(dim * self.attr_num, out_dim), nn.LeakyReLU(),
|
| 21 |
+
nn.Linear(out_dim, out_dim))
|
| 22 |
+
|
| 23 |
+
def forward(self, attr):
|
| 24 |
+
attr_embedding_list = []
|
| 25 |
+
for idx in range(self.attr_num):
|
| 26 |
+
attr_embed_fc = getattr(self, f'attr_{idx}')
|
| 27 |
+
attr_embedding_list.append(
|
| 28 |
+
attr_embed_fc(
|
| 29 |
+
F.one_hot(
|
| 30 |
+
attr[:, idx],
|
| 31 |
+
num_classes=self.cls_num_list[idx]).to(torch.float32)))
|
| 32 |
+
attr_embedding = torch.cat(attr_embedding_list, dim=1)
|
| 33 |
+
attr_embedding = self.fusion(attr_embedding)
|
| 34 |
+
|
| 35 |
+
return attr_embedding
|
models/archs/transformer_arch.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CausalSelfAttention(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
A vanilla multi-head masked self-attention layer with a projection at the end.
|
| 12 |
+
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
| 13 |
+
explicit implementation here to show that there is nothing too scary here.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, bert_n_emb, bert_n_head, attn_pdrop, resid_pdrop,
|
| 17 |
+
latent_shape, sampler):
|
| 18 |
+
super().__init__()
|
| 19 |
+
assert bert_n_emb % bert_n_head == 0
|
| 20 |
+
# key, query, value projections for all heads
|
| 21 |
+
self.key = nn.Linear(bert_n_emb, bert_n_emb)
|
| 22 |
+
self.query = nn.Linear(bert_n_emb, bert_n_emb)
|
| 23 |
+
self.value = nn.Linear(bert_n_emb, bert_n_emb)
|
| 24 |
+
# regularization
|
| 25 |
+
self.attn_drop = nn.Dropout(attn_pdrop)
|
| 26 |
+
self.resid_drop = nn.Dropout(resid_pdrop)
|
| 27 |
+
# output projection
|
| 28 |
+
self.proj = nn.Linear(bert_n_emb, bert_n_emb)
|
| 29 |
+
self.n_head = bert_n_head
|
| 30 |
+
self.causal = True if sampler == 'autoregressive' else False
|
| 31 |
+
if self.causal:
|
| 32 |
+
block_size = np.prod(latent_shape)
|
| 33 |
+
mask = torch.tril(torch.ones(block_size, block_size))
|
| 34 |
+
self.register_buffer("mask", mask.view(1, 1, block_size,
|
| 35 |
+
block_size))
|
| 36 |
+
|
| 37 |
+
def forward(self, x, layer_past=None):
|
| 38 |
+
B, T, C = x.size()
|
| 39 |
+
|
| 40 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
| 41 |
+
k = self.key(x).view(B, T, self.n_head,
|
| 42 |
+
C // self.n_head).transpose(1,
|
| 43 |
+
2) # (B, nh, T, hs)
|
| 44 |
+
q = self.query(x).view(B, T, self.n_head,
|
| 45 |
+
C // self.n_head).transpose(1,
|
| 46 |
+
2) # (B, nh, T, hs)
|
| 47 |
+
v = self.value(x).view(B, T, self.n_head,
|
| 48 |
+
C // self.n_head).transpose(1,
|
| 49 |
+
2) # (B, nh, T, hs)
|
| 50 |
+
|
| 51 |
+
present = torch.stack((k, v))
|
| 52 |
+
if self.causal and layer_past is not None:
|
| 53 |
+
past_key, past_value = layer_past
|
| 54 |
+
k = torch.cat((past_key, k), dim=-2)
|
| 55 |
+
v = torch.cat((past_value, v), dim=-2)
|
| 56 |
+
|
| 57 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
| 58 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
| 59 |
+
|
| 60 |
+
if self.causal and layer_past is None:
|
| 61 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
|
| 62 |
+
|
| 63 |
+
att = F.softmax(att, dim=-1)
|
| 64 |
+
att = self.attn_drop(att)
|
| 65 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
| 66 |
+
# re-assemble all head outputs side by side
|
| 67 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 68 |
+
|
| 69 |
+
# output projection
|
| 70 |
+
y = self.resid_drop(self.proj(y))
|
| 71 |
+
return y, present
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Block(nn.Module):
|
| 75 |
+
""" an unassuming Transformer block """
|
| 76 |
+
|
| 77 |
+
def __init__(self, bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
|
| 78 |
+
latent_shape, sampler):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.ln1 = nn.LayerNorm(bert_n_emb)
|
| 81 |
+
self.ln2 = nn.LayerNorm(bert_n_emb)
|
| 82 |
+
self.attn = CausalSelfAttention(bert_n_emb, bert_n_head, attn_pdrop,
|
| 83 |
+
resid_pdrop, latent_shape, sampler)
|
| 84 |
+
self.mlp = nn.Sequential(
|
| 85 |
+
nn.Linear(bert_n_emb, 4 * bert_n_emb),
|
| 86 |
+
nn.GELU(), # nice
|
| 87 |
+
nn.Linear(4 * bert_n_emb, bert_n_emb),
|
| 88 |
+
nn.Dropout(resid_pdrop),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def forward(self, x, layer_past=None, return_present=False):
|
| 92 |
+
|
| 93 |
+
attn, present = self.attn(self.ln1(x), layer_past)
|
| 94 |
+
x = x + attn
|
| 95 |
+
x = x + self.mlp(self.ln2(x))
|
| 96 |
+
|
| 97 |
+
if layer_past is not None or return_present:
|
| 98 |
+
return x, present
|
| 99 |
+
return x
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class Transformer(nn.Module):
|
| 103 |
+
""" the full GPT language model, with a context size of block_size """
|
| 104 |
+
|
| 105 |
+
def __init__(self,
|
| 106 |
+
codebook_size,
|
| 107 |
+
segm_codebook_size,
|
| 108 |
+
bert_n_emb,
|
| 109 |
+
bert_n_layers,
|
| 110 |
+
bert_n_head,
|
| 111 |
+
block_size,
|
| 112 |
+
latent_shape,
|
| 113 |
+
embd_pdrop,
|
| 114 |
+
resid_pdrop,
|
| 115 |
+
attn_pdrop,
|
| 116 |
+
sampler='absorbing'):
|
| 117 |
+
super().__init__()
|
| 118 |
+
|
| 119 |
+
self.vocab_size = codebook_size + 1
|
| 120 |
+
self.n_embd = bert_n_emb
|
| 121 |
+
self.block_size = block_size
|
| 122 |
+
self.n_layers = bert_n_layers
|
| 123 |
+
self.codebook_size = codebook_size
|
| 124 |
+
self.segm_codebook_size = segm_codebook_size
|
| 125 |
+
self.causal = sampler == 'autoregressive'
|
| 126 |
+
if self.causal:
|
| 127 |
+
self.vocab_size = codebook_size
|
| 128 |
+
|
| 129 |
+
self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
|
| 130 |
+
self.pos_emb = nn.Parameter(
|
| 131 |
+
torch.zeros(1, self.block_size, self.n_embd))
|
| 132 |
+
self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd)
|
| 133 |
+
self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
|
| 134 |
+
self.drop = nn.Dropout(embd_pdrop)
|
| 135 |
+
|
| 136 |
+
# transformer
|
| 137 |
+
self.blocks = nn.Sequential(*[
|
| 138 |
+
Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
|
| 139 |
+
latent_shape, sampler) for _ in range(self.n_layers)
|
| 140 |
+
])
|
| 141 |
+
# decoder head
|
| 142 |
+
self.ln_f = nn.LayerNorm(self.n_embd)
|
| 143 |
+
self.head = nn.Linear(self.n_embd, self.codebook_size, bias=False)
|
| 144 |
+
|
| 145 |
+
def get_block_size(self):
|
| 146 |
+
return self.block_size
|
| 147 |
+
|
| 148 |
+
def _init_weights(self, module):
|
| 149 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 150 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 151 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 152 |
+
module.bias.data.zero_()
|
| 153 |
+
elif isinstance(module, nn.LayerNorm):
|
| 154 |
+
module.bias.data.zero_()
|
| 155 |
+
module.weight.data.fill_(1.0)
|
| 156 |
+
|
| 157 |
+
def forward(self, idx, segm_tokens, t=None):
|
| 158 |
+
# each index maps to a (learnable) vector
|
| 159 |
+
token_embeddings = self.tok_emb(idx)
|
| 160 |
+
|
| 161 |
+
segm_embeddings = self.segm_emb(segm_tokens)
|
| 162 |
+
|
| 163 |
+
if self.causal:
|
| 164 |
+
token_embeddings = torch.cat((self.start_tok.repeat(
|
| 165 |
+
token_embeddings.size(0), 1, 1), token_embeddings),
|
| 166 |
+
dim=1)
|
| 167 |
+
|
| 168 |
+
t = token_embeddings.shape[1]
|
| 169 |
+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
| 170 |
+
# each position maps to a (learnable) vector
|
| 171 |
+
|
| 172 |
+
position_embeddings = self.pos_emb[:, :t, :]
|
| 173 |
+
|
| 174 |
+
x = token_embeddings + position_embeddings + segm_embeddings
|
| 175 |
+
x = self.drop(x)
|
| 176 |
+
for block in self.blocks:
|
| 177 |
+
x = block(x)
|
| 178 |
+
x = self.ln_f(x)
|
| 179 |
+
logits = self.head(x)
|
| 180 |
+
|
| 181 |
+
return logits
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class TransformerMultiHead(nn.Module):
|
| 185 |
+
""" the full GPT language model, with a context size of block_size """
|
| 186 |
+
|
| 187 |
+
def __init__(self,
|
| 188 |
+
codebook_size,
|
| 189 |
+
segm_codebook_size,
|
| 190 |
+
texture_codebook_size,
|
| 191 |
+
bert_n_emb,
|
| 192 |
+
bert_n_layers,
|
| 193 |
+
bert_n_head,
|
| 194 |
+
block_size,
|
| 195 |
+
latent_shape,
|
| 196 |
+
embd_pdrop,
|
| 197 |
+
resid_pdrop,
|
| 198 |
+
attn_pdrop,
|
| 199 |
+
num_head,
|
| 200 |
+
sampler='absorbing'):
|
| 201 |
+
super().__init__()
|
| 202 |
+
|
| 203 |
+
self.vocab_size = codebook_size + 1
|
| 204 |
+
self.n_embd = bert_n_emb
|
| 205 |
+
self.block_size = block_size
|
| 206 |
+
self.n_layers = bert_n_layers
|
| 207 |
+
self.codebook_size = codebook_size
|
| 208 |
+
self.segm_codebook_size = segm_codebook_size
|
| 209 |
+
self.texture_codebook_size = texture_codebook_size
|
| 210 |
+
self.causal = sampler == 'autoregressive'
|
| 211 |
+
if self.causal:
|
| 212 |
+
self.vocab_size = codebook_size
|
| 213 |
+
|
| 214 |
+
self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
|
| 215 |
+
self.pos_emb = nn.Parameter(
|
| 216 |
+
torch.zeros(1, self.block_size, self.n_embd))
|
| 217 |
+
self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd)
|
| 218 |
+
self.texture_emb = nn.Embedding(self.texture_codebook_size,
|
| 219 |
+
self.n_embd)
|
| 220 |
+
self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
|
| 221 |
+
self.drop = nn.Dropout(embd_pdrop)
|
| 222 |
+
|
| 223 |
+
# transformer
|
| 224 |
+
self.blocks = nn.Sequential(*[
|
| 225 |
+
Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
|
| 226 |
+
latent_shape, sampler) for _ in range(self.n_layers)
|
| 227 |
+
])
|
| 228 |
+
# decoder head
|
| 229 |
+
self.num_head = num_head
|
| 230 |
+
self.head_class_num = codebook_size // self.num_head
|
| 231 |
+
self.ln_f = nn.LayerNorm(self.n_embd)
|
| 232 |
+
self.head_list = nn.ModuleList([
|
| 233 |
+
nn.Linear(self.n_embd, self.head_class_num, bias=False)
|
| 234 |
+
for _ in range(self.num_head)
|
| 235 |
+
])
|
| 236 |
+
|
| 237 |
+
def get_block_size(self):
|
| 238 |
+
return self.block_size
|
| 239 |
+
|
| 240 |
+
def _init_weights(self, module):
|
| 241 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 242 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 243 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 244 |
+
module.bias.data.zero_()
|
| 245 |
+
elif isinstance(module, nn.LayerNorm):
|
| 246 |
+
module.bias.data.zero_()
|
| 247 |
+
module.weight.data.fill_(1.0)
|
| 248 |
+
|
| 249 |
+
def forward(self, idx, segm_tokens, texture_tokens, t=None):
|
| 250 |
+
# each index maps to a (learnable) vector
|
| 251 |
+
token_embeddings = self.tok_emb(idx)
|
| 252 |
+
segm_embeddings = self.segm_emb(segm_tokens)
|
| 253 |
+
texture_embeddings = self.texture_emb(texture_tokens)
|
| 254 |
+
|
| 255 |
+
if self.causal:
|
| 256 |
+
token_embeddings = torch.cat((self.start_tok.repeat(
|
| 257 |
+
token_embeddings.size(0), 1, 1), token_embeddings),
|
| 258 |
+
dim=1)
|
| 259 |
+
|
| 260 |
+
t = token_embeddings.shape[1]
|
| 261 |
+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
| 262 |
+
# each position maps to a (learnable) vector
|
| 263 |
+
|
| 264 |
+
position_embeddings = self.pos_emb[:, :t, :]
|
| 265 |
+
|
| 266 |
+
x = token_embeddings + position_embeddings + segm_embeddings + texture_embeddings
|
| 267 |
+
x = self.drop(x)
|
| 268 |
+
for block in self.blocks:
|
| 269 |
+
x = block(x)
|
| 270 |
+
x = self.ln_f(x)
|
| 271 |
+
logits_list = [self.head_list[i](x) for i in range(self.num_head)]
|
| 272 |
+
|
| 273 |
+
return logits_list
|
models/archs/unet_arch.py
ADDED
|
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.utils.checkpoint as cp
|
| 4 |
+
from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
|
| 5 |
+
build_norm_layer, build_upsample_layer, constant_init,
|
| 6 |
+
kaiming_init)
|
| 7 |
+
from mmcv.runner import load_checkpoint
|
| 8 |
+
from mmcv.utils.parrots_wrapper import _BatchNorm
|
| 9 |
+
from mmseg.utils import get_root_logger
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class UpConvBlock(nn.Module):
|
| 13 |
+
"""Upsample convolution block in decoder for UNet.
|
| 14 |
+
|
| 15 |
+
This upsample convolution block consists of one upsample module
|
| 16 |
+
followed by one convolution block. The upsample module expands the
|
| 17 |
+
high-level low-resolution feature map and the convolution block fuses
|
| 18 |
+
the upsampled high-level low-resolution feature map and the low-level
|
| 19 |
+
high-resolution feature map from encoder.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
conv_block (nn.Sequential): Sequential of convolutional layers.
|
| 23 |
+
in_channels (int): Number of input channels of the high-level
|
| 24 |
+
skip_channels (int): Number of input channels of the low-level
|
| 25 |
+
high-resolution feature map from encoder.
|
| 26 |
+
out_channels (int): Number of output channels.
|
| 27 |
+
num_convs (int): Number of convolutional layers in the conv_block.
|
| 28 |
+
Default: 2.
|
| 29 |
+
stride (int): Stride of convolutional layer in conv_block. Default: 1.
|
| 30 |
+
dilation (int): Dilation rate of convolutional layer in conv_block.
|
| 31 |
+
Default: 1.
|
| 32 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
| 33 |
+
memory while slowing down the training speed. Default: False.
|
| 34 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
| 35 |
+
Default: None.
|
| 36 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
| 37 |
+
Default: dict(type='BN').
|
| 38 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
| 39 |
+
Default: dict(type='ReLU').
|
| 40 |
+
upsample_cfg (dict): The upsample config of the upsample module in
|
| 41 |
+
decoder. Default: dict(type='InterpConv'). If the size of
|
| 42 |
+
high-level feature map is the same as that of skip feature map
|
| 43 |
+
(low-level feature map from encoder), it does not need upsample the
|
| 44 |
+
high-level feature map and the upsample_cfg is None.
|
| 45 |
+
dcn (bool): Use deformable convoluton in convolutional layer or not.
|
| 46 |
+
Default: None.
|
| 47 |
+
plugins (dict): plugins for convolutional layers. Default: None.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self,
|
| 51 |
+
conv_block,
|
| 52 |
+
in_channels,
|
| 53 |
+
skip_channels,
|
| 54 |
+
out_channels,
|
| 55 |
+
num_convs=2,
|
| 56 |
+
stride=1,
|
| 57 |
+
dilation=1,
|
| 58 |
+
with_cp=False,
|
| 59 |
+
conv_cfg=None,
|
| 60 |
+
norm_cfg=dict(type='BN'),
|
| 61 |
+
act_cfg=dict(type='ReLU'),
|
| 62 |
+
upsample_cfg=dict(type='InterpConv'),
|
| 63 |
+
dcn=None,
|
| 64 |
+
plugins=None):
|
| 65 |
+
super(UpConvBlock, self).__init__()
|
| 66 |
+
assert dcn is None, 'Not implemented yet.'
|
| 67 |
+
assert plugins is None, 'Not implemented yet.'
|
| 68 |
+
|
| 69 |
+
self.conv_block = conv_block(
|
| 70 |
+
in_channels=2 * skip_channels,
|
| 71 |
+
out_channels=out_channels,
|
| 72 |
+
num_convs=num_convs,
|
| 73 |
+
stride=stride,
|
| 74 |
+
dilation=dilation,
|
| 75 |
+
with_cp=with_cp,
|
| 76 |
+
conv_cfg=conv_cfg,
|
| 77 |
+
norm_cfg=norm_cfg,
|
| 78 |
+
act_cfg=act_cfg,
|
| 79 |
+
dcn=None,
|
| 80 |
+
plugins=None)
|
| 81 |
+
if upsample_cfg is not None:
|
| 82 |
+
self.upsample = build_upsample_layer(
|
| 83 |
+
cfg=upsample_cfg,
|
| 84 |
+
in_channels=in_channels,
|
| 85 |
+
out_channels=skip_channels,
|
| 86 |
+
with_cp=with_cp,
|
| 87 |
+
norm_cfg=norm_cfg,
|
| 88 |
+
act_cfg=act_cfg)
|
| 89 |
+
else:
|
| 90 |
+
self.upsample = ConvModule(
|
| 91 |
+
in_channels,
|
| 92 |
+
skip_channels,
|
| 93 |
+
kernel_size=1,
|
| 94 |
+
stride=1,
|
| 95 |
+
padding=0,
|
| 96 |
+
conv_cfg=conv_cfg,
|
| 97 |
+
norm_cfg=norm_cfg,
|
| 98 |
+
act_cfg=act_cfg)
|
| 99 |
+
|
| 100 |
+
def forward(self, skip, x):
|
| 101 |
+
"""Forward function."""
|
| 102 |
+
|
| 103 |
+
x = self.upsample(x)
|
| 104 |
+
out = torch.cat([skip, x], dim=1)
|
| 105 |
+
out = self.conv_block(out)
|
| 106 |
+
|
| 107 |
+
return out
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class BasicConvBlock(nn.Module):
|
| 111 |
+
"""Basic convolutional block for UNet.
|
| 112 |
+
|
| 113 |
+
This module consists of several plain convolutional layers.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
in_channels (int): Number of input channels.
|
| 117 |
+
out_channels (int): Number of output channels.
|
| 118 |
+
num_convs (int): Number of convolutional layers. Default: 2.
|
| 119 |
+
stride (int): Whether use stride convolution to downsample
|
| 120 |
+
the input feature map. If stride=2, it only uses stride convolution
|
| 121 |
+
in the first convolutional layer to downsample the input feature
|
| 122 |
+
map. Options are 1 or 2. Default: 1.
|
| 123 |
+
dilation (int): Whether use dilated convolution to expand the
|
| 124 |
+
receptive field. Set dilation rate of each convolutional layer and
|
| 125 |
+
the dilation rate of the first convolutional layer is always 1.
|
| 126 |
+
Default: 1.
|
| 127 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
| 128 |
+
memory while slowing down the training speed. Default: False.
|
| 129 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
| 130 |
+
Default: None.
|
| 131 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
| 132 |
+
Default: dict(type='BN').
|
| 133 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
| 134 |
+
Default: dict(type='ReLU').
|
| 135 |
+
dcn (bool): Use deformable convoluton in convolutional layer or not.
|
| 136 |
+
Default: None.
|
| 137 |
+
plugins (dict): plugins for convolutional layers. Default: None.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __init__(self,
|
| 141 |
+
in_channels,
|
| 142 |
+
out_channels,
|
| 143 |
+
num_convs=2,
|
| 144 |
+
stride=1,
|
| 145 |
+
dilation=1,
|
| 146 |
+
with_cp=False,
|
| 147 |
+
conv_cfg=None,
|
| 148 |
+
norm_cfg=dict(type='BN'),
|
| 149 |
+
act_cfg=dict(type='ReLU'),
|
| 150 |
+
dcn=None,
|
| 151 |
+
plugins=None):
|
| 152 |
+
super(BasicConvBlock, self).__init__()
|
| 153 |
+
assert dcn is None, 'Not implemented yet.'
|
| 154 |
+
assert plugins is None, 'Not implemented yet.'
|
| 155 |
+
|
| 156 |
+
self.with_cp = with_cp
|
| 157 |
+
convs = []
|
| 158 |
+
for i in range(num_convs):
|
| 159 |
+
convs.append(
|
| 160 |
+
ConvModule(
|
| 161 |
+
in_channels=in_channels if i == 0 else out_channels,
|
| 162 |
+
out_channels=out_channels,
|
| 163 |
+
kernel_size=3,
|
| 164 |
+
stride=stride if i == 0 else 1,
|
| 165 |
+
dilation=1 if i == 0 else dilation,
|
| 166 |
+
padding=1 if i == 0 else dilation,
|
| 167 |
+
conv_cfg=conv_cfg,
|
| 168 |
+
norm_cfg=norm_cfg,
|
| 169 |
+
act_cfg=act_cfg))
|
| 170 |
+
|
| 171 |
+
self.convs = nn.Sequential(*convs)
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
"""Forward function."""
|
| 175 |
+
|
| 176 |
+
if self.with_cp and x.requires_grad:
|
| 177 |
+
out = cp.checkpoint(self.convs, x)
|
| 178 |
+
else:
|
| 179 |
+
out = self.convs(x)
|
| 180 |
+
return out
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class DeconvModule(nn.Module):
|
| 184 |
+
"""Deconvolution upsample module in decoder for UNet (2X upsample).
|
| 185 |
+
|
| 186 |
+
This module uses deconvolution to upsample feature map in the decoder
|
| 187 |
+
of UNet.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
in_channels (int): Number of input channels.
|
| 191 |
+
out_channels (int): Number of output channels.
|
| 192 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
| 193 |
+
memory while slowing down the training speed. Default: False.
|
| 194 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
| 195 |
+
Default: dict(type='BN').
|
| 196 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
| 197 |
+
Default: dict(type='ReLU').
|
| 198 |
+
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def __init__(self,
|
| 202 |
+
in_channels,
|
| 203 |
+
out_channels,
|
| 204 |
+
with_cp=False,
|
| 205 |
+
norm_cfg=dict(type='BN'),
|
| 206 |
+
act_cfg=dict(type='ReLU'),
|
| 207 |
+
*,
|
| 208 |
+
kernel_size=4,
|
| 209 |
+
scale_factor=2):
|
| 210 |
+
super(DeconvModule, self).__init__()
|
| 211 |
+
|
| 212 |
+
assert (kernel_size - scale_factor >= 0) and\
|
| 213 |
+
(kernel_size - scale_factor) % 2 == 0,\
|
| 214 |
+
f'kernel_size should be greater than or equal to scale_factor '\
|
| 215 |
+
f'and (kernel_size - scale_factor) should be even numbers, '\
|
| 216 |
+
f'while the kernel size is {kernel_size} and scale_factor is '\
|
| 217 |
+
f'{scale_factor}.'
|
| 218 |
+
|
| 219 |
+
stride = scale_factor
|
| 220 |
+
padding = (kernel_size - scale_factor) // 2
|
| 221 |
+
self.with_cp = with_cp
|
| 222 |
+
deconv = nn.ConvTranspose2d(
|
| 223 |
+
in_channels,
|
| 224 |
+
out_channels,
|
| 225 |
+
kernel_size=kernel_size,
|
| 226 |
+
stride=stride,
|
| 227 |
+
padding=padding)
|
| 228 |
+
|
| 229 |
+
norm_name, norm = build_norm_layer(norm_cfg, out_channels)
|
| 230 |
+
activate = build_activation_layer(act_cfg)
|
| 231 |
+
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
|
| 232 |
+
|
| 233 |
+
def forward(self, x):
|
| 234 |
+
"""Forward function."""
|
| 235 |
+
|
| 236 |
+
if self.with_cp and x.requires_grad:
|
| 237 |
+
out = cp.checkpoint(self.deconv_upsamping, x)
|
| 238 |
+
else:
|
| 239 |
+
out = self.deconv_upsamping(x)
|
| 240 |
+
return out
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@UPSAMPLE_LAYERS.register_module()
|
| 244 |
+
class InterpConv(nn.Module):
|
| 245 |
+
"""Interpolation upsample module in decoder for UNet.
|
| 246 |
+
|
| 247 |
+
This module uses interpolation to upsample feature map in the decoder
|
| 248 |
+
of UNet. It consists of one interpolation upsample layer and one
|
| 249 |
+
convolutional layer. It can be one interpolation upsample layer followed
|
| 250 |
+
by one convolutional layer (conv_first=False) or one convolutional layer
|
| 251 |
+
followed by one interpolation upsample layer (conv_first=True).
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
in_channels (int): Number of input channels.
|
| 255 |
+
out_channels (int): Number of output channels.
|
| 256 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
| 257 |
+
memory while slowing down the training speed. Default: False.
|
| 258 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
| 259 |
+
Default: dict(type='BN').
|
| 260 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
| 261 |
+
Default: dict(type='ReLU').
|
| 262 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
| 263 |
+
Default: None.
|
| 264 |
+
conv_first (bool): Whether convolutional layer or interpolation
|
| 265 |
+
upsample layer first. Default: False. It means interpolation
|
| 266 |
+
upsample layer followed by one convolutional layer.
|
| 267 |
+
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
|
| 268 |
+
stride (int): Stride of the convolutional layer. Default: 1.
|
| 269 |
+
padding (int): Padding of the convolutional layer. Default: 1.
|
| 270 |
+
upsampe_cfg (dict): Interpolation config of the upsample layer.
|
| 271 |
+
Default: dict(
|
| 272 |
+
scale_factor=2, mode='bilinear', align_corners=False).
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
def __init__(self,
|
| 276 |
+
in_channels,
|
| 277 |
+
out_channels,
|
| 278 |
+
with_cp=False,
|
| 279 |
+
norm_cfg=dict(type='BN'),
|
| 280 |
+
act_cfg=dict(type='ReLU'),
|
| 281 |
+
*,
|
| 282 |
+
conv_cfg=None,
|
| 283 |
+
conv_first=False,
|
| 284 |
+
kernel_size=1,
|
| 285 |
+
stride=1,
|
| 286 |
+
padding=0,
|
| 287 |
+
upsampe_cfg=dict(
|
| 288 |
+
scale_factor=2, mode='bilinear', align_corners=False)):
|
| 289 |
+
super(InterpConv, self).__init__()
|
| 290 |
+
|
| 291 |
+
self.with_cp = with_cp
|
| 292 |
+
conv = ConvModule(
|
| 293 |
+
in_channels,
|
| 294 |
+
out_channels,
|
| 295 |
+
kernel_size=kernel_size,
|
| 296 |
+
stride=stride,
|
| 297 |
+
padding=padding,
|
| 298 |
+
conv_cfg=conv_cfg,
|
| 299 |
+
norm_cfg=norm_cfg,
|
| 300 |
+
act_cfg=act_cfg)
|
| 301 |
+
upsample = nn.Upsample(**upsampe_cfg)
|
| 302 |
+
if conv_first:
|
| 303 |
+
self.interp_upsample = nn.Sequential(conv, upsample)
|
| 304 |
+
else:
|
| 305 |
+
self.interp_upsample = nn.Sequential(upsample, conv)
|
| 306 |
+
|
| 307 |
+
def forward(self, x):
|
| 308 |
+
"""Forward function."""
|
| 309 |
+
|
| 310 |
+
if self.with_cp and x.requires_grad:
|
| 311 |
+
out = cp.checkpoint(self.interp_upsample, x)
|
| 312 |
+
else:
|
| 313 |
+
out = self.interp_upsample(x)
|
| 314 |
+
return out
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class UNet(nn.Module):
|
| 318 |
+
"""UNet backbone.
|
| 319 |
+
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
| 320 |
+
https://arxiv.org/pdf/1505.04597.pdf
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
in_channels (int): Number of input image channels. Default" 3.
|
| 324 |
+
base_channels (int): Number of base channels of each stage.
|
| 325 |
+
The output channels of the first stage. Default: 64.
|
| 326 |
+
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
|
| 327 |
+
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
|
| 328 |
+
len(strides) is equal to num_stages. Normally the stride of the
|
| 329 |
+
first stage in encoder is 1. If strides[i]=2, it uses stride
|
| 330 |
+
convolution to downsample in the correspondence encoder stage.
|
| 331 |
+
Default: (1, 1, 1, 1, 1).
|
| 332 |
+
enc_num_convs (Sequence[int]): Number of convolutional layers in the
|
| 333 |
+
convolution block of the correspondence encoder stage.
|
| 334 |
+
Default: (2, 2, 2, 2, 2).
|
| 335 |
+
dec_num_convs (Sequence[int]): Number of convolutional layers in the
|
| 336 |
+
convolution block of the correspondence decoder stage.
|
| 337 |
+
Default: (2, 2, 2, 2).
|
| 338 |
+
downsamples (Sequence[int]): Whether use MaxPool to downsample the
|
| 339 |
+
feature map after the first stage of encoder
|
| 340 |
+
(stages: [1, num_stages)). If the correspondence encoder stage use
|
| 341 |
+
stride convolution (strides[i]=2), it will never use MaxPool to
|
| 342 |
+
downsample, even downsamples[i-1]=True.
|
| 343 |
+
Default: (True, True, True, True).
|
| 344 |
+
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
|
| 345 |
+
Default: (1, 1, 1, 1, 1).
|
| 346 |
+
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
|
| 347 |
+
Default: (1, 1, 1, 1).
|
| 348 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
| 349 |
+
memory while slowing down the training speed. Default: False.
|
| 350 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
| 351 |
+
Default: None.
|
| 352 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
| 353 |
+
Default: dict(type='BN').
|
| 354 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
| 355 |
+
Default: dict(type='ReLU').
|
| 356 |
+
upsample_cfg (dict): The upsample config of the upsample module in
|
| 357 |
+
decoder. Default: dict(type='InterpConv').
|
| 358 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
| 359 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
| 360 |
+
and its variants only. Default: False.
|
| 361 |
+
dcn (bool): Use deformable convolution in convolutional layer or not.
|
| 362 |
+
Default: None.
|
| 363 |
+
plugins (dict): plugins for convolutional layers. Default: None.
|
| 364 |
+
|
| 365 |
+
Notice:
|
| 366 |
+
The input image size should be devisible by the whole downsample rate
|
| 367 |
+
of the encoder. More detail of the whole downsample rate can be found
|
| 368 |
+
in UNet._check_input_devisible.
|
| 369 |
+
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
def __init__(self,
|
| 373 |
+
in_channels=3,
|
| 374 |
+
base_channels=64,
|
| 375 |
+
num_stages=5,
|
| 376 |
+
strides=(1, 1, 1, 1, 1),
|
| 377 |
+
enc_num_convs=(2, 2, 2, 2, 2),
|
| 378 |
+
dec_num_convs=(2, 2, 2, 2),
|
| 379 |
+
downsamples=(True, True, True, True),
|
| 380 |
+
enc_dilations=(1, 1, 1, 1, 1),
|
| 381 |
+
dec_dilations=(1, 1, 1, 1),
|
| 382 |
+
with_cp=False,
|
| 383 |
+
conv_cfg=None,
|
| 384 |
+
norm_cfg=dict(type='BN'),
|
| 385 |
+
act_cfg=dict(type='ReLU'),
|
| 386 |
+
upsample_cfg=dict(type='InterpConv'),
|
| 387 |
+
norm_eval=False,
|
| 388 |
+
dcn=None,
|
| 389 |
+
plugins=None):
|
| 390 |
+
super(UNet, self).__init__()
|
| 391 |
+
assert dcn is None, 'Not implemented yet.'
|
| 392 |
+
assert plugins is None, 'Not implemented yet.'
|
| 393 |
+
assert len(strides) == num_stages, \
|
| 394 |
+
'The length of strides should be equal to num_stages, '\
|
| 395 |
+
f'while the strides is {strides}, the length of '\
|
| 396 |
+
f'strides is {len(strides)}, and the num_stages is '\
|
| 397 |
+
f'{num_stages}.'
|
| 398 |
+
assert len(enc_num_convs) == num_stages, \
|
| 399 |
+
'The length of enc_num_convs should be equal to num_stages, '\
|
| 400 |
+
f'while the enc_num_convs is {enc_num_convs}, the length of '\
|
| 401 |
+
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
|
| 402 |
+
f'{num_stages}.'
|
| 403 |
+
assert len(dec_num_convs) == (num_stages-1), \
|
| 404 |
+
'The length of dec_num_convs should be equal to (num_stages-1), '\
|
| 405 |
+
f'while the dec_num_convs is {dec_num_convs}, the length of '\
|
| 406 |
+
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
|
| 407 |
+
f'{num_stages}.'
|
| 408 |
+
assert len(downsamples) == (num_stages-1), \
|
| 409 |
+
'The length of downsamples should be equal to (num_stages-1), '\
|
| 410 |
+
f'while the downsamples is {downsamples}, the length of '\
|
| 411 |
+
f'downsamples is {len(downsamples)}, and the num_stages is '\
|
| 412 |
+
f'{num_stages}.'
|
| 413 |
+
assert len(enc_dilations) == num_stages, \
|
| 414 |
+
'The length of enc_dilations should be equal to num_stages, '\
|
| 415 |
+
f'while the enc_dilations is {enc_dilations}, the length of '\
|
| 416 |
+
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
|
| 417 |
+
f'{num_stages}.'
|
| 418 |
+
assert len(dec_dilations) == (num_stages-1), \
|
| 419 |
+
'The length of dec_dilations should be equal to (num_stages-1), '\
|
| 420 |
+
f'while the dec_dilations is {dec_dilations}, the length of '\
|
| 421 |
+
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
|
| 422 |
+
f'{num_stages}.'
|
| 423 |
+
self.num_stages = num_stages
|
| 424 |
+
self.strides = strides
|
| 425 |
+
self.downsamples = downsamples
|
| 426 |
+
self.norm_eval = norm_eval
|
| 427 |
+
|
| 428 |
+
self.encoder = nn.ModuleList()
|
| 429 |
+
self.decoder = nn.ModuleList()
|
| 430 |
+
|
| 431 |
+
for i in range(num_stages):
|
| 432 |
+
enc_conv_block = []
|
| 433 |
+
if i != 0:
|
| 434 |
+
if strides[i] == 1 and downsamples[i - 1]:
|
| 435 |
+
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
|
| 436 |
+
upsample = (strides[i] != 1 or downsamples[i - 1])
|
| 437 |
+
self.decoder.append(
|
| 438 |
+
UpConvBlock(
|
| 439 |
+
conv_block=BasicConvBlock,
|
| 440 |
+
in_channels=base_channels * 2**i,
|
| 441 |
+
skip_channels=base_channels * 2**(i - 1),
|
| 442 |
+
out_channels=base_channels * 2**(i - 1),
|
| 443 |
+
num_convs=dec_num_convs[i - 1],
|
| 444 |
+
stride=1,
|
| 445 |
+
dilation=dec_dilations[i - 1],
|
| 446 |
+
with_cp=with_cp,
|
| 447 |
+
conv_cfg=conv_cfg,
|
| 448 |
+
norm_cfg=norm_cfg,
|
| 449 |
+
act_cfg=act_cfg,
|
| 450 |
+
upsample_cfg=upsample_cfg if upsample else None,
|
| 451 |
+
dcn=None,
|
| 452 |
+
plugins=None))
|
| 453 |
+
|
| 454 |
+
enc_conv_block.append(
|
| 455 |
+
BasicConvBlock(
|
| 456 |
+
in_channels=in_channels,
|
| 457 |
+
out_channels=base_channels * 2**i,
|
| 458 |
+
num_convs=enc_num_convs[i],
|
| 459 |
+
stride=strides[i],
|
| 460 |
+
dilation=enc_dilations[i],
|
| 461 |
+
with_cp=with_cp,
|
| 462 |
+
conv_cfg=conv_cfg,
|
| 463 |
+
norm_cfg=norm_cfg,
|
| 464 |
+
act_cfg=act_cfg,
|
| 465 |
+
dcn=None,
|
| 466 |
+
plugins=None))
|
| 467 |
+
self.encoder.append((nn.Sequential(*enc_conv_block)))
|
| 468 |
+
in_channels = base_channels * 2**i
|
| 469 |
+
|
| 470 |
+
def forward(self, x):
|
| 471 |
+
enc_outs = []
|
| 472 |
+
|
| 473 |
+
for enc in self.encoder:
|
| 474 |
+
x = enc(x)
|
| 475 |
+
enc_outs.append(x)
|
| 476 |
+
dec_outs = [x]
|
| 477 |
+
for i in reversed(range(len(self.decoder))):
|
| 478 |
+
x = self.decoder[i](enc_outs[i], x)
|
| 479 |
+
dec_outs.append(x)
|
| 480 |
+
|
| 481 |
+
return dec_outs
|
| 482 |
+
|
| 483 |
+
def init_weights(self, pretrained=None):
|
| 484 |
+
"""Initialize the weights in backbone.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 488 |
+
Defaults to None.
|
| 489 |
+
"""
|
| 490 |
+
if isinstance(pretrained, str):
|
| 491 |
+
logger = get_root_logger()
|
| 492 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
| 493 |
+
elif pretrained is None:
|
| 494 |
+
for m in self.modules():
|
| 495 |
+
if isinstance(m, nn.Conv2d):
|
| 496 |
+
kaiming_init(m)
|
| 497 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
| 498 |
+
constant_init(m, 1)
|
| 499 |
+
else:
|
| 500 |
+
raise TypeError('pretrained must be a str or None')
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class ShapeUNet(nn.Module):
|
| 504 |
+
"""ShapeUNet backbone with small modifications.
|
| 505 |
+
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
| 506 |
+
https://arxiv.org/pdf/1505.04597.pdf
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
in_channels (int): Number of input image channels. Default" 3.
|
| 510 |
+
base_channels (int): Number of base channels of each stage.
|
| 511 |
+
The output channels of the first stage. Default: 64.
|
| 512 |
+
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
|
| 513 |
+
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
|
| 514 |
+
len(strides) is equal to num_stages. Normally the stride of the
|
| 515 |
+
first stage in encoder is 1. If strides[i]=2, it uses stride
|
| 516 |
+
convolution to downsample in the correspondance encoder stage.
|
| 517 |
+
Default: (1, 1, 1, 1, 1).
|
| 518 |
+
enc_num_convs (Sequence[int]): Number of convolutional layers in the
|
| 519 |
+
convolution block of the correspondance encoder stage.
|
| 520 |
+
Default: (2, 2, 2, 2, 2).
|
| 521 |
+
dec_num_convs (Sequence[int]): Number of convolutional layers in the
|
| 522 |
+
convolution block of the correspondance decoder stage.
|
| 523 |
+
Default: (2, 2, 2, 2).
|
| 524 |
+
downsamples (Sequence[int]): Whether use MaxPool to downsample the
|
| 525 |
+
feature map after the first stage of encoder
|
| 526 |
+
(stages: [1, num_stages)). If the correspondance encoder stage use
|
| 527 |
+
stride convolution (strides[i]=2), it will never use MaxPool to
|
| 528 |
+
downsample, even downsamples[i-1]=True.
|
| 529 |
+
Default: (True, True, True, True).
|
| 530 |
+
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
|
| 531 |
+
Default: (1, 1, 1, 1, 1).
|
| 532 |
+
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
|
| 533 |
+
Default: (1, 1, 1, 1).
|
| 534 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
| 535 |
+
memory while slowing down the training speed. Default: False.
|
| 536 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
| 537 |
+
Default: None.
|
| 538 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
| 539 |
+
Default: dict(type='BN').
|
| 540 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
| 541 |
+
Default: dict(type='ReLU').
|
| 542 |
+
upsample_cfg (dict): The upsample config of the upsample module in
|
| 543 |
+
decoder. Default: dict(type='InterpConv').
|
| 544 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
| 545 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
| 546 |
+
and its variants only. Default: False.
|
| 547 |
+
dcn (bool): Use deformable convoluton in convolutional layer or not.
|
| 548 |
+
Default: None.
|
| 549 |
+
plugins (dict): plugins for convolutional layers. Default: None.
|
| 550 |
+
|
| 551 |
+
Notice:
|
| 552 |
+
The input image size should be devisible by the whole downsample rate
|
| 553 |
+
of the encoder. More detail of the whole downsample rate can be found
|
| 554 |
+
in UNet._check_input_devisible.
|
| 555 |
+
|
| 556 |
+
"""
|
| 557 |
+
|
| 558 |
+
def __init__(self,
|
| 559 |
+
in_channels=3,
|
| 560 |
+
base_channels=64,
|
| 561 |
+
num_stages=5,
|
| 562 |
+
attr_embedding=128,
|
| 563 |
+
strides=(1, 1, 1, 1, 1),
|
| 564 |
+
enc_num_convs=(2, 2, 2, 2, 2),
|
| 565 |
+
dec_num_convs=(2, 2, 2, 2),
|
| 566 |
+
downsamples=(True, True, True, True),
|
| 567 |
+
enc_dilations=(1, 1, 1, 1, 1),
|
| 568 |
+
dec_dilations=(1, 1, 1, 1),
|
| 569 |
+
with_cp=False,
|
| 570 |
+
conv_cfg=None,
|
| 571 |
+
norm_cfg=dict(type='BN'),
|
| 572 |
+
act_cfg=dict(type='ReLU'),
|
| 573 |
+
upsample_cfg=dict(type='InterpConv'),
|
| 574 |
+
norm_eval=False,
|
| 575 |
+
dcn=None,
|
| 576 |
+
plugins=None):
|
| 577 |
+
super(ShapeUNet, self).__init__()
|
| 578 |
+
assert dcn is None, 'Not implemented yet.'
|
| 579 |
+
assert plugins is None, 'Not implemented yet.'
|
| 580 |
+
assert len(strides) == num_stages, \
|
| 581 |
+
'The length of strides should be equal to num_stages, '\
|
| 582 |
+
f'while the strides is {strides}, the length of '\
|
| 583 |
+
f'strides is {len(strides)}, and the num_stages is '\
|
| 584 |
+
f'{num_stages}.'
|
| 585 |
+
assert len(enc_num_convs) == num_stages, \
|
| 586 |
+
'The length of enc_num_convs should be equal to num_stages, '\
|
| 587 |
+
f'while the enc_num_convs is {enc_num_convs}, the length of '\
|
| 588 |
+
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
|
| 589 |
+
f'{num_stages}.'
|
| 590 |
+
assert len(dec_num_convs) == (num_stages-1), \
|
| 591 |
+
'The length of dec_num_convs should be equal to (num_stages-1), '\
|
| 592 |
+
f'while the dec_num_convs is {dec_num_convs}, the length of '\
|
| 593 |
+
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
|
| 594 |
+
f'{num_stages}.'
|
| 595 |
+
assert len(downsamples) == (num_stages-1), \
|
| 596 |
+
'The length of downsamples should be equal to (num_stages-1), '\
|
| 597 |
+
f'while the downsamples is {downsamples}, the length of '\
|
| 598 |
+
f'downsamples is {len(downsamples)}, and the num_stages is '\
|
| 599 |
+
f'{num_stages}.'
|
| 600 |
+
assert len(enc_dilations) == num_stages, \
|
| 601 |
+
'The length of enc_dilations should be equal to num_stages, '\
|
| 602 |
+
f'while the enc_dilations is {enc_dilations}, the length of '\
|
| 603 |
+
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
|
| 604 |
+
f'{num_stages}.'
|
| 605 |
+
assert len(dec_dilations) == (num_stages-1), \
|
| 606 |
+
'The length of dec_dilations should be equal to (num_stages-1), '\
|
| 607 |
+
f'while the dec_dilations is {dec_dilations}, the length of '\
|
| 608 |
+
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
|
| 609 |
+
f'{num_stages}.'
|
| 610 |
+
self.num_stages = num_stages
|
| 611 |
+
self.strides = strides
|
| 612 |
+
self.downsamples = downsamples
|
| 613 |
+
self.norm_eval = norm_eval
|
| 614 |
+
|
| 615 |
+
self.encoder = nn.ModuleList()
|
| 616 |
+
self.decoder = nn.ModuleList()
|
| 617 |
+
|
| 618 |
+
for i in range(num_stages):
|
| 619 |
+
enc_conv_block = []
|
| 620 |
+
if i != 0:
|
| 621 |
+
if strides[i] == 1 and downsamples[i - 1]:
|
| 622 |
+
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
|
| 623 |
+
upsample = (strides[i] != 1 or downsamples[i - 1])
|
| 624 |
+
self.decoder.append(
|
| 625 |
+
UpConvBlock(
|
| 626 |
+
conv_block=BasicConvBlock,
|
| 627 |
+
in_channels=base_channels * 2**i,
|
| 628 |
+
skip_channels=base_channels * 2**(i - 1),
|
| 629 |
+
out_channels=base_channels * 2**(i - 1),
|
| 630 |
+
num_convs=dec_num_convs[i - 1],
|
| 631 |
+
stride=1,
|
| 632 |
+
dilation=dec_dilations[i - 1],
|
| 633 |
+
with_cp=with_cp,
|
| 634 |
+
conv_cfg=conv_cfg,
|
| 635 |
+
norm_cfg=norm_cfg,
|
| 636 |
+
act_cfg=act_cfg,
|
| 637 |
+
upsample_cfg=upsample_cfg if upsample else None,
|
| 638 |
+
dcn=None,
|
| 639 |
+
plugins=None))
|
| 640 |
+
|
| 641 |
+
enc_conv_block.append(
|
| 642 |
+
BasicConvBlock(
|
| 643 |
+
in_channels=in_channels + attr_embedding,
|
| 644 |
+
out_channels=base_channels * 2**i,
|
| 645 |
+
num_convs=enc_num_convs[i],
|
| 646 |
+
stride=strides[i],
|
| 647 |
+
dilation=enc_dilations[i],
|
| 648 |
+
with_cp=with_cp,
|
| 649 |
+
conv_cfg=conv_cfg,
|
| 650 |
+
norm_cfg=norm_cfg,
|
| 651 |
+
act_cfg=act_cfg,
|
| 652 |
+
dcn=None,
|
| 653 |
+
plugins=None))
|
| 654 |
+
self.encoder.append((nn.Sequential(*enc_conv_block)))
|
| 655 |
+
in_channels = base_channels * 2**i
|
| 656 |
+
|
| 657 |
+
def forward(self, x, attr_embedding):
|
| 658 |
+
enc_outs = []
|
| 659 |
+
Be, Ce = attr_embedding.size()
|
| 660 |
+
for enc in self.encoder:
|
| 661 |
+
_, _, H, W = x.size()
|
| 662 |
+
x = enc(
|
| 663 |
+
torch.cat([
|
| 664 |
+
x,
|
| 665 |
+
attr_embedding.view(Be, Ce, 1, 1).expand((Be, Ce, H, W))
|
| 666 |
+
],
|
| 667 |
+
dim=1))
|
| 668 |
+
enc_outs.append(x)
|
| 669 |
+
dec_outs = [x]
|
| 670 |
+
for i in reversed(range(len(self.decoder))):
|
| 671 |
+
x = self.decoder[i](enc_outs[i], x)
|
| 672 |
+
dec_outs.append(x)
|
| 673 |
+
|
| 674 |
+
return dec_outs
|
| 675 |
+
|
| 676 |
+
def init_weights(self, pretrained=None):
|
| 677 |
+
"""Initialize the weights in backbone.
|
| 678 |
+
|
| 679 |
+
Args:
|
| 680 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 681 |
+
Defaults to None.
|
| 682 |
+
"""
|
| 683 |
+
if isinstance(pretrained, str):
|
| 684 |
+
logger = get_root_logger()
|
| 685 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
| 686 |
+
elif pretrained is None:
|
| 687 |
+
for m in self.modules():
|
| 688 |
+
if isinstance(m, nn.Conv2d):
|
| 689 |
+
kaiming_init(m)
|
| 690 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
| 691 |
+
constant_init(m, 1)
|
| 692 |
+
else:
|
| 693 |
+
raise TypeError('pretrained must be a str or None')
|
models/archs/vqgan_arch.py
ADDED
|
@@ -0,0 +1,1203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pytorch_diffusion + derived encoder decoder
|
| 2 |
+
import math
|
| 3 |
+
from urllib.request import proxy_bypass
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VectorQuantizer(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
| 15 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
| 19 |
+
# backwards compatibility we use the buggy version by default, but you can
|
| 20 |
+
# specify legacy=False to fix it.
|
| 21 |
+
def __init__(self,
|
| 22 |
+
n_e,
|
| 23 |
+
e_dim,
|
| 24 |
+
beta,
|
| 25 |
+
remap=None,
|
| 26 |
+
unknown_index="random",
|
| 27 |
+
sane_index_shape=False,
|
| 28 |
+
legacy=True):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.n_e = n_e
|
| 31 |
+
self.e_dim = e_dim
|
| 32 |
+
self.beta = beta
|
| 33 |
+
self.legacy = legacy
|
| 34 |
+
|
| 35 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
| 36 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
| 37 |
+
|
| 38 |
+
self.remap = remap
|
| 39 |
+
if self.remap is not None:
|
| 40 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
| 41 |
+
self.re_embed = self.used.shape[0]
|
| 42 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
| 43 |
+
if self.unknown_index == "extra":
|
| 44 |
+
self.unknown_index = self.re_embed
|
| 45 |
+
self.re_embed = self.re_embed + 1
|
| 46 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
| 47 |
+
f"Using {self.unknown_index} for unknown indices.")
|
| 48 |
+
else:
|
| 49 |
+
self.re_embed = n_e
|
| 50 |
+
|
| 51 |
+
self.sane_index_shape = sane_index_shape
|
| 52 |
+
|
| 53 |
+
def remap_to_used(self, inds):
|
| 54 |
+
ishape = inds.shape
|
| 55 |
+
assert len(ishape) > 1
|
| 56 |
+
inds = inds.reshape(ishape[0], -1)
|
| 57 |
+
used = self.used.to(inds)
|
| 58 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
| 59 |
+
new = match.argmax(-1)
|
| 60 |
+
unknown = match.sum(2) < 1
|
| 61 |
+
if self.unknown_index == "random":
|
| 62 |
+
new[unknown] = torch.randint(
|
| 63 |
+
0, self.re_embed,
|
| 64 |
+
size=new[unknown].shape).to(device=new.device)
|
| 65 |
+
else:
|
| 66 |
+
new[unknown] = self.unknown_index
|
| 67 |
+
return new.reshape(ishape)
|
| 68 |
+
|
| 69 |
+
def unmap_to_all(self, inds):
|
| 70 |
+
ishape = inds.shape
|
| 71 |
+
assert len(ishape) > 1
|
| 72 |
+
inds = inds.reshape(ishape[0], -1)
|
| 73 |
+
used = self.used.to(inds)
|
| 74 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
| 75 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
| 76 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
| 77 |
+
return back.reshape(ishape)
|
| 78 |
+
|
| 79 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
| 80 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
| 81 |
+
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
| 82 |
+
assert return_logits == False, "Only for interface compatible with Gumbel"
|
| 83 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 84 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
| 85 |
+
z_flattened = z.view(-1, self.e_dim)
|
| 86 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 87 |
+
|
| 88 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
| 89 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
| 90 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
| 91 |
+
|
| 92 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
| 93 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
| 94 |
+
perplexity = None
|
| 95 |
+
min_encodings = None
|
| 96 |
+
|
| 97 |
+
# compute loss for embedding
|
| 98 |
+
if not self.legacy:
|
| 99 |
+
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
| 100 |
+
torch.mean((z_q - z.detach()) ** 2)
|
| 101 |
+
else:
|
| 102 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
| 103 |
+
torch.mean((z_q - z.detach()) ** 2)
|
| 104 |
+
|
| 105 |
+
# preserve gradients
|
| 106 |
+
z_q = z + (z_q - z).detach()
|
| 107 |
+
|
| 108 |
+
# reshape back to match original input shape
|
| 109 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
| 110 |
+
|
| 111 |
+
if self.remap is not None:
|
| 112 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
| 113 |
+
z.shape[0], -1) # add batch axis
|
| 114 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
| 115 |
+
min_encoding_indices = min_encoding_indices.reshape(-1,
|
| 116 |
+
1) # flatten
|
| 117 |
+
|
| 118 |
+
if self.sane_index_shape:
|
| 119 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
| 120 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
| 121 |
+
|
| 122 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
| 123 |
+
|
| 124 |
+
def get_codebook_entry(self, indices, shape):
|
| 125 |
+
# shape specifying (batch, height, width, channel)
|
| 126 |
+
if self.remap is not None:
|
| 127 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
| 128 |
+
indices = self.unmap_to_all(indices)
|
| 129 |
+
indices = indices.reshape(-1) # flatten again
|
| 130 |
+
|
| 131 |
+
# get quantized latent vectors
|
| 132 |
+
z_q = self.embedding(indices)
|
| 133 |
+
|
| 134 |
+
if shape is not None:
|
| 135 |
+
z_q = z_q.view(shape)
|
| 136 |
+
# reshape back to match original input shape
|
| 137 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 138 |
+
|
| 139 |
+
return z_q
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class VectorQuantizerTexture(nn.Module):
|
| 143 |
+
"""
|
| 144 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
| 145 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
| 149 |
+
# backwards compatibility we use the buggy version by default, but you can
|
| 150 |
+
# specify legacy=False to fix it.
|
| 151 |
+
def __init__(self,
|
| 152 |
+
n_e,
|
| 153 |
+
e_dim,
|
| 154 |
+
beta,
|
| 155 |
+
remap=None,
|
| 156 |
+
unknown_index="random",
|
| 157 |
+
sane_index_shape=False,
|
| 158 |
+
legacy=True):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.n_e = n_e
|
| 161 |
+
self.e_dim = e_dim
|
| 162 |
+
self.beta = beta
|
| 163 |
+
self.legacy = legacy
|
| 164 |
+
|
| 165 |
+
# TODO: decide number of embeddings
|
| 166 |
+
self.embedding_list = nn.ModuleList(
|
| 167 |
+
[nn.Embedding(self.n_e, self.e_dim) for i in range(18)])
|
| 168 |
+
for embedding in self.embedding_list:
|
| 169 |
+
embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
| 170 |
+
|
| 171 |
+
self.remap = remap
|
| 172 |
+
if self.remap is not None:
|
| 173 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
| 174 |
+
self.re_embed = self.used.shape[0]
|
| 175 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
| 176 |
+
if self.unknown_index == "extra":
|
| 177 |
+
self.unknown_index = self.re_embed
|
| 178 |
+
self.re_embed = self.re_embed + 1
|
| 179 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
| 180 |
+
f"Using {self.unknown_index} for unknown indices.")
|
| 181 |
+
else:
|
| 182 |
+
self.re_embed = n_e
|
| 183 |
+
|
| 184 |
+
self.sane_index_shape = sane_index_shape
|
| 185 |
+
|
| 186 |
+
def remap_to_used(self, inds):
|
| 187 |
+
ishape = inds.shape
|
| 188 |
+
assert len(ishape) > 1
|
| 189 |
+
inds = inds.reshape(ishape[0], -1)
|
| 190 |
+
used = self.used.to(inds)
|
| 191 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
| 192 |
+
new = match.argmax(-1)
|
| 193 |
+
unknown = match.sum(2) < 1
|
| 194 |
+
if self.unknown_index == "random":
|
| 195 |
+
new[unknown] = torch.randint(
|
| 196 |
+
0, self.re_embed,
|
| 197 |
+
size=new[unknown].shape).to(device=new.device)
|
| 198 |
+
else:
|
| 199 |
+
new[unknown] = self.unknown_index
|
| 200 |
+
return new.reshape(ishape)
|
| 201 |
+
|
| 202 |
+
def unmap_to_all(self, inds):
|
| 203 |
+
ishape = inds.shape
|
| 204 |
+
assert len(ishape) > 1
|
| 205 |
+
inds = inds.reshape(ishape[0], -1)
|
| 206 |
+
used = self.used.to(inds)
|
| 207 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
| 208 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
| 209 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
| 210 |
+
return back.reshape(ishape)
|
| 211 |
+
|
| 212 |
+
def forward(self,
|
| 213 |
+
z,
|
| 214 |
+
segm_map,
|
| 215 |
+
temp=None,
|
| 216 |
+
rescale_logits=False,
|
| 217 |
+
return_logits=False):
|
| 218 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
| 219 |
+
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
| 220 |
+
assert return_logits == False, "Only for interface compatible with Gumbel"
|
| 221 |
+
|
| 222 |
+
segm_map = F.interpolate(segm_map, size=z.size()[2:], mode='nearest')
|
| 223 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 224 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
| 225 |
+
z_flattened = z.view(-1, self.e_dim)
|
| 226 |
+
|
| 227 |
+
# flatten segm_map (b, h, w)
|
| 228 |
+
segm_map_flatten = segm_map.view(-1)
|
| 229 |
+
|
| 230 |
+
z_q = torch.zeros_like(z_flattened)
|
| 231 |
+
min_encoding_indices_list = []
|
| 232 |
+
min_encoding_indices_continual = torch.full(
|
| 233 |
+
segm_map_flatten.size(),
|
| 234 |
+
fill_value=-1,
|
| 235 |
+
dtype=torch.long,
|
| 236 |
+
device=segm_map_flatten.device)
|
| 237 |
+
for codebook_idx in range(18):
|
| 238 |
+
min_encoding_indices = torch.full(
|
| 239 |
+
segm_map_flatten.size(),
|
| 240 |
+
fill_value=-1,
|
| 241 |
+
dtype=torch.long,
|
| 242 |
+
device=segm_map_flatten.device)
|
| 243 |
+
if torch.sum(segm_map_flatten == codebook_idx) > 0:
|
| 244 |
+
z_selected = z_flattened[segm_map_flatten == codebook_idx]
|
| 245 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 246 |
+
d_selected = torch.sum(
|
| 247 |
+
z_selected**2, dim=1, keepdim=True) + torch.sum(
|
| 248 |
+
self.embedding_list[codebook_idx].weight**2,
|
| 249 |
+
dim=1) - 2 * torch.einsum(
|
| 250 |
+
'bd,dn->bn', z_selected,
|
| 251 |
+
rearrange(self.embedding_list[codebook_idx].weight,
|
| 252 |
+
'n d -> d n'))
|
| 253 |
+
min_encoding_indices_selected = torch.argmin(d_selected, dim=1)
|
| 254 |
+
z_q_selected = self.embedding_list[codebook_idx](
|
| 255 |
+
min_encoding_indices_selected)
|
| 256 |
+
z_q[segm_map_flatten == codebook_idx] = z_q_selected
|
| 257 |
+
min_encoding_indices[
|
| 258 |
+
segm_map_flatten ==
|
| 259 |
+
codebook_idx] = min_encoding_indices_selected
|
| 260 |
+
min_encoding_indices_continual[
|
| 261 |
+
segm_map_flatten ==
|
| 262 |
+
codebook_idx] = min_encoding_indices_selected + 1024 * codebook_idx
|
| 263 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
| 264 |
+
z.shape[0], z.shape[1], z.shape[2])
|
| 265 |
+
min_encoding_indices_list.append(min_encoding_indices)
|
| 266 |
+
|
| 267 |
+
min_encoding_indices_continual = min_encoding_indices_continual.reshape(
|
| 268 |
+
z.shape[0], z.shape[1], z.shape[2])
|
| 269 |
+
z_q = z_q.view(z.shape)
|
| 270 |
+
perplexity = None
|
| 271 |
+
|
| 272 |
+
# compute loss for embedding
|
| 273 |
+
if not self.legacy:
|
| 274 |
+
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
| 275 |
+
torch.mean((z_q - z.detach()) ** 2)
|
| 276 |
+
else:
|
| 277 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
| 278 |
+
torch.mean((z_q - z.detach()) ** 2)
|
| 279 |
+
|
| 280 |
+
# preserve gradients
|
| 281 |
+
z_q = z + (z_q - z).detach()
|
| 282 |
+
|
| 283 |
+
# reshape back to match original input shape
|
| 284 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
| 285 |
+
|
| 286 |
+
return z_q, loss, (perplexity, min_encoding_indices_continual,
|
| 287 |
+
min_encoding_indices_list)
|
| 288 |
+
|
| 289 |
+
def get_codebook_entry(self, indices_list, segm_map, shape):
|
| 290 |
+
# flatten segm_map (b, h, w)
|
| 291 |
+
segm_map = F.interpolate(
|
| 292 |
+
segm_map, size=(shape[1], shape[2]), mode='nearest')
|
| 293 |
+
segm_map_flatten = segm_map.view(-1)
|
| 294 |
+
|
| 295 |
+
z_q = torch.zeros((shape[0] * shape[1] * shape[2]),
|
| 296 |
+
self.e_dim).to(segm_map.device)
|
| 297 |
+
for codebook_idx in range(18):
|
| 298 |
+
if torch.sum(segm_map_flatten == codebook_idx) > 0:
|
| 299 |
+
min_encoding_indices_selected = indices_list[
|
| 300 |
+
codebook_idx].view(-1)[segm_map_flatten == codebook_idx]
|
| 301 |
+
z_q_selected = self.embedding_list[codebook_idx](
|
| 302 |
+
min_encoding_indices_selected)
|
| 303 |
+
z_q[segm_map_flatten == codebook_idx] = z_q_selected
|
| 304 |
+
|
| 305 |
+
z_q = z_q.view(shape)
|
| 306 |
+
# reshape back to match original input shape
|
| 307 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 308 |
+
|
| 309 |
+
return z_q
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def sample_patches(inputs, patch_size=3, stride=1):
|
| 313 |
+
"""Extract sliding local patches from an input feature tensor.
|
| 314 |
+
The sampled pathes are row-major.
|
| 315 |
+
Args:
|
| 316 |
+
inputs (Tensor): the input feature maps, shape: (n, c, h, w).
|
| 317 |
+
patch_size (int): the spatial size of sampled patches. Default: 3.
|
| 318 |
+
stride (int): the stride of sampling. Default: 1.
|
| 319 |
+
Returns:
|
| 320 |
+
patches (Tensor): extracted patches, shape: (n, c * patch_size *
|
| 321 |
+
patch_size, n_patches).
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
patches = F.unfold(inputs, (patch_size, patch_size), stride=stride)
|
| 325 |
+
|
| 326 |
+
return patches
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class VectorQuantizerSpatialTextureAware(nn.Module):
|
| 330 |
+
"""
|
| 331 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
| 332 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
| 336 |
+
# backwards compatibility we use the buggy version by default, but you can
|
| 337 |
+
# specify legacy=False to fix it.
|
| 338 |
+
def __init__(self,
|
| 339 |
+
n_e,
|
| 340 |
+
e_dim,
|
| 341 |
+
beta,
|
| 342 |
+
spatial_size,
|
| 343 |
+
remap=None,
|
| 344 |
+
unknown_index="random",
|
| 345 |
+
sane_index_shape=False,
|
| 346 |
+
legacy=True):
|
| 347 |
+
super().__init__()
|
| 348 |
+
self.n_e = n_e
|
| 349 |
+
self.e_dim = e_dim * spatial_size * spatial_size
|
| 350 |
+
self.beta = beta
|
| 351 |
+
self.legacy = legacy
|
| 352 |
+
self.spatial_size = spatial_size
|
| 353 |
+
|
| 354 |
+
# TODO: decide number of embeddings
|
| 355 |
+
self.embedding_list = nn.ModuleList(
|
| 356 |
+
[nn.Embedding(self.n_e, self.e_dim) for i in range(18)])
|
| 357 |
+
for embedding in self.embedding_list:
|
| 358 |
+
embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
| 359 |
+
|
| 360 |
+
self.remap = remap
|
| 361 |
+
if self.remap is not None:
|
| 362 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
| 363 |
+
self.re_embed = self.used.shape[0]
|
| 364 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
| 365 |
+
if self.unknown_index == "extra":
|
| 366 |
+
self.unknown_index = self.re_embed
|
| 367 |
+
self.re_embed = self.re_embed + 1
|
| 368 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
| 369 |
+
f"Using {self.unknown_index} for unknown indices.")
|
| 370 |
+
else:
|
| 371 |
+
self.re_embed = n_e
|
| 372 |
+
|
| 373 |
+
self.sane_index_shape = sane_index_shape
|
| 374 |
+
|
| 375 |
+
def forward(self,
|
| 376 |
+
z,
|
| 377 |
+
segm_map,
|
| 378 |
+
temp=None,
|
| 379 |
+
rescale_logits=False,
|
| 380 |
+
return_logits=False):
|
| 381 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
| 382 |
+
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
| 383 |
+
assert return_logits == False, "Only for interface compatible with Gumbel"
|
| 384 |
+
|
| 385 |
+
segm_map = F.interpolate(
|
| 386 |
+
segm_map,
|
| 387 |
+
size=(z.size(2) // self.spatial_size,
|
| 388 |
+
z.size(3) // self.spatial_size),
|
| 389 |
+
mode='nearest')
|
| 390 |
+
|
| 391 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 392 |
+
# z = rearrange(z, 'b c h w -> b h w c').contiguous() ?
|
| 393 |
+
z_patches = sample_patches(
|
| 394 |
+
z, patch_size=self.spatial_size,
|
| 395 |
+
stride=self.spatial_size).permute(0, 2, 1)
|
| 396 |
+
z_patches_flattened = z_patches.reshape(-1, self.e_dim)
|
| 397 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 398 |
+
|
| 399 |
+
# flatten segm_map (b, h, w)
|
| 400 |
+
segm_map_flatten = segm_map.view(-1)
|
| 401 |
+
|
| 402 |
+
z_q = torch.zeros_like(z_patches_flattened)
|
| 403 |
+
min_encoding_indices_list = []
|
| 404 |
+
min_encoding_indices_continual = torch.full(
|
| 405 |
+
segm_map_flatten.size(),
|
| 406 |
+
fill_value=-1,
|
| 407 |
+
dtype=torch.long,
|
| 408 |
+
device=segm_map_flatten.device)
|
| 409 |
+
|
| 410 |
+
for codebook_idx in range(18):
|
| 411 |
+
min_encoding_indices = torch.full(
|
| 412 |
+
segm_map_flatten.size(),
|
| 413 |
+
fill_value=-1,
|
| 414 |
+
dtype=torch.long,
|
| 415 |
+
device=segm_map_flatten.device)
|
| 416 |
+
if torch.sum(segm_map_flatten == codebook_idx) > 0:
|
| 417 |
+
z_selected = z_patches_flattened[segm_map_flatten ==
|
| 418 |
+
codebook_idx]
|
| 419 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 420 |
+
d_selected = torch.sum(
|
| 421 |
+
z_selected**2, dim=1, keepdim=True) + torch.sum(
|
| 422 |
+
self.embedding_list[codebook_idx].weight**2,
|
| 423 |
+
dim=1) - 2 * torch.einsum(
|
| 424 |
+
'bd,dn->bn', z_selected,
|
| 425 |
+
rearrange(self.embedding_list[codebook_idx].weight,
|
| 426 |
+
'n d -> d n'))
|
| 427 |
+
min_encoding_indices_selected = torch.argmin(d_selected, dim=1)
|
| 428 |
+
z_q_selected = self.embedding_list[codebook_idx](
|
| 429 |
+
min_encoding_indices_selected)
|
| 430 |
+
z_q[segm_map_flatten == codebook_idx] = z_q_selected
|
| 431 |
+
min_encoding_indices[
|
| 432 |
+
segm_map_flatten ==
|
| 433 |
+
codebook_idx] = min_encoding_indices_selected
|
| 434 |
+
min_encoding_indices_continual[
|
| 435 |
+
segm_map_flatten ==
|
| 436 |
+
codebook_idx] = min_encoding_indices_selected + self.n_e * codebook_idx
|
| 437 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
| 438 |
+
z_patches.shape[0], segm_map.shape[2], segm_map.shape[3])
|
| 439 |
+
min_encoding_indices_list.append(min_encoding_indices)
|
| 440 |
+
|
| 441 |
+
z_q = F.fold(
|
| 442 |
+
z_q.view(z_patches.shape).permute(0, 2, 1),
|
| 443 |
+
z.size()[2:],
|
| 444 |
+
kernel_size=(self.spatial_size, self.spatial_size),
|
| 445 |
+
stride=self.spatial_size)
|
| 446 |
+
|
| 447 |
+
perplexity = None
|
| 448 |
+
|
| 449 |
+
# compute loss for embedding
|
| 450 |
+
if not self.legacy:
|
| 451 |
+
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
| 452 |
+
torch.mean((z_q - z.detach()) ** 2)
|
| 453 |
+
else:
|
| 454 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
| 455 |
+
torch.mean((z_q - z.detach()) ** 2)
|
| 456 |
+
|
| 457 |
+
# preserve gradients
|
| 458 |
+
z_q = z + (z_q - z).detach()
|
| 459 |
+
|
| 460 |
+
return z_q, loss, (perplexity, min_encoding_indices_continual,
|
| 461 |
+
min_encoding_indices_list)
|
| 462 |
+
|
| 463 |
+
def get_codebook_entry(self, indices_list, segm_map, shape):
|
| 464 |
+
# flatten segm_map (b, h, w)
|
| 465 |
+
segm_map = F.interpolate(
|
| 466 |
+
segm_map, size=(shape[1], shape[2]), mode='nearest')
|
| 467 |
+
segm_map_flatten = segm_map.view(-1)
|
| 468 |
+
|
| 469 |
+
z_q = torch.zeros((shape[0] * shape[1] * shape[2]),
|
| 470 |
+
self.e_dim).to(segm_map.device)
|
| 471 |
+
for codebook_idx in range(18):
|
| 472 |
+
if torch.sum(segm_map_flatten == codebook_idx) > 0:
|
| 473 |
+
min_encoding_indices_selected = indices_list[
|
| 474 |
+
codebook_idx].view(-1)[segm_map_flatten == codebook_idx]
|
| 475 |
+
z_q_selected = self.embedding_list[codebook_idx](
|
| 476 |
+
min_encoding_indices_selected)
|
| 477 |
+
z_q[segm_map_flatten == codebook_idx] = z_q_selected
|
| 478 |
+
|
| 479 |
+
z_q = F.fold(
|
| 480 |
+
z_q.view(((shape[0], shape[1] * shape[2],
|
| 481 |
+
self.e_dim))).permute(0, 2, 1),
|
| 482 |
+
(shape[1] * self.spatial_size, shape[2] * self.spatial_size),
|
| 483 |
+
kernel_size=(self.spatial_size, self.spatial_size),
|
| 484 |
+
stride=self.spatial_size)
|
| 485 |
+
|
| 486 |
+
return z_q
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
| 490 |
+
"""
|
| 491 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
| 492 |
+
From Fairseq.
|
| 493 |
+
Build sinusoidal embeddings.
|
| 494 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
| 495 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
| 496 |
+
"""
|
| 497 |
+
assert len(timesteps.shape) == 1
|
| 498 |
+
|
| 499 |
+
half_dim = embedding_dim // 2
|
| 500 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 501 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
| 502 |
+
emb = emb.to(device=timesteps.device)
|
| 503 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
| 504 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 505 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 506 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 507 |
+
return emb
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def nonlinearity(x):
|
| 511 |
+
# swish
|
| 512 |
+
return x * torch.sigmoid(x)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def Normalize(in_channels):
|
| 516 |
+
return torch.nn.GroupNorm(
|
| 517 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class Upsample(nn.Module):
|
| 521 |
+
|
| 522 |
+
def __init__(self, in_channels, with_conv):
|
| 523 |
+
super().__init__()
|
| 524 |
+
self.with_conv = with_conv
|
| 525 |
+
if self.with_conv:
|
| 526 |
+
self.conv = torch.nn.Conv2d(
|
| 527 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 528 |
+
|
| 529 |
+
def forward(self, x):
|
| 530 |
+
x = torch.nn.functional.interpolate(
|
| 531 |
+
x, scale_factor=2.0, mode="nearest")
|
| 532 |
+
if self.with_conv:
|
| 533 |
+
x = self.conv(x)
|
| 534 |
+
return x
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
class Downsample(nn.Module):
|
| 538 |
+
|
| 539 |
+
def __init__(self, in_channels, with_conv):
|
| 540 |
+
super().__init__()
|
| 541 |
+
self.with_conv = with_conv
|
| 542 |
+
if self.with_conv:
|
| 543 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 544 |
+
self.conv = torch.nn.Conv2d(
|
| 545 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 546 |
+
|
| 547 |
+
def forward(self, x):
|
| 548 |
+
if self.with_conv:
|
| 549 |
+
pad = (0, 1, 0, 1)
|
| 550 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 551 |
+
x = self.conv(x)
|
| 552 |
+
else:
|
| 553 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 554 |
+
return x
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
class ResnetBlock(nn.Module):
|
| 558 |
+
|
| 559 |
+
def __init__(self,
|
| 560 |
+
*,
|
| 561 |
+
in_channels,
|
| 562 |
+
out_channels=None,
|
| 563 |
+
conv_shortcut=False,
|
| 564 |
+
dropout,
|
| 565 |
+
temb_channels=512):
|
| 566 |
+
super().__init__()
|
| 567 |
+
self.in_channels = in_channels
|
| 568 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 569 |
+
self.out_channels = out_channels
|
| 570 |
+
self.use_conv_shortcut = conv_shortcut
|
| 571 |
+
|
| 572 |
+
self.norm1 = Normalize(in_channels)
|
| 573 |
+
self.conv1 = torch.nn.Conv2d(
|
| 574 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 575 |
+
if temb_channels > 0:
|
| 576 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
| 577 |
+
self.norm2 = Normalize(out_channels)
|
| 578 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 579 |
+
self.conv2 = torch.nn.Conv2d(
|
| 580 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 581 |
+
if self.in_channels != self.out_channels:
|
| 582 |
+
if self.use_conv_shortcut:
|
| 583 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
| 584 |
+
in_channels,
|
| 585 |
+
out_channels,
|
| 586 |
+
kernel_size=3,
|
| 587 |
+
stride=1,
|
| 588 |
+
padding=1)
|
| 589 |
+
else:
|
| 590 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
| 591 |
+
in_channels,
|
| 592 |
+
out_channels,
|
| 593 |
+
kernel_size=1,
|
| 594 |
+
stride=1,
|
| 595 |
+
padding=0)
|
| 596 |
+
|
| 597 |
+
def forward(self, x, temb):
|
| 598 |
+
h = x
|
| 599 |
+
h = self.norm1(h)
|
| 600 |
+
h = nonlinearity(h)
|
| 601 |
+
h = self.conv1(h)
|
| 602 |
+
|
| 603 |
+
if temb is not None:
|
| 604 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
| 605 |
+
|
| 606 |
+
h = self.norm2(h)
|
| 607 |
+
h = nonlinearity(h)
|
| 608 |
+
h = self.dropout(h)
|
| 609 |
+
h = self.conv2(h)
|
| 610 |
+
|
| 611 |
+
if self.in_channels != self.out_channels:
|
| 612 |
+
if self.use_conv_shortcut:
|
| 613 |
+
x = self.conv_shortcut(x)
|
| 614 |
+
else:
|
| 615 |
+
x = self.nin_shortcut(x)
|
| 616 |
+
|
| 617 |
+
return x + h
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
class AttnBlock(nn.Module):
|
| 621 |
+
|
| 622 |
+
def __init__(self, in_channels):
|
| 623 |
+
super().__init__()
|
| 624 |
+
self.in_channels = in_channels
|
| 625 |
+
|
| 626 |
+
self.norm = Normalize(in_channels)
|
| 627 |
+
self.q = torch.nn.Conv2d(
|
| 628 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 629 |
+
self.k = torch.nn.Conv2d(
|
| 630 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 631 |
+
self.v = torch.nn.Conv2d(
|
| 632 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 633 |
+
self.proj_out = torch.nn.Conv2d(
|
| 634 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 635 |
+
|
| 636 |
+
def forward(self, x):
|
| 637 |
+
h_ = x
|
| 638 |
+
h_ = self.norm(h_)
|
| 639 |
+
q = self.q(h_)
|
| 640 |
+
k = self.k(h_)
|
| 641 |
+
v = self.v(h_)
|
| 642 |
+
|
| 643 |
+
# compute attention
|
| 644 |
+
b, c, h, w = q.shape
|
| 645 |
+
q = q.reshape(b, c, h * w)
|
| 646 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
| 647 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
| 648 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 649 |
+
w_ = w_ * (int(c)**(-0.5))
|
| 650 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 651 |
+
|
| 652 |
+
# attend to values
|
| 653 |
+
v = v.reshape(b, c, h * w)
|
| 654 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
| 655 |
+
h_ = torch.bmm(
|
| 656 |
+
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 657 |
+
h_ = h_.reshape(b, c, h, w)
|
| 658 |
+
|
| 659 |
+
h_ = self.proj_out(h_)
|
| 660 |
+
|
| 661 |
+
return x + h_
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
class Model(nn.Module):
|
| 665 |
+
|
| 666 |
+
def __init__(self,
|
| 667 |
+
*,
|
| 668 |
+
ch,
|
| 669 |
+
out_ch,
|
| 670 |
+
ch_mult=(1, 2, 4, 8),
|
| 671 |
+
num_res_blocks,
|
| 672 |
+
attn_resolutions,
|
| 673 |
+
dropout=0.0,
|
| 674 |
+
resamp_with_conv=True,
|
| 675 |
+
in_channels,
|
| 676 |
+
resolution,
|
| 677 |
+
use_timestep=True):
|
| 678 |
+
super().__init__()
|
| 679 |
+
self.ch = ch
|
| 680 |
+
self.temb_ch = self.ch * 4
|
| 681 |
+
self.num_resolutions = len(ch_mult)
|
| 682 |
+
self.num_res_blocks = num_res_blocks
|
| 683 |
+
self.resolution = resolution
|
| 684 |
+
self.in_channels = in_channels
|
| 685 |
+
|
| 686 |
+
self.use_timestep = use_timestep
|
| 687 |
+
if self.use_timestep:
|
| 688 |
+
# timestep embedding
|
| 689 |
+
self.temb = nn.Module()
|
| 690 |
+
self.temb.dense = nn.ModuleList([
|
| 691 |
+
torch.nn.Linear(self.ch, self.temb_ch),
|
| 692 |
+
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
| 693 |
+
])
|
| 694 |
+
|
| 695 |
+
# downsampling
|
| 696 |
+
self.conv_in = torch.nn.Conv2d(
|
| 697 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 698 |
+
|
| 699 |
+
curr_res = resolution
|
| 700 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
| 701 |
+
self.down = nn.ModuleList()
|
| 702 |
+
for i_level in range(self.num_resolutions):
|
| 703 |
+
block = nn.ModuleList()
|
| 704 |
+
attn = nn.ModuleList()
|
| 705 |
+
block_in = ch * in_ch_mult[i_level]
|
| 706 |
+
block_out = ch * ch_mult[i_level]
|
| 707 |
+
for i_block in range(self.num_res_blocks):
|
| 708 |
+
block.append(
|
| 709 |
+
ResnetBlock(
|
| 710 |
+
in_channels=block_in,
|
| 711 |
+
out_channels=block_out,
|
| 712 |
+
temb_channels=self.temb_ch,
|
| 713 |
+
dropout=dropout))
|
| 714 |
+
block_in = block_out
|
| 715 |
+
if curr_res in attn_resolutions:
|
| 716 |
+
attn.append(AttnBlock(block_in))
|
| 717 |
+
down = nn.Module()
|
| 718 |
+
down.block = block
|
| 719 |
+
down.attn = attn
|
| 720 |
+
if i_level != self.num_resolutions - 1:
|
| 721 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 722 |
+
curr_res = curr_res // 2
|
| 723 |
+
self.down.append(down)
|
| 724 |
+
|
| 725 |
+
# middle
|
| 726 |
+
self.mid = nn.Module()
|
| 727 |
+
self.mid.block_1 = ResnetBlock(
|
| 728 |
+
in_channels=block_in,
|
| 729 |
+
out_channels=block_in,
|
| 730 |
+
temb_channels=self.temb_ch,
|
| 731 |
+
dropout=dropout)
|
| 732 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 733 |
+
self.mid.block_2 = ResnetBlock(
|
| 734 |
+
in_channels=block_in,
|
| 735 |
+
out_channels=block_in,
|
| 736 |
+
temb_channels=self.temb_ch,
|
| 737 |
+
dropout=dropout)
|
| 738 |
+
|
| 739 |
+
# upsampling
|
| 740 |
+
self.up = nn.ModuleList()
|
| 741 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 742 |
+
block = nn.ModuleList()
|
| 743 |
+
attn = nn.ModuleList()
|
| 744 |
+
block_out = ch * ch_mult[i_level]
|
| 745 |
+
skip_in = ch * ch_mult[i_level]
|
| 746 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 747 |
+
if i_block == self.num_res_blocks:
|
| 748 |
+
skip_in = ch * in_ch_mult[i_level]
|
| 749 |
+
block.append(
|
| 750 |
+
ResnetBlock(
|
| 751 |
+
in_channels=block_in + skip_in,
|
| 752 |
+
out_channels=block_out,
|
| 753 |
+
temb_channels=self.temb_ch,
|
| 754 |
+
dropout=dropout))
|
| 755 |
+
block_in = block_out
|
| 756 |
+
if curr_res in attn_resolutions:
|
| 757 |
+
attn.append(AttnBlock(block_in))
|
| 758 |
+
up = nn.Module()
|
| 759 |
+
up.block = block
|
| 760 |
+
up.attn = attn
|
| 761 |
+
if i_level != 0:
|
| 762 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 763 |
+
curr_res = curr_res * 2
|
| 764 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 765 |
+
|
| 766 |
+
# end
|
| 767 |
+
self.norm_out = Normalize(block_in)
|
| 768 |
+
self.conv_out = torch.nn.Conv2d(
|
| 769 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 770 |
+
|
| 771 |
+
def forward(self, x, t=None):
|
| 772 |
+
#assert x.shape[2] == x.shape[3] == self.resolution
|
| 773 |
+
|
| 774 |
+
if self.use_timestep:
|
| 775 |
+
# timestep embedding
|
| 776 |
+
assert t is not None
|
| 777 |
+
temb = get_timestep_embedding(t, self.ch)
|
| 778 |
+
temb = self.temb.dense[0](temb)
|
| 779 |
+
temb = nonlinearity(temb)
|
| 780 |
+
temb = self.temb.dense[1](temb)
|
| 781 |
+
else:
|
| 782 |
+
temb = None
|
| 783 |
+
|
| 784 |
+
# downsampling
|
| 785 |
+
hs = [self.conv_in(x)]
|
| 786 |
+
for i_level in range(self.num_resolutions):
|
| 787 |
+
for i_block in range(self.num_res_blocks):
|
| 788 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 789 |
+
if len(self.down[i_level].attn) > 0:
|
| 790 |
+
h = self.down[i_level].attn[i_block](h)
|
| 791 |
+
hs.append(h)
|
| 792 |
+
if i_level != self.num_resolutions - 1:
|
| 793 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 794 |
+
|
| 795 |
+
# middle
|
| 796 |
+
h = hs[-1]
|
| 797 |
+
h = self.mid.block_1(h, temb)
|
| 798 |
+
h = self.mid.attn_1(h)
|
| 799 |
+
h = self.mid.block_2(h, temb)
|
| 800 |
+
|
| 801 |
+
# upsampling
|
| 802 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 803 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 804 |
+
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()],
|
| 805 |
+
dim=1), temb)
|
| 806 |
+
if len(self.up[i_level].attn) > 0:
|
| 807 |
+
h = self.up[i_level].attn[i_block](h)
|
| 808 |
+
if i_level != 0:
|
| 809 |
+
h = self.up[i_level].upsample(h)
|
| 810 |
+
|
| 811 |
+
# end
|
| 812 |
+
h = self.norm_out(h)
|
| 813 |
+
h = nonlinearity(h)
|
| 814 |
+
h = self.conv_out(h)
|
| 815 |
+
return h
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
class Encoder(nn.Module):
|
| 819 |
+
|
| 820 |
+
def __init__(self,
|
| 821 |
+
ch,
|
| 822 |
+
num_res_blocks,
|
| 823 |
+
attn_resolutions,
|
| 824 |
+
in_channels,
|
| 825 |
+
resolution,
|
| 826 |
+
z_channels,
|
| 827 |
+
ch_mult=(1, 2, 4, 8),
|
| 828 |
+
dropout=0.0,
|
| 829 |
+
resamp_with_conv=True,
|
| 830 |
+
double_z=True):
|
| 831 |
+
super().__init__()
|
| 832 |
+
self.ch = ch
|
| 833 |
+
self.temb_ch = 0
|
| 834 |
+
self.num_resolutions = len(ch_mult)
|
| 835 |
+
self.num_res_blocks = num_res_blocks
|
| 836 |
+
self.resolution = resolution
|
| 837 |
+
self.in_channels = in_channels
|
| 838 |
+
|
| 839 |
+
# downsampling
|
| 840 |
+
self.conv_in = torch.nn.Conv2d(
|
| 841 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 842 |
+
|
| 843 |
+
curr_res = resolution
|
| 844 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
| 845 |
+
self.down = nn.ModuleList()
|
| 846 |
+
for i_level in range(self.num_resolutions):
|
| 847 |
+
block = nn.ModuleList()
|
| 848 |
+
attn = nn.ModuleList()
|
| 849 |
+
block_in = ch * in_ch_mult[i_level]
|
| 850 |
+
block_out = ch * ch_mult[i_level]
|
| 851 |
+
for i_block in range(self.num_res_blocks):
|
| 852 |
+
block.append(
|
| 853 |
+
ResnetBlock(
|
| 854 |
+
in_channels=block_in,
|
| 855 |
+
out_channels=block_out,
|
| 856 |
+
temb_channels=self.temb_ch,
|
| 857 |
+
dropout=dropout))
|
| 858 |
+
block_in = block_out
|
| 859 |
+
if curr_res in attn_resolutions:
|
| 860 |
+
attn.append(AttnBlock(block_in))
|
| 861 |
+
down = nn.Module()
|
| 862 |
+
down.block = block
|
| 863 |
+
down.attn = attn
|
| 864 |
+
if i_level != self.num_resolutions - 1:
|
| 865 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 866 |
+
curr_res = curr_res // 2
|
| 867 |
+
self.down.append(down)
|
| 868 |
+
|
| 869 |
+
# middle
|
| 870 |
+
self.mid = nn.Module()
|
| 871 |
+
self.mid.block_1 = ResnetBlock(
|
| 872 |
+
in_channels=block_in,
|
| 873 |
+
out_channels=block_in,
|
| 874 |
+
temb_channels=self.temb_ch,
|
| 875 |
+
dropout=dropout)
|
| 876 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 877 |
+
self.mid.block_2 = ResnetBlock(
|
| 878 |
+
in_channels=block_in,
|
| 879 |
+
out_channels=block_in,
|
| 880 |
+
temb_channels=self.temb_ch,
|
| 881 |
+
dropout=dropout)
|
| 882 |
+
|
| 883 |
+
# end
|
| 884 |
+
self.norm_out = Normalize(block_in)
|
| 885 |
+
self.conv_out = torch.nn.Conv2d(
|
| 886 |
+
block_in,
|
| 887 |
+
2 * z_channels if double_z else z_channels,
|
| 888 |
+
kernel_size=3,
|
| 889 |
+
stride=1,
|
| 890 |
+
padding=1)
|
| 891 |
+
|
| 892 |
+
def forward(self, x):
|
| 893 |
+
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
| 894 |
+
|
| 895 |
+
# timestep embedding
|
| 896 |
+
temb = None
|
| 897 |
+
|
| 898 |
+
# downsampling
|
| 899 |
+
hs = [self.conv_in(x)]
|
| 900 |
+
for i_level in range(self.num_resolutions):
|
| 901 |
+
for i_block in range(self.num_res_blocks):
|
| 902 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 903 |
+
if len(self.down[i_level].attn) > 0:
|
| 904 |
+
h = self.down[i_level].attn[i_block](h)
|
| 905 |
+
hs.append(h)
|
| 906 |
+
if i_level != self.num_resolutions - 1:
|
| 907 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 908 |
+
|
| 909 |
+
# middle
|
| 910 |
+
h = hs[-1]
|
| 911 |
+
h = self.mid.block_1(h, temb)
|
| 912 |
+
h = self.mid.attn_1(h)
|
| 913 |
+
h = self.mid.block_2(h, temb)
|
| 914 |
+
|
| 915 |
+
# end
|
| 916 |
+
h = self.norm_out(h)
|
| 917 |
+
h = nonlinearity(h)
|
| 918 |
+
h = self.conv_out(h)
|
| 919 |
+
return h
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
class Decoder(nn.Module):
|
| 923 |
+
|
| 924 |
+
def __init__(self,
|
| 925 |
+
in_channels,
|
| 926 |
+
resolution,
|
| 927 |
+
z_channels,
|
| 928 |
+
ch,
|
| 929 |
+
out_ch,
|
| 930 |
+
num_res_blocks,
|
| 931 |
+
attn_resolutions,
|
| 932 |
+
ch_mult=(1, 2, 4, 8),
|
| 933 |
+
dropout=0.0,
|
| 934 |
+
resamp_with_conv=True,
|
| 935 |
+
give_pre_end=False):
|
| 936 |
+
super().__init__()
|
| 937 |
+
self.ch = ch
|
| 938 |
+
self.temb_ch = 0
|
| 939 |
+
self.num_resolutions = len(ch_mult)
|
| 940 |
+
self.num_res_blocks = num_res_blocks
|
| 941 |
+
self.resolution = resolution
|
| 942 |
+
self.in_channels = in_channels
|
| 943 |
+
self.give_pre_end = give_pre_end
|
| 944 |
+
|
| 945 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 946 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
| 947 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 948 |
+
curr_res = resolution // 2**(self.num_resolutions - 1)
|
| 949 |
+
self.z_shape = (1, z_channels, curr_res, curr_res // 2)
|
| 950 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
| 951 |
+
self.z_shape, np.prod(self.z_shape)))
|
| 952 |
+
|
| 953 |
+
# z to block_in
|
| 954 |
+
self.conv_in = torch.nn.Conv2d(
|
| 955 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 956 |
+
|
| 957 |
+
# middle
|
| 958 |
+
self.mid = nn.Module()
|
| 959 |
+
self.mid.block_1 = ResnetBlock(
|
| 960 |
+
in_channels=block_in,
|
| 961 |
+
out_channels=block_in,
|
| 962 |
+
temb_channels=self.temb_ch,
|
| 963 |
+
dropout=dropout)
|
| 964 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 965 |
+
self.mid.block_2 = ResnetBlock(
|
| 966 |
+
in_channels=block_in,
|
| 967 |
+
out_channels=block_in,
|
| 968 |
+
temb_channels=self.temb_ch,
|
| 969 |
+
dropout=dropout)
|
| 970 |
+
|
| 971 |
+
# upsampling
|
| 972 |
+
self.up = nn.ModuleList()
|
| 973 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 974 |
+
block = nn.ModuleList()
|
| 975 |
+
attn = nn.ModuleList()
|
| 976 |
+
block_out = ch * ch_mult[i_level]
|
| 977 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 978 |
+
block.append(
|
| 979 |
+
ResnetBlock(
|
| 980 |
+
in_channels=block_in,
|
| 981 |
+
out_channels=block_out,
|
| 982 |
+
temb_channels=self.temb_ch,
|
| 983 |
+
dropout=dropout))
|
| 984 |
+
block_in = block_out
|
| 985 |
+
if curr_res in attn_resolutions:
|
| 986 |
+
attn.append(AttnBlock(block_in))
|
| 987 |
+
up = nn.Module()
|
| 988 |
+
up.block = block
|
| 989 |
+
up.attn = attn
|
| 990 |
+
if i_level != 0:
|
| 991 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 992 |
+
curr_res = curr_res * 2
|
| 993 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 994 |
+
|
| 995 |
+
# end
|
| 996 |
+
self.norm_out = Normalize(block_in)
|
| 997 |
+
self.conv_out = torch.nn.Conv2d(
|
| 998 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 999 |
+
|
| 1000 |
+
def forward(self, z, bot_h=None):
|
| 1001 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
| 1002 |
+
self.last_z_shape = z.shape
|
| 1003 |
+
|
| 1004 |
+
# timestep embedding
|
| 1005 |
+
temb = None
|
| 1006 |
+
|
| 1007 |
+
# z to block_in
|
| 1008 |
+
h = self.conv_in(z)
|
| 1009 |
+
|
| 1010 |
+
# middle
|
| 1011 |
+
h = self.mid.block_1(h, temb)
|
| 1012 |
+
h = self.mid.attn_1(h)
|
| 1013 |
+
h = self.mid.block_2(h, temb)
|
| 1014 |
+
|
| 1015 |
+
# upsampling
|
| 1016 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 1017 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 1018 |
+
h = self.up[i_level].block[i_block](h, temb)
|
| 1019 |
+
if len(self.up[i_level].attn) > 0:
|
| 1020 |
+
h = self.up[i_level].attn[i_block](h)
|
| 1021 |
+
if i_level != 0:
|
| 1022 |
+
h = self.up[i_level].upsample(h)
|
| 1023 |
+
if i_level == 4 and bot_h is not None:
|
| 1024 |
+
h += bot_h
|
| 1025 |
+
|
| 1026 |
+
# end
|
| 1027 |
+
if self.give_pre_end:
|
| 1028 |
+
return h
|
| 1029 |
+
|
| 1030 |
+
h = self.norm_out(h)
|
| 1031 |
+
h = nonlinearity(h)
|
| 1032 |
+
h = self.conv_out(h)
|
| 1033 |
+
return h
|
| 1034 |
+
|
| 1035 |
+
def get_feature_top(self, z):
|
| 1036 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
| 1037 |
+
self.last_z_shape = z.shape
|
| 1038 |
+
|
| 1039 |
+
# timestep embedding
|
| 1040 |
+
temb = None
|
| 1041 |
+
|
| 1042 |
+
# z to block_in
|
| 1043 |
+
h = self.conv_in(z)
|
| 1044 |
+
|
| 1045 |
+
# middle
|
| 1046 |
+
h = self.mid.block_1(h, temb)
|
| 1047 |
+
h = self.mid.attn_1(h)
|
| 1048 |
+
h = self.mid.block_2(h, temb)
|
| 1049 |
+
|
| 1050 |
+
# upsampling
|
| 1051 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 1052 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 1053 |
+
h = self.up[i_level].block[i_block](h, temb)
|
| 1054 |
+
if len(self.up[i_level].attn) > 0:
|
| 1055 |
+
h = self.up[i_level].attn[i_block](h)
|
| 1056 |
+
if i_level != 0:
|
| 1057 |
+
h = self.up[i_level].upsample(h)
|
| 1058 |
+
if i_level == 4:
|
| 1059 |
+
return h
|
| 1060 |
+
|
| 1061 |
+
def get_feature_middle(self, z, mid_h):
|
| 1062 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
| 1063 |
+
self.last_z_shape = z.shape
|
| 1064 |
+
|
| 1065 |
+
# timestep embedding
|
| 1066 |
+
temb = None
|
| 1067 |
+
|
| 1068 |
+
# z to block_in
|
| 1069 |
+
h = self.conv_in(z)
|
| 1070 |
+
|
| 1071 |
+
# middle
|
| 1072 |
+
h = self.mid.block_1(h, temb)
|
| 1073 |
+
h = self.mid.attn_1(h)
|
| 1074 |
+
h = self.mid.block_2(h, temb)
|
| 1075 |
+
|
| 1076 |
+
# upsampling
|
| 1077 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 1078 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 1079 |
+
h = self.up[i_level].block[i_block](h, temb)
|
| 1080 |
+
if len(self.up[i_level].attn) > 0:
|
| 1081 |
+
h = self.up[i_level].attn[i_block](h)
|
| 1082 |
+
if i_level != 0:
|
| 1083 |
+
h = self.up[i_level].upsample(h)
|
| 1084 |
+
if i_level == 4:
|
| 1085 |
+
h += mid_h
|
| 1086 |
+
if i_level == 3:
|
| 1087 |
+
return h
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
class DecoderRes(nn.Module):
|
| 1091 |
+
|
| 1092 |
+
def __init__(self,
|
| 1093 |
+
in_channels,
|
| 1094 |
+
resolution,
|
| 1095 |
+
z_channels,
|
| 1096 |
+
ch,
|
| 1097 |
+
num_res_blocks,
|
| 1098 |
+
ch_mult=(1, 2, 4, 8),
|
| 1099 |
+
dropout=0.0,
|
| 1100 |
+
give_pre_end=False):
|
| 1101 |
+
super().__init__()
|
| 1102 |
+
self.ch = ch
|
| 1103 |
+
self.temb_ch = 0
|
| 1104 |
+
self.num_resolutions = len(ch_mult)
|
| 1105 |
+
self.num_res_blocks = num_res_blocks
|
| 1106 |
+
self.resolution = resolution
|
| 1107 |
+
self.in_channels = in_channels
|
| 1108 |
+
self.give_pre_end = give_pre_end
|
| 1109 |
+
|
| 1110 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 1111 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
| 1112 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 1113 |
+
curr_res = resolution // 2**(self.num_resolutions - 1)
|
| 1114 |
+
self.z_shape = (1, z_channels, curr_res, curr_res // 2)
|
| 1115 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
| 1116 |
+
self.z_shape, np.prod(self.z_shape)))
|
| 1117 |
+
|
| 1118 |
+
# z to block_in
|
| 1119 |
+
self.conv_in = torch.nn.Conv2d(
|
| 1120 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 1121 |
+
|
| 1122 |
+
# middle
|
| 1123 |
+
self.mid = nn.Module()
|
| 1124 |
+
self.mid.block_1 = ResnetBlock(
|
| 1125 |
+
in_channels=block_in,
|
| 1126 |
+
out_channels=block_in,
|
| 1127 |
+
temb_channels=self.temb_ch,
|
| 1128 |
+
dropout=dropout)
|
| 1129 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 1130 |
+
self.mid.block_2 = ResnetBlock(
|
| 1131 |
+
in_channels=block_in,
|
| 1132 |
+
out_channels=block_in,
|
| 1133 |
+
temb_channels=self.temb_ch,
|
| 1134 |
+
dropout=dropout)
|
| 1135 |
+
|
| 1136 |
+
def forward(self, z):
|
| 1137 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
| 1138 |
+
self.last_z_shape = z.shape
|
| 1139 |
+
|
| 1140 |
+
# timestep embedding
|
| 1141 |
+
temb = None
|
| 1142 |
+
|
| 1143 |
+
# z to block_in
|
| 1144 |
+
h = self.conv_in(z)
|
| 1145 |
+
|
| 1146 |
+
# middle
|
| 1147 |
+
h = self.mid.block_1(h, temb)
|
| 1148 |
+
h = self.mid.attn_1(h)
|
| 1149 |
+
h = self.mid.block_2(h, temb)
|
| 1150 |
+
|
| 1151 |
+
return h
|
| 1152 |
+
|
| 1153 |
+
|
| 1154 |
+
# patch based discriminator
|
| 1155 |
+
class Discriminator(nn.Module):
|
| 1156 |
+
|
| 1157 |
+
def __init__(self, nc, ndf, n_layers=3):
|
| 1158 |
+
super().__init__()
|
| 1159 |
+
|
| 1160 |
+
layers = [
|
| 1161 |
+
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
|
| 1162 |
+
nn.LeakyReLU(0.2, True)
|
| 1163 |
+
]
|
| 1164 |
+
ndf_mult = 1
|
| 1165 |
+
ndf_mult_prev = 1
|
| 1166 |
+
for n in range(1,
|
| 1167 |
+
n_layers): # gradually increase the number of filters
|
| 1168 |
+
ndf_mult_prev = ndf_mult
|
| 1169 |
+
ndf_mult = min(2**n, 8)
|
| 1170 |
+
layers += [
|
| 1171 |
+
nn.Conv2d(
|
| 1172 |
+
ndf * ndf_mult_prev,
|
| 1173 |
+
ndf * ndf_mult,
|
| 1174 |
+
kernel_size=4,
|
| 1175 |
+
stride=2,
|
| 1176 |
+
padding=1,
|
| 1177 |
+
bias=False),
|
| 1178 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
| 1179 |
+
nn.LeakyReLU(0.2, True)
|
| 1180 |
+
]
|
| 1181 |
+
|
| 1182 |
+
ndf_mult_prev = ndf_mult
|
| 1183 |
+
ndf_mult = min(2**n_layers, 8)
|
| 1184 |
+
|
| 1185 |
+
layers += [
|
| 1186 |
+
nn.Conv2d(
|
| 1187 |
+
ndf * ndf_mult_prev,
|
| 1188 |
+
ndf * ndf_mult,
|
| 1189 |
+
kernel_size=4,
|
| 1190 |
+
stride=1,
|
| 1191 |
+
padding=1,
|
| 1192 |
+
bias=False),
|
| 1193 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
| 1194 |
+
nn.LeakyReLU(0.2, True)
|
| 1195 |
+
]
|
| 1196 |
+
|
| 1197 |
+
layers += [
|
| 1198 |
+
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
|
| 1199 |
+
] # output 1 channel prediction map
|
| 1200 |
+
self.main = nn.Sequential(*layers)
|
| 1201 |
+
|
| 1202 |
+
def forward(self, x):
|
| 1203 |
+
return self.main(x)
|
models/hierarchy_inference_model.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torchvision.utils import save_image
|
| 8 |
+
|
| 9 |
+
from models.archs.fcn_arch import MultiHeadFCNHead
|
| 10 |
+
from models.archs.unet_arch import UNet
|
| 11 |
+
from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder,
|
| 12 |
+
VectorQuantizerSpatialTextureAware,
|
| 13 |
+
VectorQuantizerTexture)
|
| 14 |
+
from models.losses.accuracy import accuracy
|
| 15 |
+
from models.losses.cross_entropy_loss import CrossEntropyLoss
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger('base')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class VQGANTextureAwareSpatialHierarchyInferenceModel():
|
| 21 |
+
|
| 22 |
+
def __init__(self, opt):
|
| 23 |
+
self.opt = opt
|
| 24 |
+
self.device = torch.device('cuda')
|
| 25 |
+
self.is_train = opt['is_train']
|
| 26 |
+
|
| 27 |
+
self.top_encoder = Encoder(
|
| 28 |
+
ch=opt['top_ch'],
|
| 29 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
| 30 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
| 31 |
+
ch_mult=opt['top_ch_mult'],
|
| 32 |
+
in_channels=opt['top_in_channels'],
|
| 33 |
+
resolution=opt['top_resolution'],
|
| 34 |
+
z_channels=opt['top_z_channels'],
|
| 35 |
+
double_z=opt['top_double_z'],
|
| 36 |
+
dropout=opt['top_dropout']).to(self.device)
|
| 37 |
+
self.decoder = Decoder(
|
| 38 |
+
in_channels=opt['top_in_channels'],
|
| 39 |
+
resolution=opt['top_resolution'],
|
| 40 |
+
z_channels=opt['top_z_channels'],
|
| 41 |
+
ch=opt['top_ch'],
|
| 42 |
+
out_ch=opt['top_out_ch'],
|
| 43 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
| 44 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
| 45 |
+
ch_mult=opt['top_ch_mult'],
|
| 46 |
+
dropout=opt['top_dropout'],
|
| 47 |
+
resamp_with_conv=True,
|
| 48 |
+
give_pre_end=False).to(self.device)
|
| 49 |
+
self.top_quantize = VectorQuantizerTexture(
|
| 50 |
+
1024, opt['embed_dim'], beta=0.25).to(self.device)
|
| 51 |
+
self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
|
| 52 |
+
opt['embed_dim'],
|
| 53 |
+
1).to(self.device)
|
| 54 |
+
self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
| 55 |
+
opt["top_z_channels"],
|
| 56 |
+
1).to(self.device)
|
| 57 |
+
self.load_top_pretrain_models()
|
| 58 |
+
|
| 59 |
+
self.bot_encoder = Encoder(
|
| 60 |
+
ch=opt['bot_ch'],
|
| 61 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
| 62 |
+
attn_resolutions=opt['bot_attn_resolutions'],
|
| 63 |
+
ch_mult=opt['bot_ch_mult'],
|
| 64 |
+
in_channels=opt['bot_in_channels'],
|
| 65 |
+
resolution=opt['bot_resolution'],
|
| 66 |
+
z_channels=opt['bot_z_channels'],
|
| 67 |
+
double_z=opt['bot_double_z'],
|
| 68 |
+
dropout=opt['bot_dropout']).to(self.device)
|
| 69 |
+
self.bot_decoder_res = DecoderRes(
|
| 70 |
+
in_channels=opt['bot_in_channels'],
|
| 71 |
+
resolution=opt['bot_resolution'],
|
| 72 |
+
z_channels=opt['bot_z_channels'],
|
| 73 |
+
ch=opt['bot_ch'],
|
| 74 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
| 75 |
+
ch_mult=opt['bot_ch_mult'],
|
| 76 |
+
dropout=opt['bot_dropout'],
|
| 77 |
+
give_pre_end=False).to(self.device)
|
| 78 |
+
self.bot_quantize = VectorQuantizerSpatialTextureAware(
|
| 79 |
+
opt['bot_n_embed'],
|
| 80 |
+
opt['embed_dim'],
|
| 81 |
+
beta=0.25,
|
| 82 |
+
spatial_size=opt['codebook_spatial_size']).to(self.device)
|
| 83 |
+
self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
|
| 84 |
+
opt['embed_dim'],
|
| 85 |
+
1).to(self.device)
|
| 86 |
+
self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
| 87 |
+
opt["bot_z_channels"],
|
| 88 |
+
1).to(self.device)
|
| 89 |
+
|
| 90 |
+
self.load_bot_pretrain_network()
|
| 91 |
+
|
| 92 |
+
self.guidance_encoder = UNet(
|
| 93 |
+
in_channels=opt['encoder_in_channels']).to(self.device)
|
| 94 |
+
self.index_decoder = MultiHeadFCNHead(
|
| 95 |
+
in_channels=opt['fc_in_channels'],
|
| 96 |
+
in_index=opt['fc_in_index'],
|
| 97 |
+
channels=opt['fc_channels'],
|
| 98 |
+
num_convs=opt['fc_num_convs'],
|
| 99 |
+
concat_input=opt['fc_concat_input'],
|
| 100 |
+
dropout_ratio=opt['fc_dropout_ratio'],
|
| 101 |
+
num_classes=opt['fc_num_classes'],
|
| 102 |
+
align_corners=opt['fc_align_corners'],
|
| 103 |
+
num_head=18).to(self.device)
|
| 104 |
+
|
| 105 |
+
self.init_training_settings()
|
| 106 |
+
|
| 107 |
+
def init_training_settings(self):
|
| 108 |
+
optim_params = []
|
| 109 |
+
for v in self.guidance_encoder.parameters():
|
| 110 |
+
if v.requires_grad:
|
| 111 |
+
optim_params.append(v)
|
| 112 |
+
for v in self.index_decoder.parameters():
|
| 113 |
+
if v.requires_grad:
|
| 114 |
+
optim_params.append(v)
|
| 115 |
+
# set up optimizers
|
| 116 |
+
if self.opt['optimizer'] == 'Adam':
|
| 117 |
+
self.optimizer = torch.optim.Adam(
|
| 118 |
+
optim_params,
|
| 119 |
+
self.opt['lr'],
|
| 120 |
+
weight_decay=self.opt['weight_decay'])
|
| 121 |
+
elif self.opt['optimizer'] == 'SGD':
|
| 122 |
+
self.optimizer = torch.optim.SGD(
|
| 123 |
+
optim_params,
|
| 124 |
+
self.opt['lr'],
|
| 125 |
+
momentum=self.opt['momentum'],
|
| 126 |
+
weight_decay=self.opt['weight_decay'])
|
| 127 |
+
self.log_dict = OrderedDict()
|
| 128 |
+
if self.opt['loss_function'] == 'cross_entropy':
|
| 129 |
+
self.loss_func = CrossEntropyLoss().to(self.device)
|
| 130 |
+
|
| 131 |
+
def load_top_pretrain_models(self):
|
| 132 |
+
# load pretrained vqgan for segmentation mask
|
| 133 |
+
top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
|
| 134 |
+
self.top_encoder.load_state_dict(
|
| 135 |
+
top_vae_checkpoint['encoder'], strict=True)
|
| 136 |
+
self.decoder.load_state_dict(
|
| 137 |
+
top_vae_checkpoint['decoder'], strict=True)
|
| 138 |
+
self.top_quantize.load_state_dict(
|
| 139 |
+
top_vae_checkpoint['quantize'], strict=True)
|
| 140 |
+
self.top_quant_conv.load_state_dict(
|
| 141 |
+
top_vae_checkpoint['quant_conv'], strict=True)
|
| 142 |
+
self.top_post_quant_conv.load_state_dict(
|
| 143 |
+
top_vae_checkpoint['post_quant_conv'], strict=True)
|
| 144 |
+
self.top_encoder.eval()
|
| 145 |
+
self.top_quantize.eval()
|
| 146 |
+
self.top_quant_conv.eval()
|
| 147 |
+
self.top_post_quant_conv.eval()
|
| 148 |
+
|
| 149 |
+
def load_bot_pretrain_network(self):
|
| 150 |
+
checkpoint = torch.load(self.opt['bot_vae_path'])
|
| 151 |
+
self.bot_encoder.load_state_dict(
|
| 152 |
+
checkpoint['bot_encoder'], strict=True)
|
| 153 |
+
self.bot_decoder_res.load_state_dict(
|
| 154 |
+
checkpoint['bot_decoder_res'], strict=True)
|
| 155 |
+
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
|
| 156 |
+
self.bot_quantize.load_state_dict(
|
| 157 |
+
checkpoint['bot_quantize'], strict=True)
|
| 158 |
+
self.bot_quant_conv.load_state_dict(
|
| 159 |
+
checkpoint['bot_quant_conv'], strict=True)
|
| 160 |
+
self.bot_post_quant_conv.load_state_dict(
|
| 161 |
+
checkpoint['bot_post_quant_conv'], strict=True)
|
| 162 |
+
|
| 163 |
+
self.bot_encoder.eval()
|
| 164 |
+
self.bot_decoder_res.eval()
|
| 165 |
+
self.decoder.eval()
|
| 166 |
+
self.bot_quantize.eval()
|
| 167 |
+
self.bot_quant_conv.eval()
|
| 168 |
+
self.bot_post_quant_conv.eval()
|
| 169 |
+
|
| 170 |
+
def top_encode(self, x, mask):
|
| 171 |
+
h = self.top_encoder(x)
|
| 172 |
+
h = self.top_quant_conv(h)
|
| 173 |
+
quant, _, _ = self.top_quantize(h, mask)
|
| 174 |
+
quant = self.top_post_quant_conv(quant)
|
| 175 |
+
|
| 176 |
+
return quant, quant
|
| 177 |
+
|
| 178 |
+
def feed_data(self, data):
|
| 179 |
+
self.image = data['image'].to(self.device)
|
| 180 |
+
self.texture_mask = data['texture_mask'].float().to(self.device)
|
| 181 |
+
self.get_gt_indices()
|
| 182 |
+
|
| 183 |
+
self.texture_tokens = F.interpolate(
|
| 184 |
+
self.texture_mask, size=(32, 16),
|
| 185 |
+
mode='nearest').view(self.image.size(0), -1).long()
|
| 186 |
+
|
| 187 |
+
def bot_encode(self, x, mask):
|
| 188 |
+
h = self.bot_encoder(x)
|
| 189 |
+
h = self.bot_quant_conv(h)
|
| 190 |
+
_, _, (_, _, indices_list) = self.bot_quantize(h, mask)
|
| 191 |
+
|
| 192 |
+
return indices_list
|
| 193 |
+
|
| 194 |
+
def get_gt_indices(self):
|
| 195 |
+
self.quant_t, self.feature_t = self.top_encode(self.image,
|
| 196 |
+
self.texture_mask)
|
| 197 |
+
self.gt_indices_list = self.bot_encode(self.image, self.texture_mask)
|
| 198 |
+
|
| 199 |
+
def index_to_image(self, index_bottom_list, texture_mask):
|
| 200 |
+
quant_b = self.bot_quantize.get_codebook_entry(
|
| 201 |
+
index_bottom_list, texture_mask,
|
| 202 |
+
(index_bottom_list[0].size(0), index_bottom_list[0].size(1),
|
| 203 |
+
index_bottom_list[0].size(2),
|
| 204 |
+
self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2)
|
| 205 |
+
quant_b = self.bot_post_quant_conv(quant_b)
|
| 206 |
+
bot_dec_res = self.bot_decoder_res(quant_b)
|
| 207 |
+
|
| 208 |
+
dec = self.decoder(self.quant_t, bot_h=bot_dec_res)
|
| 209 |
+
|
| 210 |
+
return dec
|
| 211 |
+
|
| 212 |
+
def get_vis(self, pred_img_index, rec_img_index, texture_mask, save_path):
|
| 213 |
+
rec_img = self.index_to_image(rec_img_index, texture_mask)
|
| 214 |
+
pred_img = self.index_to_image(pred_img_index, texture_mask)
|
| 215 |
+
|
| 216 |
+
base_img = self.decoder(self.quant_t)
|
| 217 |
+
img_cat = torch.cat([
|
| 218 |
+
self.image,
|
| 219 |
+
rec_img,
|
| 220 |
+
base_img,
|
| 221 |
+
pred_img,
|
| 222 |
+
], dim=3).detach()
|
| 223 |
+
img_cat = ((img_cat + 1) / 2)
|
| 224 |
+
img_cat = img_cat.clamp_(0, 1)
|
| 225 |
+
save_image(img_cat, save_path, nrow=1, padding=4)
|
| 226 |
+
|
| 227 |
+
def optimize_parameters(self):
|
| 228 |
+
self.guidance_encoder.train()
|
| 229 |
+
self.index_decoder.train()
|
| 230 |
+
|
| 231 |
+
self.feature_enc = self.guidance_encoder(self.feature_t)
|
| 232 |
+
self.memory_logits_list = self.index_decoder(self.feature_enc)
|
| 233 |
+
|
| 234 |
+
loss = 0
|
| 235 |
+
for i in range(18):
|
| 236 |
+
loss += self.loss_func(
|
| 237 |
+
self.memory_logits_list[i],
|
| 238 |
+
self.gt_indices_list[i],
|
| 239 |
+
ignore_index=-1)
|
| 240 |
+
|
| 241 |
+
self.optimizer.zero_grad()
|
| 242 |
+
loss.backward()
|
| 243 |
+
self.optimizer.step()
|
| 244 |
+
|
| 245 |
+
self.log_dict['loss_total'] = loss
|
| 246 |
+
|
| 247 |
+
def inference(self, data_loader, save_dir):
|
| 248 |
+
self.guidance_encoder.eval()
|
| 249 |
+
self.index_decoder.eval()
|
| 250 |
+
|
| 251 |
+
acc = 0
|
| 252 |
+
num = 0
|
| 253 |
+
|
| 254 |
+
for _, data in enumerate(data_loader):
|
| 255 |
+
self.feed_data(data)
|
| 256 |
+
img_name = data['img_name']
|
| 257 |
+
|
| 258 |
+
num += self.image.size(0)
|
| 259 |
+
|
| 260 |
+
texture_mask_flatten = self.texture_tokens.view(-1)
|
| 261 |
+
min_encodings_indices_list = [
|
| 262 |
+
torch.full(
|
| 263 |
+
texture_mask_flatten.size(),
|
| 264 |
+
fill_value=-1,
|
| 265 |
+
dtype=torch.long,
|
| 266 |
+
device=texture_mask_flatten.device) for _ in range(18)
|
| 267 |
+
]
|
| 268 |
+
with torch.no_grad():
|
| 269 |
+
self.feature_enc = self.guidance_encoder(self.feature_t)
|
| 270 |
+
memory_logits_list = self.index_decoder(self.feature_enc)
|
| 271 |
+
# memory_indices_pred = memory_logits.argmax(dim=1)
|
| 272 |
+
batch_acc = 0
|
| 273 |
+
for codebook_idx, memory_logits in enumerate(memory_logits_list):
|
| 274 |
+
region_of_interest = texture_mask_flatten == codebook_idx
|
| 275 |
+
if torch.sum(region_of_interest) > 0:
|
| 276 |
+
memory_indices_pred = memory_logits.argmax(dim=1).view(-1)
|
| 277 |
+
batch_acc += torch.sum(
|
| 278 |
+
memory_indices_pred[region_of_interest] ==
|
| 279 |
+
self.gt_indices_list[codebook_idx].view(
|
| 280 |
+
-1)[region_of_interest])
|
| 281 |
+
memory_indices_pred = memory_indices_pred
|
| 282 |
+
min_encodings_indices_list[codebook_idx][
|
| 283 |
+
region_of_interest] = memory_indices_pred[
|
| 284 |
+
region_of_interest]
|
| 285 |
+
min_encodings_indices_return_list = [
|
| 286 |
+
min_encodings_indices.view(self.gt_indices_list[0].size())
|
| 287 |
+
for min_encodings_indices in min_encodings_indices_list
|
| 288 |
+
]
|
| 289 |
+
batch_acc = batch_acc / self.gt_indices_list[codebook_idx].numel(
|
| 290 |
+
) * self.image.size(0)
|
| 291 |
+
acc += batch_acc
|
| 292 |
+
self.get_vis(min_encodings_indices_return_list,
|
| 293 |
+
self.gt_indices_list, self.texture_mask,
|
| 294 |
+
f'{save_dir}/{img_name[0]}')
|
| 295 |
+
|
| 296 |
+
self.guidance_encoder.train()
|
| 297 |
+
self.index_decoder.train()
|
| 298 |
+
return (acc / num).item()
|
| 299 |
+
|
| 300 |
+
def load_network(self):
|
| 301 |
+
checkpoint = torch.load(self.opt['pretrained_models'])
|
| 302 |
+
self.guidance_encoder.load_state_dict(
|
| 303 |
+
checkpoint['guidance_encoder'], strict=True)
|
| 304 |
+
self.guidance_encoder.eval()
|
| 305 |
+
|
| 306 |
+
self.index_decoder.load_state_dict(
|
| 307 |
+
checkpoint['index_decoder'], strict=True)
|
| 308 |
+
self.index_decoder.eval()
|
| 309 |
+
|
| 310 |
+
def save_network(self, save_path):
|
| 311 |
+
"""Save networks.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
net (nn.Module): Network to be saved.
|
| 315 |
+
net_label (str): Network label.
|
| 316 |
+
current_iter (int): Current iter number.
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
save_dict = {}
|
| 320 |
+
save_dict['guidance_encoder'] = self.guidance_encoder.state_dict()
|
| 321 |
+
save_dict['index_decoder'] = self.index_decoder.state_dict()
|
| 322 |
+
|
| 323 |
+
torch.save(save_dict, save_path)
|
| 324 |
+
|
| 325 |
+
def update_learning_rate(self, epoch):
|
| 326 |
+
"""Update learning rate.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
current_iter (int): Current iteration.
|
| 330 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
| 331 |
+
Default: -1.
|
| 332 |
+
"""
|
| 333 |
+
lr = self.optimizer.param_groups[0]['lr']
|
| 334 |
+
|
| 335 |
+
if self.opt['lr_decay'] == 'step':
|
| 336 |
+
lr = self.opt['lr'] * (
|
| 337 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
| 338 |
+
elif self.opt['lr_decay'] == 'cos':
|
| 339 |
+
lr = self.opt['lr'] * (
|
| 340 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
| 341 |
+
elif self.opt['lr_decay'] == 'linear':
|
| 342 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
| 343 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
| 344 |
+
if epoch < self.opt['turning_point'] + 1:
|
| 345 |
+
# learning rate decay as 95%
|
| 346 |
+
# at the turning point (1 / 95% = 1.0526)
|
| 347 |
+
lr = self.opt['lr'] * (
|
| 348 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
| 349 |
+
else:
|
| 350 |
+
lr *= self.opt['gamma']
|
| 351 |
+
elif self.opt['lr_decay'] == 'schedule':
|
| 352 |
+
if epoch in self.opt['schedule']:
|
| 353 |
+
lr *= self.opt['gamma']
|
| 354 |
+
else:
|
| 355 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
| 356 |
+
# set learning rate
|
| 357 |
+
for param_group in self.optimizer.param_groups:
|
| 358 |
+
param_group['lr'] = lr
|
| 359 |
+
|
| 360 |
+
return lr
|
| 361 |
+
|
| 362 |
+
def get_current_log(self):
|
| 363 |
+
return self.log_dict
|
models/hierarchy_vqgan_model.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import sys
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
sys.path.append('..')
|
| 6 |
+
import lpips
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torchvision.utils import save_image
|
| 10 |
+
|
| 11 |
+
from models.archs.vqgan_arch import (Decoder, DecoderRes, Discriminator,
|
| 12 |
+
Encoder,
|
| 13 |
+
VectorQuantizerSpatialTextureAware,
|
| 14 |
+
VectorQuantizerTexture)
|
| 15 |
+
from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
|
| 16 |
+
calculate_adaptive_weight, hinge_d_loss)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class HierarchyVQSpatialTextureAwareModel():
|
| 20 |
+
|
| 21 |
+
def __init__(self, opt):
|
| 22 |
+
self.opt = opt
|
| 23 |
+
self.device = torch.device('cuda')
|
| 24 |
+
self.top_encoder = Encoder(
|
| 25 |
+
ch=opt['top_ch'],
|
| 26 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
| 27 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
| 28 |
+
ch_mult=opt['top_ch_mult'],
|
| 29 |
+
in_channels=opt['top_in_channels'],
|
| 30 |
+
resolution=opt['top_resolution'],
|
| 31 |
+
z_channels=opt['top_z_channels'],
|
| 32 |
+
double_z=opt['top_double_z'],
|
| 33 |
+
dropout=opt['top_dropout']).to(self.device)
|
| 34 |
+
self.decoder = Decoder(
|
| 35 |
+
in_channels=opt['top_in_channels'],
|
| 36 |
+
resolution=opt['top_resolution'],
|
| 37 |
+
z_channels=opt['top_z_channels'],
|
| 38 |
+
ch=opt['top_ch'],
|
| 39 |
+
out_ch=opt['top_out_ch'],
|
| 40 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
| 41 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
| 42 |
+
ch_mult=opt['top_ch_mult'],
|
| 43 |
+
dropout=opt['top_dropout'],
|
| 44 |
+
resamp_with_conv=True,
|
| 45 |
+
give_pre_end=False).to(self.device)
|
| 46 |
+
self.top_quantize = VectorQuantizerTexture(
|
| 47 |
+
1024, opt['embed_dim'], beta=0.25).to(self.device)
|
| 48 |
+
self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
|
| 49 |
+
opt['embed_dim'],
|
| 50 |
+
1).to(self.device)
|
| 51 |
+
self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
| 52 |
+
opt["top_z_channels"],
|
| 53 |
+
1).to(self.device)
|
| 54 |
+
self.load_top_pretrain_models()
|
| 55 |
+
|
| 56 |
+
self.bot_encoder = Encoder(
|
| 57 |
+
ch=opt['bot_ch'],
|
| 58 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
| 59 |
+
attn_resolutions=opt['bot_attn_resolutions'],
|
| 60 |
+
ch_mult=opt['bot_ch_mult'],
|
| 61 |
+
in_channels=opt['bot_in_channels'],
|
| 62 |
+
resolution=opt['bot_resolution'],
|
| 63 |
+
z_channels=opt['bot_z_channels'],
|
| 64 |
+
double_z=opt['bot_double_z'],
|
| 65 |
+
dropout=opt['bot_dropout']).to(self.device)
|
| 66 |
+
self.bot_decoder_res = DecoderRes(
|
| 67 |
+
in_channels=opt['bot_in_channels'],
|
| 68 |
+
resolution=opt['bot_resolution'],
|
| 69 |
+
z_channels=opt['bot_z_channels'],
|
| 70 |
+
ch=opt['bot_ch'],
|
| 71 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
| 72 |
+
ch_mult=opt['bot_ch_mult'],
|
| 73 |
+
dropout=opt['bot_dropout'],
|
| 74 |
+
give_pre_end=False).to(self.device)
|
| 75 |
+
self.bot_quantize = VectorQuantizerSpatialTextureAware(
|
| 76 |
+
opt['bot_n_embed'],
|
| 77 |
+
opt['embed_dim'],
|
| 78 |
+
beta=0.25,
|
| 79 |
+
spatial_size=opt['codebook_spatial_size']).to(self.device)
|
| 80 |
+
self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
|
| 81 |
+
opt['embed_dim'],
|
| 82 |
+
1).to(self.device)
|
| 83 |
+
self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
| 84 |
+
opt["bot_z_channels"],
|
| 85 |
+
1).to(self.device)
|
| 86 |
+
|
| 87 |
+
self.disc = Discriminator(
|
| 88 |
+
opt['n_channels'], opt['ndf'],
|
| 89 |
+
n_layers=opt['disc_layers']).to(self.device)
|
| 90 |
+
self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
|
| 91 |
+
self.perceptual_weight = opt['perceptual_weight']
|
| 92 |
+
self.disc_start_step = opt['disc_start_step']
|
| 93 |
+
self.disc_weight_max = opt['disc_weight_max']
|
| 94 |
+
self.diff_aug = opt['diff_aug']
|
| 95 |
+
self.policy = "color,translation"
|
| 96 |
+
|
| 97 |
+
self.load_discriminator_models()
|
| 98 |
+
|
| 99 |
+
self.disc.train()
|
| 100 |
+
|
| 101 |
+
self.fix_decoder = opt['fix_decoder']
|
| 102 |
+
|
| 103 |
+
self.init_training_settings()
|
| 104 |
+
|
| 105 |
+
def load_top_pretrain_models(self):
|
| 106 |
+
# load pretrained vqgan for segmentation mask
|
| 107 |
+
top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
|
| 108 |
+
self.top_encoder.load_state_dict(
|
| 109 |
+
top_vae_checkpoint['encoder'], strict=True)
|
| 110 |
+
self.decoder.load_state_dict(
|
| 111 |
+
top_vae_checkpoint['decoder'], strict=True)
|
| 112 |
+
self.top_quantize.load_state_dict(
|
| 113 |
+
top_vae_checkpoint['quantize'], strict=True)
|
| 114 |
+
self.top_quant_conv.load_state_dict(
|
| 115 |
+
top_vae_checkpoint['quant_conv'], strict=True)
|
| 116 |
+
self.top_post_quant_conv.load_state_dict(
|
| 117 |
+
top_vae_checkpoint['post_quant_conv'], strict=True)
|
| 118 |
+
self.top_encoder.eval()
|
| 119 |
+
self.top_quantize.eval()
|
| 120 |
+
self.top_quant_conv.eval()
|
| 121 |
+
self.top_post_quant_conv.eval()
|
| 122 |
+
|
| 123 |
+
def init_training_settings(self):
|
| 124 |
+
self.log_dict = OrderedDict()
|
| 125 |
+
self.configure_optimizers()
|
| 126 |
+
|
| 127 |
+
def configure_optimizers(self):
|
| 128 |
+
optim_params = []
|
| 129 |
+
for v in self.bot_encoder.parameters():
|
| 130 |
+
if v.requires_grad:
|
| 131 |
+
optim_params.append(v)
|
| 132 |
+
for v in self.bot_decoder_res.parameters():
|
| 133 |
+
if v.requires_grad:
|
| 134 |
+
optim_params.append(v)
|
| 135 |
+
for v in self.bot_quantize.parameters():
|
| 136 |
+
if v.requires_grad:
|
| 137 |
+
optim_params.append(v)
|
| 138 |
+
for v in self.bot_quant_conv.parameters():
|
| 139 |
+
if v.requires_grad:
|
| 140 |
+
optim_params.append(v)
|
| 141 |
+
for v in self.bot_post_quant_conv.parameters():
|
| 142 |
+
if v.requires_grad:
|
| 143 |
+
optim_params.append(v)
|
| 144 |
+
if not self.fix_decoder:
|
| 145 |
+
for name, v in self.decoder.named_parameters():
|
| 146 |
+
if v.requires_grad:
|
| 147 |
+
if 'up.0' in name:
|
| 148 |
+
optim_params.append(v)
|
| 149 |
+
if 'up.1' in name:
|
| 150 |
+
optim_params.append(v)
|
| 151 |
+
if 'up.2' in name:
|
| 152 |
+
optim_params.append(v)
|
| 153 |
+
if 'up.3' in name:
|
| 154 |
+
optim_params.append(v)
|
| 155 |
+
|
| 156 |
+
self.optimizer = torch.optim.Adam(optim_params, lr=self.opt['lr'])
|
| 157 |
+
|
| 158 |
+
self.disc_optimizer = torch.optim.Adam(
|
| 159 |
+
self.disc.parameters(), lr=self.opt['lr'])
|
| 160 |
+
|
| 161 |
+
def load_discriminator_models(self):
|
| 162 |
+
# load pretrained vqgan for segmentation mask
|
| 163 |
+
top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
|
| 164 |
+
self.disc.load_state_dict(
|
| 165 |
+
top_vae_checkpoint['discriminator'], strict=True)
|
| 166 |
+
|
| 167 |
+
def save_network(self, save_path):
|
| 168 |
+
"""Save networks.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
save_dict = {}
|
| 172 |
+
save_dict['bot_encoder'] = self.bot_encoder.state_dict()
|
| 173 |
+
save_dict['bot_decoder_res'] = self.bot_decoder_res.state_dict()
|
| 174 |
+
save_dict['decoder'] = self.decoder.state_dict()
|
| 175 |
+
save_dict['bot_quantize'] = self.bot_quantize.state_dict()
|
| 176 |
+
save_dict['bot_quant_conv'] = self.bot_quant_conv.state_dict()
|
| 177 |
+
save_dict['bot_post_quant_conv'] = self.bot_post_quant_conv.state_dict(
|
| 178 |
+
)
|
| 179 |
+
save_dict['discriminator'] = self.disc.state_dict()
|
| 180 |
+
torch.save(save_dict, save_path)
|
| 181 |
+
|
| 182 |
+
def load_network(self):
|
| 183 |
+
checkpoint = torch.load(self.opt['pretrained_models'])
|
| 184 |
+
self.bot_encoder.load_state_dict(
|
| 185 |
+
checkpoint['bot_encoder'], strict=True)
|
| 186 |
+
self.bot_decoder_res.load_state_dict(
|
| 187 |
+
checkpoint['bot_decoder_res'], strict=True)
|
| 188 |
+
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
|
| 189 |
+
self.bot_quantize.load_state_dict(
|
| 190 |
+
checkpoint['bot_quantize'], strict=True)
|
| 191 |
+
self.bot_quant_conv.load_state_dict(
|
| 192 |
+
checkpoint['bot_quant_conv'], strict=True)
|
| 193 |
+
self.bot_post_quant_conv.load_state_dict(
|
| 194 |
+
checkpoint['bot_post_quant_conv'], strict=True)
|
| 195 |
+
|
| 196 |
+
def optimize_parameters(self, data, step):
|
| 197 |
+
self.bot_encoder.train()
|
| 198 |
+
self.bot_decoder_res.train()
|
| 199 |
+
if not self.fix_decoder:
|
| 200 |
+
self.decoder.train()
|
| 201 |
+
self.bot_quantize.train()
|
| 202 |
+
self.bot_quant_conv.train()
|
| 203 |
+
self.bot_post_quant_conv.train()
|
| 204 |
+
|
| 205 |
+
loss, d_loss = self.training_step(data, step)
|
| 206 |
+
self.optimizer.zero_grad()
|
| 207 |
+
loss.backward()
|
| 208 |
+
self.optimizer.step()
|
| 209 |
+
|
| 210 |
+
if step > self.disc_start_step:
|
| 211 |
+
self.disc_optimizer.zero_grad()
|
| 212 |
+
d_loss.backward()
|
| 213 |
+
self.disc_optimizer.step()
|
| 214 |
+
|
| 215 |
+
def top_encode(self, x, mask):
|
| 216 |
+
h = self.top_encoder(x)
|
| 217 |
+
h = self.top_quant_conv(h)
|
| 218 |
+
quant, _, _ = self.top_quantize(h, mask)
|
| 219 |
+
quant = self.top_post_quant_conv(quant)
|
| 220 |
+
return quant
|
| 221 |
+
|
| 222 |
+
def bot_encode(self, x, mask):
|
| 223 |
+
h = self.bot_encoder(x)
|
| 224 |
+
h = self.bot_quant_conv(h)
|
| 225 |
+
quant, emb_loss, info = self.bot_quantize(h, mask)
|
| 226 |
+
quant = self.bot_post_quant_conv(quant)
|
| 227 |
+
bot_dec_res = self.bot_decoder_res(quant)
|
| 228 |
+
return bot_dec_res, emb_loss, info
|
| 229 |
+
|
| 230 |
+
def decode(self, quant_top, bot_dec_res):
|
| 231 |
+
dec = self.decoder(quant_top, bot_h=bot_dec_res)
|
| 232 |
+
return dec
|
| 233 |
+
|
| 234 |
+
def forward_step(self, input, mask):
|
| 235 |
+
with torch.no_grad():
|
| 236 |
+
quant_top = self.top_encode(input, mask)
|
| 237 |
+
bot_dec_res, diff, _ = self.bot_encode(input, mask)
|
| 238 |
+
dec = self.decode(quant_top, bot_dec_res)
|
| 239 |
+
return dec, diff
|
| 240 |
+
|
| 241 |
+
def feed_data(self, data):
|
| 242 |
+
x = data['image'].float().to(self.device)
|
| 243 |
+
mask = data['texture_mask'].float().to(self.device)
|
| 244 |
+
|
| 245 |
+
return x, mask
|
| 246 |
+
|
| 247 |
+
def training_step(self, data, step):
|
| 248 |
+
x, mask = self.feed_data(data)
|
| 249 |
+
xrec, codebook_loss = self.forward_step(x, mask)
|
| 250 |
+
|
| 251 |
+
# get recon/perceptual loss
|
| 252 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
| 253 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
| 254 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
| 255 |
+
nll_loss = torch.mean(nll_loss)
|
| 256 |
+
|
| 257 |
+
# augment for input to discriminator
|
| 258 |
+
if self.diff_aug:
|
| 259 |
+
xrec = DiffAugment(xrec, policy=self.policy)
|
| 260 |
+
|
| 261 |
+
# update generator
|
| 262 |
+
logits_fake = self.disc(xrec)
|
| 263 |
+
g_loss = -torch.mean(logits_fake)
|
| 264 |
+
last_layer = self.decoder.conv_out.weight
|
| 265 |
+
d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
|
| 266 |
+
self.disc_weight_max)
|
| 267 |
+
d_weight *= adopt_weight(1, step, self.disc_start_step)
|
| 268 |
+
loss = nll_loss + d_weight * g_loss + codebook_loss
|
| 269 |
+
|
| 270 |
+
self.log_dict["loss"] = loss
|
| 271 |
+
self.log_dict["l1"] = recon_loss.mean().item()
|
| 272 |
+
self.log_dict["perceptual"] = p_loss.mean().item()
|
| 273 |
+
self.log_dict["nll_loss"] = nll_loss.item()
|
| 274 |
+
self.log_dict["g_loss"] = g_loss.item()
|
| 275 |
+
self.log_dict["d_weight"] = d_weight
|
| 276 |
+
self.log_dict["codebook_loss"] = codebook_loss.item()
|
| 277 |
+
|
| 278 |
+
if step > self.disc_start_step:
|
| 279 |
+
if self.diff_aug:
|
| 280 |
+
logits_real = self.disc(
|
| 281 |
+
DiffAugment(x.contiguous().detach(), policy=self.policy))
|
| 282 |
+
else:
|
| 283 |
+
logits_real = self.disc(x.contiguous().detach())
|
| 284 |
+
logits_fake = self.disc(xrec.contiguous().detach(
|
| 285 |
+
)) # detach so that generator isn"t also updated
|
| 286 |
+
d_loss = hinge_d_loss(logits_real, logits_fake)
|
| 287 |
+
self.log_dict["d_loss"] = d_loss
|
| 288 |
+
else:
|
| 289 |
+
d_loss = None
|
| 290 |
+
|
| 291 |
+
return loss, d_loss
|
| 292 |
+
|
| 293 |
+
@torch.no_grad()
|
| 294 |
+
def inference(self, data_loader, save_dir):
|
| 295 |
+
self.bot_encoder.eval()
|
| 296 |
+
self.bot_decoder_res.eval()
|
| 297 |
+
self.decoder.eval()
|
| 298 |
+
self.bot_quantize.eval()
|
| 299 |
+
self.bot_quant_conv.eval()
|
| 300 |
+
self.bot_post_quant_conv.eval()
|
| 301 |
+
|
| 302 |
+
loss_total = 0
|
| 303 |
+
num = 0
|
| 304 |
+
|
| 305 |
+
for _, data in enumerate(data_loader):
|
| 306 |
+
img_name = data['img_name'][0]
|
| 307 |
+
x, mask = self.feed_data(data)
|
| 308 |
+
xrec, _ = self.forward_step(x, mask)
|
| 309 |
+
|
| 310 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
| 311 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
| 312 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
| 313 |
+
nll_loss = torch.mean(nll_loss)
|
| 314 |
+
loss_total += nll_loss
|
| 315 |
+
|
| 316 |
+
num += x.size(0)
|
| 317 |
+
|
| 318 |
+
if x.shape[1] > 3:
|
| 319 |
+
# colorize with random projection
|
| 320 |
+
assert xrec.shape[1] > 3
|
| 321 |
+
# convert logits to indices
|
| 322 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
| 323 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
| 324 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
| 325 |
+
x = self.to_rgb(x)
|
| 326 |
+
xrec = self.to_rgb(xrec)
|
| 327 |
+
|
| 328 |
+
img_cat = torch.cat([x, xrec], dim=3).detach()
|
| 329 |
+
img_cat = ((img_cat + 1) / 2)
|
| 330 |
+
img_cat = img_cat.clamp_(0, 1)
|
| 331 |
+
save_image(
|
| 332 |
+
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
|
| 333 |
+
|
| 334 |
+
return (loss_total / num).item()
|
| 335 |
+
|
| 336 |
+
def get_current_log(self):
|
| 337 |
+
return self.log_dict
|
| 338 |
+
|
| 339 |
+
def update_learning_rate(self, epoch):
|
| 340 |
+
"""Update learning rate.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
current_iter (int): Current iteration.
|
| 344 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
| 345 |
+
Default: -1.
|
| 346 |
+
"""
|
| 347 |
+
lr = self.optimizer.param_groups[0]['lr']
|
| 348 |
+
|
| 349 |
+
if self.opt['lr_decay'] == 'step':
|
| 350 |
+
lr = self.opt['lr'] * (
|
| 351 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
| 352 |
+
elif self.opt['lr_decay'] == 'cos':
|
| 353 |
+
lr = self.opt['lr'] * (
|
| 354 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
| 355 |
+
elif self.opt['lr_decay'] == 'linear':
|
| 356 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
| 357 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
| 358 |
+
if epoch < self.opt['turning_point'] + 1:
|
| 359 |
+
# learning rate decay as 95%
|
| 360 |
+
# at the turning point (1 / 95% = 1.0526)
|
| 361 |
+
lr = self.opt['lr'] * (
|
| 362 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
| 363 |
+
else:
|
| 364 |
+
lr *= self.opt['gamma']
|
| 365 |
+
elif self.opt['lr_decay'] == 'schedule':
|
| 366 |
+
if epoch in self.opt['schedule']:
|
| 367 |
+
lr *= self.opt['gamma']
|
| 368 |
+
else:
|
| 369 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
| 370 |
+
# set learning rate
|
| 371 |
+
for param_group in self.optimizer.param_groups:
|
| 372 |
+
param_group['lr'] = lr
|
| 373 |
+
|
| 374 |
+
return lr
|
models/losses/__init__.py
ADDED
|
File without changes
|
models/losses/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (127 Bytes). View file
|
|
|
models/losses/__pycache__/accuracy.cpython-38.pyc
ADDED
|
Binary file (2.02 kB). View file
|
|
|
models/losses/__pycache__/cross_entropy_loss.cpython-38.pyc
ADDED
|
Binary file (6.76 kB). View file
|
|
|
models/losses/__pycache__/segmentation_loss.cpython-38.pyc
ADDED
|
Binary file (1.3 kB). View file
|
|
|
models/losses/__pycache__/vqgan_loss.cpython-38.pyc
ADDED
|
Binary file (3.65 kB). View file
|
|
|
models/losses/accuracy.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def accuracy(pred, target, topk=1, thresh=None):
|
| 2 |
+
"""Calculate accuracy according to the prediction and target.
|
| 3 |
+
|
| 4 |
+
Args:
|
| 5 |
+
pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
|
| 6 |
+
target (torch.Tensor): The target of each prediction, shape (N, , ...)
|
| 7 |
+
topk (int | tuple[int], optional): If the predictions in ``topk``
|
| 8 |
+
matches the target, the predictions will be regarded as
|
| 9 |
+
correct ones. Defaults to 1.
|
| 10 |
+
thresh (float, optional): If not None, predictions with scores under
|
| 11 |
+
this threshold are considered incorrect. Default to None.
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
float | tuple[float]: If the input ``topk`` is a single integer,
|
| 15 |
+
the function will return a single float as accuracy. If
|
| 16 |
+
``topk`` is a tuple containing multiple integers, the
|
| 17 |
+
function will return a tuple containing accuracies of
|
| 18 |
+
each ``topk`` number.
|
| 19 |
+
"""
|
| 20 |
+
assert isinstance(topk, (int, tuple))
|
| 21 |
+
if isinstance(topk, int):
|
| 22 |
+
topk = (topk, )
|
| 23 |
+
return_single = True
|
| 24 |
+
else:
|
| 25 |
+
return_single = False
|
| 26 |
+
|
| 27 |
+
maxk = max(topk)
|
| 28 |
+
if pred.size(0) == 0:
|
| 29 |
+
accu = [pred.new_tensor(0.) for i in range(len(topk))]
|
| 30 |
+
return accu[0] if return_single else accu
|
| 31 |
+
assert pred.ndim == target.ndim + 1
|
| 32 |
+
assert pred.size(0) == target.size(0)
|
| 33 |
+
assert maxk <= pred.size(1), \
|
| 34 |
+
f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
|
| 35 |
+
pred_value, pred_label = pred.topk(maxk, dim=1)
|
| 36 |
+
# transpose to shape (maxk, N, ...)
|
| 37 |
+
pred_label = pred_label.transpose(0, 1)
|
| 38 |
+
correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
|
| 39 |
+
if thresh is not None:
|
| 40 |
+
# Only prediction values larger than thresh are counted as correct
|
| 41 |
+
correct = correct & (pred_value > thresh).t()
|
| 42 |
+
res = []
|
| 43 |
+
for k in topk:
|
| 44 |
+
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
| 45 |
+
res.append(correct_k.mul_(100.0 / target.numel()))
|
| 46 |
+
return res[0] if return_single else res
|
models/losses/cross_entropy_loss.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def reduce_loss(loss, reduction):
|
| 7 |
+
"""Reduce loss as specified.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
loss (Tensor): Elementwise loss tensor.
|
| 11 |
+
reduction (str): Options are "none", "mean" and "sum".
|
| 12 |
+
|
| 13 |
+
Return:
|
| 14 |
+
Tensor: Reduced loss tensor.
|
| 15 |
+
"""
|
| 16 |
+
reduction_enum = F._Reduction.get_enum(reduction)
|
| 17 |
+
# none: 0, elementwise_mean:1, sum: 2
|
| 18 |
+
if reduction_enum == 0:
|
| 19 |
+
return loss
|
| 20 |
+
elif reduction_enum == 1:
|
| 21 |
+
return loss.mean()
|
| 22 |
+
elif reduction_enum == 2:
|
| 23 |
+
return loss.sum()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
|
| 27 |
+
"""Apply element-wise weight and reduce loss.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
loss (Tensor): Element-wise loss.
|
| 31 |
+
weight (Tensor): Element-wise weights.
|
| 32 |
+
reduction (str): Same as built-in losses of PyTorch.
|
| 33 |
+
avg_factor (float): Avarage factor when computing the mean of losses.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Tensor: Processed loss values.
|
| 37 |
+
"""
|
| 38 |
+
# if weight is specified, apply element-wise weight
|
| 39 |
+
if weight is not None:
|
| 40 |
+
assert weight.dim() == loss.dim()
|
| 41 |
+
if weight.dim() > 1:
|
| 42 |
+
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
| 43 |
+
loss = loss * weight
|
| 44 |
+
|
| 45 |
+
# if avg_factor is not specified, just reduce the loss
|
| 46 |
+
if avg_factor is None:
|
| 47 |
+
loss = reduce_loss(loss, reduction)
|
| 48 |
+
else:
|
| 49 |
+
# if reduction is mean, then average the loss by avg_factor
|
| 50 |
+
if reduction == 'mean':
|
| 51 |
+
loss = loss.sum() / avg_factor
|
| 52 |
+
# if reduction is 'none', then do nothing, otherwise raise an error
|
| 53 |
+
elif reduction != 'none':
|
| 54 |
+
raise ValueError('avg_factor can not be used with reduction="sum"')
|
| 55 |
+
return loss
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def cross_entropy(pred,
|
| 59 |
+
label,
|
| 60 |
+
weight=None,
|
| 61 |
+
class_weight=None,
|
| 62 |
+
reduction='mean',
|
| 63 |
+
avg_factor=None,
|
| 64 |
+
ignore_index=-100):
|
| 65 |
+
"""The wrapper function for :func:`F.cross_entropy`"""
|
| 66 |
+
# class_weight is a manual rescaling weight given to each class.
|
| 67 |
+
# If given, has to be a Tensor of size C element-wise losses
|
| 68 |
+
loss = F.cross_entropy(
|
| 69 |
+
pred,
|
| 70 |
+
label,
|
| 71 |
+
weight=class_weight,
|
| 72 |
+
reduction='none',
|
| 73 |
+
ignore_index=ignore_index)
|
| 74 |
+
|
| 75 |
+
# apply weights and do the reduction
|
| 76 |
+
if weight is not None:
|
| 77 |
+
weight = weight.float()
|
| 78 |
+
loss = weight_reduce_loss(
|
| 79 |
+
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
| 80 |
+
|
| 81 |
+
return loss
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
|
| 85 |
+
"""Expand onehot labels to match the size of prediction."""
|
| 86 |
+
bin_labels = labels.new_zeros(target_shape)
|
| 87 |
+
valid_mask = (labels >= 0) & (labels != ignore_index)
|
| 88 |
+
inds = torch.nonzero(valid_mask, as_tuple=True)
|
| 89 |
+
|
| 90 |
+
if inds[0].numel() > 0:
|
| 91 |
+
if labels.dim() == 3:
|
| 92 |
+
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
|
| 93 |
+
else:
|
| 94 |
+
bin_labels[inds[0], labels[valid_mask]] = 1
|
| 95 |
+
|
| 96 |
+
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
|
| 97 |
+
if label_weights is None:
|
| 98 |
+
bin_label_weights = valid_mask
|
| 99 |
+
else:
|
| 100 |
+
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
| 101 |
+
bin_label_weights *= valid_mask
|
| 102 |
+
|
| 103 |
+
return bin_labels, bin_label_weights
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def binary_cross_entropy(pred,
|
| 107 |
+
label,
|
| 108 |
+
weight=None,
|
| 109 |
+
reduction='mean',
|
| 110 |
+
avg_factor=None,
|
| 111 |
+
class_weight=None,
|
| 112 |
+
ignore_index=255):
|
| 113 |
+
"""Calculate the binary CrossEntropy loss.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
pred (torch.Tensor): The prediction with shape (N, 1).
|
| 117 |
+
label (torch.Tensor): The learning label of the prediction.
|
| 118 |
+
weight (torch.Tensor, optional): Sample-wise loss weight.
|
| 119 |
+
reduction (str, optional): The method used to reduce the loss.
|
| 120 |
+
Options are "none", "mean" and "sum".
|
| 121 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 122 |
+
the loss. Defaults to None.
|
| 123 |
+
class_weight (list[float], optional): The weight for each class.
|
| 124 |
+
ignore_index (int | None): The label index to be ignored. Default: 255
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
torch.Tensor: The calculated loss
|
| 128 |
+
"""
|
| 129 |
+
if pred.dim() != label.dim():
|
| 130 |
+
assert (pred.dim() == 2 and label.dim() == 1) or (
|
| 131 |
+
pred.dim() == 4 and label.dim() == 3), \
|
| 132 |
+
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
|
| 133 |
+
'H, W], label shape [N, H, W] are supported'
|
| 134 |
+
label, weight = _expand_onehot_labels(label, weight, pred.shape,
|
| 135 |
+
ignore_index)
|
| 136 |
+
|
| 137 |
+
# weighted element-wise losses
|
| 138 |
+
if weight is not None:
|
| 139 |
+
weight = weight.float()
|
| 140 |
+
loss = F.binary_cross_entropy_with_logits(
|
| 141 |
+
pred, label.float(), pos_weight=class_weight, reduction='none')
|
| 142 |
+
# do the reduction for the weighted loss
|
| 143 |
+
loss = weight_reduce_loss(
|
| 144 |
+
loss, weight, reduction=reduction, avg_factor=avg_factor)
|
| 145 |
+
|
| 146 |
+
return loss
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def mask_cross_entropy(pred,
|
| 150 |
+
target,
|
| 151 |
+
label,
|
| 152 |
+
reduction='mean',
|
| 153 |
+
avg_factor=None,
|
| 154 |
+
class_weight=None,
|
| 155 |
+
ignore_index=None):
|
| 156 |
+
"""Calculate the CrossEntropy loss for masks.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
| 160 |
+
of classes.
|
| 161 |
+
target (torch.Tensor): The learning label of the prediction.
|
| 162 |
+
label (torch.Tensor): ``label`` indicates the class label of the mask'
|
| 163 |
+
corresponding object. This will be used to select the mask in the
|
| 164 |
+
of the class which the object belongs to when the mask prediction
|
| 165 |
+
if not class-agnostic.
|
| 166 |
+
reduction (str, optional): The method used to reduce the loss.
|
| 167 |
+
Options are "none", "mean" and "sum".
|
| 168 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 169 |
+
the loss. Defaults to None.
|
| 170 |
+
class_weight (list[float], optional): The weight for each class.
|
| 171 |
+
ignore_index (None): Placeholder, to be consistent with other loss.
|
| 172 |
+
Default: None.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
torch.Tensor: The calculated loss
|
| 176 |
+
"""
|
| 177 |
+
assert ignore_index is None, 'BCE loss does not support ignore_index'
|
| 178 |
+
# TODO: handle these two reserved arguments
|
| 179 |
+
assert reduction == 'mean' and avg_factor is None
|
| 180 |
+
num_rois = pred.size()[0]
|
| 181 |
+
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
|
| 182 |
+
pred_slice = pred[inds, label].squeeze(1)
|
| 183 |
+
return F.binary_cross_entropy_with_logits(
|
| 184 |
+
pred_slice, target, weight=class_weight, reduction='mean')[None]
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class CrossEntropyLoss(nn.Module):
|
| 188 |
+
"""CrossEntropyLoss.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
|
| 192 |
+
of softmax. Defaults to False.
|
| 193 |
+
use_mask (bool, optional): Whether to use mask cross entropy loss.
|
| 194 |
+
Defaults to False.
|
| 195 |
+
reduction (str, optional): . Defaults to 'mean'.
|
| 196 |
+
Options are "none", "mean" and "sum".
|
| 197 |
+
class_weight (list[float], optional): Weight of each class.
|
| 198 |
+
Defaults to None.
|
| 199 |
+
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __init__(self,
|
| 203 |
+
use_sigmoid=False,
|
| 204 |
+
use_mask=False,
|
| 205 |
+
reduction='mean',
|
| 206 |
+
class_weight=None,
|
| 207 |
+
loss_weight=1.0):
|
| 208 |
+
super(CrossEntropyLoss, self).__init__()
|
| 209 |
+
assert (use_sigmoid is False) or (use_mask is False)
|
| 210 |
+
self.use_sigmoid = use_sigmoid
|
| 211 |
+
self.use_mask = use_mask
|
| 212 |
+
self.reduction = reduction
|
| 213 |
+
self.loss_weight = loss_weight
|
| 214 |
+
self.class_weight = class_weight
|
| 215 |
+
|
| 216 |
+
if self.use_sigmoid:
|
| 217 |
+
self.cls_criterion = binary_cross_entropy
|
| 218 |
+
elif self.use_mask:
|
| 219 |
+
self.cls_criterion = mask_cross_entropy
|
| 220 |
+
else:
|
| 221 |
+
self.cls_criterion = cross_entropy
|
| 222 |
+
|
| 223 |
+
def forward(self,
|
| 224 |
+
cls_score,
|
| 225 |
+
label,
|
| 226 |
+
weight=None,
|
| 227 |
+
avg_factor=None,
|
| 228 |
+
reduction_override=None,
|
| 229 |
+
**kwargs):
|
| 230 |
+
"""Forward function."""
|
| 231 |
+
assert reduction_override in (None, 'none', 'mean', 'sum')
|
| 232 |
+
reduction = (
|
| 233 |
+
reduction_override if reduction_override else self.reduction)
|
| 234 |
+
if self.class_weight is not None:
|
| 235 |
+
class_weight = cls_score.new_tensor(self.class_weight)
|
| 236 |
+
else:
|
| 237 |
+
class_weight = None
|
| 238 |
+
loss_cls = self.loss_weight * self.cls_criterion(
|
| 239 |
+
cls_score,
|
| 240 |
+
label,
|
| 241 |
+
weight,
|
| 242 |
+
class_weight=class_weight,
|
| 243 |
+
reduction=reduction,
|
| 244 |
+
avg_factor=avg_factor,
|
| 245 |
+
**kwargs)
|
| 246 |
+
return loss_cls
|
models/losses/segmentation_loss.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BCELoss(nn.Module):
|
| 6 |
+
|
| 7 |
+
def forward(self, prediction, target):
|
| 8 |
+
loss = F.binary_cross_entropy_with_logits(prediction, target)
|
| 9 |
+
return loss, {}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BCELossWithQuant(nn.Module):
|
| 13 |
+
|
| 14 |
+
def __init__(self, codebook_weight=1.):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.codebook_weight = codebook_weight
|
| 17 |
+
|
| 18 |
+
def forward(self, qloss, target, prediction, split):
|
| 19 |
+
bce_loss = F.binary_cross_entropy_with_logits(prediction, target)
|
| 20 |
+
loss = bce_loss + self.codebook_weight * qloss
|
| 21 |
+
return loss, {
|
| 22 |
+
"{}/total_loss".format(split): loss.clone().detach().mean(),
|
| 23 |
+
"{}/bce_loss".format(split): bce_loss.detach().mean(),
|
| 24 |
+
"{}/quant_loss".format(split): qloss.detach().mean()
|
| 25 |
+
}
|
models/losses/vqgan_loss.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def calculate_adaptive_weight(recon_loss, g_loss, last_layer, disc_weight_max):
|
| 6 |
+
recon_grads = torch.autograd.grad(
|
| 7 |
+
recon_loss, last_layer, retain_graph=True)[0]
|
| 8 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
| 9 |
+
|
| 10 |
+
d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
|
| 11 |
+
d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
|
| 12 |
+
return d_weight
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
| 16 |
+
if global_step < threshold:
|
| 17 |
+
weight = value
|
| 18 |
+
return weight
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@torch.jit.script
|
| 22 |
+
def hinge_d_loss(logits_real, logits_fake):
|
| 23 |
+
loss_real = torch.mean(F.relu(1. - logits_real))
|
| 24 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
| 25 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
| 26 |
+
return d_loss
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def DiffAugment(x, policy='', channels_first=True):
|
| 30 |
+
if policy:
|
| 31 |
+
if not channels_first:
|
| 32 |
+
x = x.permute(0, 3, 1, 2)
|
| 33 |
+
for p in policy.split(','):
|
| 34 |
+
for f in AUGMENT_FNS[p]:
|
| 35 |
+
x = f(x)
|
| 36 |
+
if not channels_first:
|
| 37 |
+
x = x.permute(0, 2, 3, 1)
|
| 38 |
+
x = x.contiguous()
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def rand_brightness(x):
|
| 43 |
+
x = x + (
|
| 44 |
+
torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
| 45 |
+
return x
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def rand_saturation(x):
|
| 49 |
+
x_mean = x.mean(dim=1, keepdim=True)
|
| 50 |
+
x = (x - x_mean) * (torch.rand(
|
| 51 |
+
x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def rand_contrast(x):
|
| 56 |
+
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
| 57 |
+
x = (x - x_mean) * (torch.rand(
|
| 58 |
+
x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def rand_translation(x, ratio=0.125):
|
| 63 |
+
shift_x, shift_y = int(x.size(2) * ratio +
|
| 64 |
+
0.5), int(x.size(3) * ratio + 0.5)
|
| 65 |
+
translation_x = torch.randint(
|
| 66 |
+
-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
|
| 67 |
+
translation_y = torch.randint(
|
| 68 |
+
-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
|
| 69 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
| 70 |
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
| 71 |
+
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
| 72 |
+
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
| 73 |
+
)
|
| 74 |
+
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
| 75 |
+
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
| 76 |
+
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
| 77 |
+
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x,
|
| 78 |
+
grid_y].permute(0, 3, 1, 2)
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def rand_cutout(x, ratio=0.5):
|
| 83 |
+
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
| 84 |
+
offset_x = torch.randint(
|
| 85 |
+
0,
|
| 86 |
+
x.size(2) + (1 - cutout_size[0] % 2),
|
| 87 |
+
size=[x.size(0), 1, 1],
|
| 88 |
+
device=x.device)
|
| 89 |
+
offset_y = torch.randint(
|
| 90 |
+
0,
|
| 91 |
+
x.size(3) + (1 - cutout_size[1] % 2),
|
| 92 |
+
size=[x.size(0), 1, 1],
|
| 93 |
+
device=x.device)
|
| 94 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
| 95 |
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
| 96 |
+
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
| 97 |
+
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
| 98 |
+
)
|
| 99 |
+
grid_x = torch.clamp(
|
| 100 |
+
grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
|
| 101 |
+
grid_y = torch.clamp(
|
| 102 |
+
grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
|
| 103 |
+
mask = torch.ones(
|
| 104 |
+
x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
| 105 |
+
mask[grid_batch, grid_x, grid_y] = 0
|
| 106 |
+
x = x * mask.unsqueeze(1)
|
| 107 |
+
return x
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
AUGMENT_FNS = {
|
| 111 |
+
'color': [rand_brightness, rand_saturation, rand_contrast],
|
| 112 |
+
'translation': [rand_translation],
|
| 113 |
+
'cutout': [rand_cutout],
|
| 114 |
+
}
|
models/parsing_gen_model.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
import mmcv
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torchvision.utils import save_image
|
| 9 |
+
|
| 10 |
+
from models.archs.fcn_arch import FCNHead
|
| 11 |
+
from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding
|
| 12 |
+
from models.archs.unet_arch import ShapeUNet
|
| 13 |
+
from models.losses.accuracy import accuracy
|
| 14 |
+
from models.losses.cross_entropy_loss import CrossEntropyLoss
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger('base')
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ParsingGenModel():
|
| 20 |
+
"""Paring Generation model.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, opt):
|
| 24 |
+
self.opt = opt
|
| 25 |
+
self.device = torch.device('cuda')
|
| 26 |
+
self.is_train = opt['is_train']
|
| 27 |
+
|
| 28 |
+
self.attr_embedder = ShapeAttrEmbedding(
|
| 29 |
+
dim=opt['embedder_dim'],
|
| 30 |
+
out_dim=opt['embedder_out_dim'],
|
| 31 |
+
cls_num_list=opt['attr_class_num']).to(self.device)
|
| 32 |
+
self.parsing_encoder = ShapeUNet(
|
| 33 |
+
in_channels=opt['encoder_in_channels']).to(self.device)
|
| 34 |
+
self.parsing_decoder = FCNHead(
|
| 35 |
+
in_channels=opt['fc_in_channels'],
|
| 36 |
+
in_index=opt['fc_in_index'],
|
| 37 |
+
channels=opt['fc_channels'],
|
| 38 |
+
num_convs=opt['fc_num_convs'],
|
| 39 |
+
concat_input=opt['fc_concat_input'],
|
| 40 |
+
dropout_ratio=opt['fc_dropout_ratio'],
|
| 41 |
+
num_classes=opt['fc_num_classes'],
|
| 42 |
+
align_corners=opt['fc_align_corners'],
|
| 43 |
+
).to(self.device)
|
| 44 |
+
|
| 45 |
+
self.init_training_settings()
|
| 46 |
+
|
| 47 |
+
self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220],
|
| 48 |
+
[250, 235, 215], [255, 250, 205], [211, 211, 211],
|
| 49 |
+
[70, 130, 180], [127, 255, 212], [0, 100, 0],
|
| 50 |
+
[50, 205, 50], [255, 255, 0], [245, 222, 179],
|
| 51 |
+
[255, 140, 0], [255, 0, 0], [16, 78, 139],
|
| 52 |
+
[144, 238, 144], [50, 205, 174], [50, 155, 250],
|
| 53 |
+
[160, 140, 88], [213, 140, 88], [90, 140, 90],
|
| 54 |
+
[185, 210, 205], [130, 165, 180], [225, 141, 151]]
|
| 55 |
+
|
| 56 |
+
def init_training_settings(self):
|
| 57 |
+
optim_params = []
|
| 58 |
+
for v in self.attr_embedder.parameters():
|
| 59 |
+
if v.requires_grad:
|
| 60 |
+
optim_params.append(v)
|
| 61 |
+
for v in self.parsing_encoder.parameters():
|
| 62 |
+
if v.requires_grad:
|
| 63 |
+
optim_params.append(v)
|
| 64 |
+
for v in self.parsing_decoder.parameters():
|
| 65 |
+
if v.requires_grad:
|
| 66 |
+
optim_params.append(v)
|
| 67 |
+
# set up optimizers
|
| 68 |
+
self.optimizer = torch.optim.Adam(
|
| 69 |
+
optim_params,
|
| 70 |
+
self.opt['lr'],
|
| 71 |
+
weight_decay=self.opt['weight_decay'])
|
| 72 |
+
self.log_dict = OrderedDict()
|
| 73 |
+
self.entropy_loss = CrossEntropyLoss().to(self.device)
|
| 74 |
+
|
| 75 |
+
def feed_data(self, data):
|
| 76 |
+
self.pose = data['densepose'].to(self.device)
|
| 77 |
+
self.attr = data['attr'].to(self.device)
|
| 78 |
+
self.segm = data['segm'].to(self.device)
|
| 79 |
+
|
| 80 |
+
def optimize_parameters(self):
|
| 81 |
+
self.attr_embedder.train()
|
| 82 |
+
self.parsing_encoder.train()
|
| 83 |
+
self.parsing_decoder.train()
|
| 84 |
+
|
| 85 |
+
self.attr_embedding = self.attr_embedder(self.attr)
|
| 86 |
+
self.pose_enc = self.parsing_encoder(self.pose, self.attr_embedding)
|
| 87 |
+
self.seg_logits = self.parsing_decoder(self.pose_enc)
|
| 88 |
+
|
| 89 |
+
loss = self.entropy_loss(self.seg_logits, self.segm)
|
| 90 |
+
|
| 91 |
+
self.optimizer.zero_grad()
|
| 92 |
+
loss.backward()
|
| 93 |
+
self.optimizer.step()
|
| 94 |
+
|
| 95 |
+
self.log_dict['loss_total'] = loss
|
| 96 |
+
|
| 97 |
+
def get_vis(self, save_path):
|
| 98 |
+
img_cat = torch.cat([
|
| 99 |
+
self.pose,
|
| 100 |
+
self.segm,
|
| 101 |
+
], dim=3).detach()
|
| 102 |
+
img_cat = ((img_cat + 1) / 2)
|
| 103 |
+
|
| 104 |
+
img_cat = img_cat.clamp_(0, 1)
|
| 105 |
+
|
| 106 |
+
save_image(img_cat, save_path, nrow=1, padding=4)
|
| 107 |
+
|
| 108 |
+
def inference(self, data_loader, save_dir):
|
| 109 |
+
self.attr_embedder.eval()
|
| 110 |
+
self.parsing_encoder.eval()
|
| 111 |
+
self.parsing_decoder.eval()
|
| 112 |
+
|
| 113 |
+
acc = 0
|
| 114 |
+
num = 0
|
| 115 |
+
|
| 116 |
+
for _, data in enumerate(data_loader):
|
| 117 |
+
pose = data['densepose'].to(self.device)
|
| 118 |
+
attr = data['attr'].to(self.device)
|
| 119 |
+
segm = data['segm'].to(self.device)
|
| 120 |
+
img_name = data['img_name']
|
| 121 |
+
|
| 122 |
+
num += pose.size(0)
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
attr_embedding = self.attr_embedder(attr)
|
| 125 |
+
pose_enc = self.parsing_encoder(pose, attr_embedding)
|
| 126 |
+
seg_logits = self.parsing_decoder(pose_enc)
|
| 127 |
+
seg_pred = seg_logits.argmax(dim=1)
|
| 128 |
+
acc += accuracy(seg_logits, segm)
|
| 129 |
+
palette_label = self.palette_result(segm.cpu().numpy())
|
| 130 |
+
palette_pred = self.palette_result(seg_pred.cpu().numpy())
|
| 131 |
+
pose_numpy = ((pose[0] + 1) / 2. * 255.).expand(
|
| 132 |
+
3,
|
| 133 |
+
pose[0].size(1),
|
| 134 |
+
pose[0].size(2),
|
| 135 |
+
).cpu().numpy().clip(0, 255).astype(np.uint8).transpose(1, 2, 0)
|
| 136 |
+
concat_result = np.concatenate(
|
| 137 |
+
(pose_numpy, palette_pred, palette_label), axis=1)
|
| 138 |
+
mmcv.imwrite(concat_result, f'{save_dir}/{img_name[0]}')
|
| 139 |
+
|
| 140 |
+
self.attr_embedder.train()
|
| 141 |
+
self.parsing_encoder.train()
|
| 142 |
+
self.parsing_decoder.train()
|
| 143 |
+
return (acc / num).item()
|
| 144 |
+
|
| 145 |
+
def get_current_log(self):
|
| 146 |
+
return self.log_dict
|
| 147 |
+
|
| 148 |
+
def update_learning_rate(self, epoch):
|
| 149 |
+
"""Update learning rate.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
current_iter (int): Current iteration.
|
| 153 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
| 154 |
+
Default: -1.
|
| 155 |
+
"""
|
| 156 |
+
lr = self.optimizer.param_groups[0]['lr']
|
| 157 |
+
|
| 158 |
+
if self.opt['lr_decay'] == 'step':
|
| 159 |
+
lr = self.opt['lr'] * (
|
| 160 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
| 161 |
+
elif self.opt['lr_decay'] == 'cos':
|
| 162 |
+
lr = self.opt['lr'] * (
|
| 163 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
| 164 |
+
elif self.opt['lr_decay'] == 'linear':
|
| 165 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
| 166 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
| 167 |
+
if epoch < self.opt['turning_point'] + 1:
|
| 168 |
+
# learning rate decay as 95%
|
| 169 |
+
# at the turning point (1 / 95% = 1.0526)
|
| 170 |
+
lr = self.opt['lr'] * (
|
| 171 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
| 172 |
+
else:
|
| 173 |
+
lr *= self.opt['gamma']
|
| 174 |
+
elif self.opt['lr_decay'] == 'schedule':
|
| 175 |
+
if epoch in self.opt['schedule']:
|
| 176 |
+
lr *= self.opt['gamma']
|
| 177 |
+
else:
|
| 178 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
| 179 |
+
# set learning rate
|
| 180 |
+
for param_group in self.optimizer.param_groups:
|
| 181 |
+
param_group['lr'] = lr
|
| 182 |
+
|
| 183 |
+
return lr
|
| 184 |
+
|
| 185 |
+
def save_network(self, save_path):
|
| 186 |
+
"""Save networks.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
save_dict = {}
|
| 190 |
+
save_dict['embedder'] = self.attr_embedder.state_dict()
|
| 191 |
+
save_dict['encoder'] = self.parsing_encoder.state_dict()
|
| 192 |
+
save_dict['decoder'] = self.parsing_decoder.state_dict()
|
| 193 |
+
|
| 194 |
+
torch.save(save_dict, save_path)
|
| 195 |
+
|
| 196 |
+
def load_network(self):
|
| 197 |
+
checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
|
| 198 |
+
|
| 199 |
+
self.attr_embedder.load_state_dict(checkpoint['embedder'], strict=True)
|
| 200 |
+
self.attr_embedder.eval()
|
| 201 |
+
|
| 202 |
+
self.parsing_encoder.load_state_dict(
|
| 203 |
+
checkpoint['encoder'], strict=True)
|
| 204 |
+
self.parsing_encoder.eval()
|
| 205 |
+
|
| 206 |
+
self.parsing_decoder.load_state_dict(
|
| 207 |
+
checkpoint['decoder'], strict=True)
|
| 208 |
+
self.parsing_decoder.eval()
|
| 209 |
+
|
| 210 |
+
def palette_result(self, result):
|
| 211 |
+
seg = result[0]
|
| 212 |
+
palette = np.array(self.palette)
|
| 213 |
+
assert palette.shape[1] == 3
|
| 214 |
+
assert len(palette.shape) == 2
|
| 215 |
+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
| 216 |
+
for label, color in enumerate(palette):
|
| 217 |
+
color_seg[seg == label, :] = color
|
| 218 |
+
# convert to BGR
|
| 219 |
+
color_seg = color_seg[..., ::-1]
|
| 220 |
+
return color_seg
|
models/sample_model.py
ADDED
|
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.distributions as dists
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torchvision.utils import save_image
|
| 8 |
+
|
| 9 |
+
from models.archs.fcn_arch import FCNHead, MultiHeadFCNHead
|
| 10 |
+
from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding
|
| 11 |
+
from models.archs.transformer_arch import TransformerMultiHead
|
| 12 |
+
from models.archs.unet_arch import ShapeUNet, UNet
|
| 13 |
+
from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder,
|
| 14 |
+
VectorQuantizer,
|
| 15 |
+
VectorQuantizerSpatialTextureAware,
|
| 16 |
+
VectorQuantizerTexture)
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger('base')
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BaseSampleModel():
|
| 22 |
+
"""Base Model"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, opt):
|
| 25 |
+
self.opt = opt
|
| 26 |
+
self.device = torch.device('cuda')
|
| 27 |
+
|
| 28 |
+
# hierarchical VQVAE
|
| 29 |
+
self.decoder = Decoder(
|
| 30 |
+
in_channels=opt['top_in_channels'],
|
| 31 |
+
resolution=opt['top_resolution'],
|
| 32 |
+
z_channels=opt['top_z_channels'],
|
| 33 |
+
ch=opt['top_ch'],
|
| 34 |
+
out_ch=opt['top_out_ch'],
|
| 35 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
| 36 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
| 37 |
+
ch_mult=opt['top_ch_mult'],
|
| 38 |
+
dropout=opt['top_dropout'],
|
| 39 |
+
resamp_with_conv=True,
|
| 40 |
+
give_pre_end=False).to(self.device)
|
| 41 |
+
self.top_quantize = VectorQuantizerTexture(
|
| 42 |
+
1024, opt['embed_dim'], beta=0.25).to(self.device)
|
| 43 |
+
self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
| 44 |
+
opt["top_z_channels"],
|
| 45 |
+
1).to(self.device)
|
| 46 |
+
self.load_top_pretrain_models()
|
| 47 |
+
|
| 48 |
+
self.bot_decoder_res = DecoderRes(
|
| 49 |
+
in_channels=opt['bot_in_channels'],
|
| 50 |
+
resolution=opt['bot_resolution'],
|
| 51 |
+
z_channels=opt['bot_z_channels'],
|
| 52 |
+
ch=opt['bot_ch'],
|
| 53 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
| 54 |
+
ch_mult=opt['bot_ch_mult'],
|
| 55 |
+
dropout=opt['bot_dropout'],
|
| 56 |
+
give_pre_end=False).to(self.device)
|
| 57 |
+
self.bot_quantize = VectorQuantizerSpatialTextureAware(
|
| 58 |
+
opt['bot_n_embed'],
|
| 59 |
+
opt['embed_dim'],
|
| 60 |
+
beta=0.25,
|
| 61 |
+
spatial_size=opt['bot_codebook_spatial_size']).to(self.device)
|
| 62 |
+
self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
| 63 |
+
opt["bot_z_channels"],
|
| 64 |
+
1).to(self.device)
|
| 65 |
+
self.load_bot_pretrain_network()
|
| 66 |
+
|
| 67 |
+
# top -> bot prediction
|
| 68 |
+
self.index_pred_guidance_encoder = UNet(
|
| 69 |
+
in_channels=opt['index_pred_encoder_in_channels']).to(self.device)
|
| 70 |
+
self.index_pred_decoder = MultiHeadFCNHead(
|
| 71 |
+
in_channels=opt['index_pred_fc_in_channels'],
|
| 72 |
+
in_index=opt['index_pred_fc_in_index'],
|
| 73 |
+
channels=opt['index_pred_fc_channels'],
|
| 74 |
+
num_convs=opt['index_pred_fc_num_convs'],
|
| 75 |
+
concat_input=opt['index_pred_fc_concat_input'],
|
| 76 |
+
dropout_ratio=opt['index_pred_fc_dropout_ratio'],
|
| 77 |
+
num_classes=opt['index_pred_fc_num_classes'],
|
| 78 |
+
align_corners=opt['index_pred_fc_align_corners'],
|
| 79 |
+
num_head=18).to(self.device)
|
| 80 |
+
self.load_index_pred_network()
|
| 81 |
+
|
| 82 |
+
# VAE for segmentation mask
|
| 83 |
+
self.segm_encoder = Encoder(
|
| 84 |
+
ch=opt['segm_ch'],
|
| 85 |
+
num_res_blocks=opt['segm_num_res_blocks'],
|
| 86 |
+
attn_resolutions=opt['segm_attn_resolutions'],
|
| 87 |
+
ch_mult=opt['segm_ch_mult'],
|
| 88 |
+
in_channels=opt['segm_in_channels'],
|
| 89 |
+
resolution=opt['segm_resolution'],
|
| 90 |
+
z_channels=opt['segm_z_channels'],
|
| 91 |
+
double_z=opt['segm_double_z'],
|
| 92 |
+
dropout=opt['segm_dropout']).to(self.device)
|
| 93 |
+
self.segm_quantizer = VectorQuantizer(
|
| 94 |
+
opt['segm_n_embed'],
|
| 95 |
+
opt['segm_embed_dim'],
|
| 96 |
+
beta=0.25,
|
| 97 |
+
sane_index_shape=True).to(self.device)
|
| 98 |
+
self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"],
|
| 99 |
+
opt['segm_embed_dim'],
|
| 100 |
+
1).to(self.device)
|
| 101 |
+
self.load_pretrained_segm_token()
|
| 102 |
+
|
| 103 |
+
# define sampler
|
| 104 |
+
self.sampler_fn = TransformerMultiHead(
|
| 105 |
+
codebook_size=opt['codebook_size'],
|
| 106 |
+
segm_codebook_size=opt['segm_codebook_size'],
|
| 107 |
+
texture_codebook_size=opt['texture_codebook_size'],
|
| 108 |
+
bert_n_emb=opt['bert_n_emb'],
|
| 109 |
+
bert_n_layers=opt['bert_n_layers'],
|
| 110 |
+
bert_n_head=opt['bert_n_head'],
|
| 111 |
+
block_size=opt['block_size'],
|
| 112 |
+
latent_shape=opt['latent_shape'],
|
| 113 |
+
embd_pdrop=opt['embd_pdrop'],
|
| 114 |
+
resid_pdrop=opt['resid_pdrop'],
|
| 115 |
+
attn_pdrop=opt['attn_pdrop'],
|
| 116 |
+
num_head=opt['num_head']).to(self.device)
|
| 117 |
+
self.load_sampler_pretrained_network()
|
| 118 |
+
|
| 119 |
+
self.shape = tuple(opt['latent_shape'])
|
| 120 |
+
|
| 121 |
+
self.mask_id = opt['codebook_size']
|
| 122 |
+
self.sample_steps = opt['sample_steps']
|
| 123 |
+
|
| 124 |
+
def load_top_pretrain_models(self):
|
| 125 |
+
# load pretrained vqgan
|
| 126 |
+
top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
|
| 127 |
+
|
| 128 |
+
self.decoder.load_state_dict(
|
| 129 |
+
top_vae_checkpoint['decoder'], strict=True)
|
| 130 |
+
self.top_quantize.load_state_dict(
|
| 131 |
+
top_vae_checkpoint['quantize'], strict=True)
|
| 132 |
+
self.top_post_quant_conv.load_state_dict(
|
| 133 |
+
top_vae_checkpoint['post_quant_conv'], strict=True)
|
| 134 |
+
|
| 135 |
+
self.decoder.eval()
|
| 136 |
+
self.top_quantize.eval()
|
| 137 |
+
self.top_post_quant_conv.eval()
|
| 138 |
+
|
| 139 |
+
def load_bot_pretrain_network(self):
|
| 140 |
+
checkpoint = torch.load(self.opt['bot_vae_path'])
|
| 141 |
+
self.bot_decoder_res.load_state_dict(
|
| 142 |
+
checkpoint['bot_decoder_res'], strict=True)
|
| 143 |
+
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
|
| 144 |
+
self.bot_quantize.load_state_dict(
|
| 145 |
+
checkpoint['bot_quantize'], strict=True)
|
| 146 |
+
self.bot_post_quant_conv.load_state_dict(
|
| 147 |
+
checkpoint['bot_post_quant_conv'], strict=True)
|
| 148 |
+
|
| 149 |
+
self.bot_decoder_res.eval()
|
| 150 |
+
self.decoder.eval()
|
| 151 |
+
self.bot_quantize.eval()
|
| 152 |
+
self.bot_post_quant_conv.eval()
|
| 153 |
+
|
| 154 |
+
def load_pretrained_segm_token(self):
|
| 155 |
+
# load pretrained vqgan for segmentation mask
|
| 156 |
+
segm_token_checkpoint = torch.load(self.opt['segm_token_path'])
|
| 157 |
+
self.segm_encoder.load_state_dict(
|
| 158 |
+
segm_token_checkpoint['encoder'], strict=True)
|
| 159 |
+
self.segm_quantizer.load_state_dict(
|
| 160 |
+
segm_token_checkpoint['quantize'], strict=True)
|
| 161 |
+
self.segm_quant_conv.load_state_dict(
|
| 162 |
+
segm_token_checkpoint['quant_conv'], strict=True)
|
| 163 |
+
|
| 164 |
+
self.segm_encoder.eval()
|
| 165 |
+
self.segm_quantizer.eval()
|
| 166 |
+
self.segm_quant_conv.eval()
|
| 167 |
+
|
| 168 |
+
def load_index_pred_network(self):
|
| 169 |
+
checkpoint = torch.load(self.opt['pretrained_index_network'])
|
| 170 |
+
self.index_pred_guidance_encoder.load_state_dict(
|
| 171 |
+
checkpoint['guidance_encoder'], strict=True)
|
| 172 |
+
self.index_pred_decoder.load_state_dict(
|
| 173 |
+
checkpoint['index_decoder'], strict=True)
|
| 174 |
+
|
| 175 |
+
self.index_pred_guidance_encoder.eval()
|
| 176 |
+
self.index_pred_decoder.eval()
|
| 177 |
+
|
| 178 |
+
def load_sampler_pretrained_network(self):
|
| 179 |
+
checkpoint = torch.load(self.opt['pretrained_sampler'])
|
| 180 |
+
self.sampler_fn.load_state_dict(checkpoint, strict=True)
|
| 181 |
+
self.sampler_fn.eval()
|
| 182 |
+
|
| 183 |
+
def bot_index_prediction(self, feature_top, texture_mask):
|
| 184 |
+
self.index_pred_guidance_encoder.eval()
|
| 185 |
+
self.index_pred_decoder.eval()
|
| 186 |
+
|
| 187 |
+
texture_mask_flatten = F.interpolate(
|
| 188 |
+
texture_mask, (32, 16), mode='nearest').view(-1).long()
|
| 189 |
+
|
| 190 |
+
min_encodings_indices_list = [
|
| 191 |
+
torch.full(
|
| 192 |
+
texture_mask_flatten.size(),
|
| 193 |
+
fill_value=-1,
|
| 194 |
+
dtype=torch.long,
|
| 195 |
+
device=texture_mask_flatten.device) for _ in range(18)
|
| 196 |
+
]
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
feature_enc = self.index_pred_guidance_encoder(feature_top)
|
| 199 |
+
memory_logits_list = self.index_pred_decoder(feature_enc)
|
| 200 |
+
for codebook_idx, memory_logits in enumerate(memory_logits_list):
|
| 201 |
+
region_of_interest = texture_mask_flatten == codebook_idx
|
| 202 |
+
if torch.sum(region_of_interest) > 0:
|
| 203 |
+
memory_indices_pred = memory_logits.argmax(dim=1).view(-1)
|
| 204 |
+
memory_indices_pred = memory_indices_pred
|
| 205 |
+
min_encodings_indices_list[codebook_idx][
|
| 206 |
+
region_of_interest] = memory_indices_pred[
|
| 207 |
+
region_of_interest]
|
| 208 |
+
min_encodings_indices_return_list = [
|
| 209 |
+
min_encodings_indices.view((1, 32, 16))
|
| 210 |
+
for min_encodings_indices in min_encodings_indices_list
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
return min_encodings_indices_return_list
|
| 214 |
+
|
| 215 |
+
def sample_and_refine(self, save_dir=None, img_name=None):
|
| 216 |
+
# sample 32x16 features indices
|
| 217 |
+
sampled_top_indices_list = self.sample_fn(
|
| 218 |
+
temp=1, sample_steps=self.sample_steps)
|
| 219 |
+
|
| 220 |
+
for sample_idx in range(self.batch_size):
|
| 221 |
+
sample_indices = [
|
| 222 |
+
sampled_indices_cur[sample_idx:sample_idx + 1]
|
| 223 |
+
for sampled_indices_cur in sampled_top_indices_list
|
| 224 |
+
]
|
| 225 |
+
top_quant = self.top_quantize.get_codebook_entry(
|
| 226 |
+
sample_indices, self.texture_mask[sample_idx:sample_idx + 1],
|
| 227 |
+
(sample_indices[0].size(0), self.shape[0], self.shape[1],
|
| 228 |
+
self.opt["top_z_channels"]))
|
| 229 |
+
|
| 230 |
+
top_quant = self.top_post_quant_conv(top_quant)
|
| 231 |
+
|
| 232 |
+
bot_indices_list = self.bot_index_prediction(
|
| 233 |
+
top_quant, self.texture_mask[sample_idx:sample_idx + 1])
|
| 234 |
+
|
| 235 |
+
quant_bot = self.bot_quantize.get_codebook_entry(
|
| 236 |
+
bot_indices_list, self.texture_mask[sample_idx:sample_idx + 1],
|
| 237 |
+
(bot_indices_list[0].size(0), bot_indices_list[0].size(1),
|
| 238 |
+
bot_indices_list[0].size(2),
|
| 239 |
+
self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2)
|
| 240 |
+
quant_bot = self.bot_post_quant_conv(quant_bot)
|
| 241 |
+
bot_dec_res = self.bot_decoder_res(quant_bot)
|
| 242 |
+
|
| 243 |
+
dec = self.decoder(top_quant, bot_h=bot_dec_res)
|
| 244 |
+
|
| 245 |
+
dec = ((dec + 1) / 2)
|
| 246 |
+
dec = dec.clamp_(0, 1)
|
| 247 |
+
if save_dir is None and img_name is None:
|
| 248 |
+
return dec
|
| 249 |
+
else:
|
| 250 |
+
save_image(
|
| 251 |
+
dec,
|
| 252 |
+
f'{save_dir}/{img_name[sample_idx]}',
|
| 253 |
+
nrow=1,
|
| 254 |
+
padding=4)
|
| 255 |
+
|
| 256 |
+
def sample_fn(self, temp=1.0, sample_steps=None):
|
| 257 |
+
self.sampler_fn.eval()
|
| 258 |
+
|
| 259 |
+
x_t = torch.ones((self.batch_size, np.prod(self.shape)),
|
| 260 |
+
device=self.device).long() * self.mask_id
|
| 261 |
+
unmasked = torch.zeros_like(x_t, device=self.device).bool()
|
| 262 |
+
sample_steps = list(range(1, sample_steps + 1))
|
| 263 |
+
|
| 264 |
+
texture_tokens = F.interpolate(
|
| 265 |
+
self.texture_mask, (32, 16),
|
| 266 |
+
mode='nearest').view(self.batch_size, -1).long()
|
| 267 |
+
|
| 268 |
+
texture_mask_flatten = texture_tokens.view(-1)
|
| 269 |
+
|
| 270 |
+
# min_encodings_indices_list would be used to visualize the image
|
| 271 |
+
min_encodings_indices_list = [
|
| 272 |
+
torch.full(
|
| 273 |
+
texture_mask_flatten.size(),
|
| 274 |
+
fill_value=-1,
|
| 275 |
+
dtype=torch.long,
|
| 276 |
+
device=texture_mask_flatten.device) for _ in range(18)
|
| 277 |
+
]
|
| 278 |
+
|
| 279 |
+
for t in reversed(sample_steps):
|
| 280 |
+
t = torch.full((self.batch_size, ),
|
| 281 |
+
t,
|
| 282 |
+
device=self.device,
|
| 283 |
+
dtype=torch.long)
|
| 284 |
+
|
| 285 |
+
# where to unmask
|
| 286 |
+
changes = torch.rand(
|
| 287 |
+
x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1)
|
| 288 |
+
# don't unmask somewhere already unmasked
|
| 289 |
+
changes = torch.bitwise_xor(changes,
|
| 290 |
+
torch.bitwise_and(changes, unmasked))
|
| 291 |
+
# update mask with changes
|
| 292 |
+
unmasked = torch.bitwise_or(unmasked, changes)
|
| 293 |
+
|
| 294 |
+
x_0_logits_list = self.sampler_fn(
|
| 295 |
+
x_t, self.segm_tokens, texture_tokens, t=t)
|
| 296 |
+
|
| 297 |
+
changes_flatten = changes.view(-1)
|
| 298 |
+
ori_shape = x_t.shape # [b, h*w]
|
| 299 |
+
x_t = x_t.view(-1) # [b*h*w]
|
| 300 |
+
for codebook_idx, x_0_logits in enumerate(x_0_logits_list):
|
| 301 |
+
if torch.sum(texture_mask_flatten[changes_flatten] ==
|
| 302 |
+
codebook_idx) > 0:
|
| 303 |
+
# scale by temperature
|
| 304 |
+
x_0_logits = x_0_logits / temp
|
| 305 |
+
x_0_dist = dists.Categorical(logits=x_0_logits)
|
| 306 |
+
x_0_hat = x_0_dist.sample().long()
|
| 307 |
+
x_0_hat = x_0_hat.view(-1)
|
| 308 |
+
|
| 309 |
+
# only replace the changed indices with corresponding codebook_idx
|
| 310 |
+
changes_segm = torch.bitwise_and(
|
| 311 |
+
changes_flatten, texture_mask_flatten == codebook_idx)
|
| 312 |
+
|
| 313 |
+
# x_t would be the input to the transformer, so the index range should be continual one
|
| 314 |
+
x_t[changes_segm] = x_0_hat[
|
| 315 |
+
changes_segm] + 1024 * codebook_idx
|
| 316 |
+
min_encodings_indices_list[codebook_idx][
|
| 317 |
+
changes_segm] = x_0_hat[changes_segm]
|
| 318 |
+
|
| 319 |
+
x_t = x_t.view(ori_shape) # [b, h*w]
|
| 320 |
+
|
| 321 |
+
min_encodings_indices_return_list = [
|
| 322 |
+
min_encodings_indices.view(ori_shape)
|
| 323 |
+
for min_encodings_indices in min_encodings_indices_list
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
self.sampler_fn.train()
|
| 327 |
+
|
| 328 |
+
return min_encodings_indices_return_list
|
| 329 |
+
|
| 330 |
+
@torch.no_grad()
|
| 331 |
+
def get_quantized_segm(self, segm):
|
| 332 |
+
segm_one_hot = F.one_hot(
|
| 333 |
+
segm.squeeze(1).long(),
|
| 334 |
+
num_classes=self.opt['segm_num_segm_classes']).permute(
|
| 335 |
+
0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 336 |
+
encoded_segm_mask = self.segm_encoder(segm_one_hot)
|
| 337 |
+
encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask)
|
| 338 |
+
_, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask)
|
| 339 |
+
|
| 340 |
+
return segm_tokens
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class SampleFromParsingModel(BaseSampleModel):
|
| 344 |
+
"""SampleFromParsing model.
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
def feed_data(self, data):
|
| 348 |
+
self.segm = data['segm'].to(self.device)
|
| 349 |
+
self.texture_mask = data['texture_mask'].to(self.device)
|
| 350 |
+
self.batch_size = self.segm.size(0)
|
| 351 |
+
|
| 352 |
+
self.segm_tokens = self.get_quantized_segm(self.segm)
|
| 353 |
+
self.segm_tokens = self.segm_tokens.view(self.batch_size, -1)
|
| 354 |
+
|
| 355 |
+
def inference(self, data_loader, save_dir):
|
| 356 |
+
for _, data in enumerate(data_loader):
|
| 357 |
+
img_name = data['img_name']
|
| 358 |
+
self.feed_data(data)
|
| 359 |
+
with torch.no_grad():
|
| 360 |
+
self.sample_and_refine(save_dir, img_name)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class SampleFromPoseModel(BaseSampleModel):
|
| 364 |
+
"""SampleFromPose model.
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
def __init__(self, opt):
|
| 368 |
+
super().__init__(opt)
|
| 369 |
+
# pose-to-parsing
|
| 370 |
+
self.shape_attr_embedder = ShapeAttrEmbedding(
|
| 371 |
+
dim=opt['shape_embedder_dim'],
|
| 372 |
+
out_dim=opt['shape_embedder_out_dim'],
|
| 373 |
+
cls_num_list=opt['shape_attr_class_num']).to(self.device)
|
| 374 |
+
self.shape_parsing_encoder = ShapeUNet(
|
| 375 |
+
in_channels=opt['shape_encoder_in_channels']).to(self.device)
|
| 376 |
+
self.shape_parsing_decoder = FCNHead(
|
| 377 |
+
in_channels=opt['shape_fc_in_channels'],
|
| 378 |
+
in_index=opt['shape_fc_in_index'],
|
| 379 |
+
channels=opt['shape_fc_channels'],
|
| 380 |
+
num_convs=opt['shape_fc_num_convs'],
|
| 381 |
+
concat_input=opt['shape_fc_concat_input'],
|
| 382 |
+
dropout_ratio=opt['shape_fc_dropout_ratio'],
|
| 383 |
+
num_classes=opt['shape_fc_num_classes'],
|
| 384 |
+
align_corners=opt['shape_fc_align_corners'],
|
| 385 |
+
).to(self.device)
|
| 386 |
+
self.load_shape_generation_models()
|
| 387 |
+
|
| 388 |
+
self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220],
|
| 389 |
+
[250, 235, 215], [255, 250, 205], [211, 211, 211],
|
| 390 |
+
[70, 130, 180], [127, 255, 212], [0, 100, 0],
|
| 391 |
+
[50, 205, 50], [255, 255, 0], [245, 222, 179],
|
| 392 |
+
[255, 140, 0], [255, 0, 0], [16, 78, 139],
|
| 393 |
+
[144, 238, 144], [50, 205, 174], [50, 155, 250],
|
| 394 |
+
[160, 140, 88], [213, 140, 88], [90, 140, 90],
|
| 395 |
+
[185, 210, 205], [130, 165, 180], [225, 141, 151]]
|
| 396 |
+
|
| 397 |
+
def load_shape_generation_models(self):
|
| 398 |
+
checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
|
| 399 |
+
|
| 400 |
+
self.shape_attr_embedder.load_state_dict(
|
| 401 |
+
checkpoint['embedder'], strict=True)
|
| 402 |
+
self.shape_attr_embedder.eval()
|
| 403 |
+
|
| 404 |
+
self.shape_parsing_encoder.load_state_dict(
|
| 405 |
+
checkpoint['encoder'], strict=True)
|
| 406 |
+
self.shape_parsing_encoder.eval()
|
| 407 |
+
|
| 408 |
+
self.shape_parsing_decoder.load_state_dict(
|
| 409 |
+
checkpoint['decoder'], strict=True)
|
| 410 |
+
self.shape_parsing_decoder.eval()
|
| 411 |
+
|
| 412 |
+
def feed_data(self, data):
|
| 413 |
+
self.pose = data['densepose'].to(self.device)
|
| 414 |
+
self.batch_size = self.pose.size(0)
|
| 415 |
+
|
| 416 |
+
self.shape_attr = data['shape_attr'].to(self.device)
|
| 417 |
+
self.upper_fused_attr = data['upper_fused_attr'].to(self.device)
|
| 418 |
+
self.lower_fused_attr = data['lower_fused_attr'].to(self.device)
|
| 419 |
+
self.outer_fused_attr = data['outer_fused_attr'].to(self.device)
|
| 420 |
+
|
| 421 |
+
def inference(self, data_loader, save_dir):
|
| 422 |
+
for _, data in enumerate(data_loader):
|
| 423 |
+
img_name = data['img_name']
|
| 424 |
+
self.feed_data(data)
|
| 425 |
+
with torch.no_grad():
|
| 426 |
+
self.generate_parsing_map()
|
| 427 |
+
self.generate_quantized_segm()
|
| 428 |
+
self.generate_texture_map()
|
| 429 |
+
self.sample_and_refine(save_dir, img_name)
|
| 430 |
+
|
| 431 |
+
def generate_parsing_map(self):
|
| 432 |
+
with torch.no_grad():
|
| 433 |
+
attr_embedding = self.shape_attr_embedder(self.shape_attr)
|
| 434 |
+
pose_enc = self.shape_parsing_encoder(self.pose, attr_embedding)
|
| 435 |
+
seg_logits = self.shape_parsing_decoder(pose_enc)
|
| 436 |
+
self.segm = seg_logits.argmax(dim=1)
|
| 437 |
+
self.segm = self.segm.unsqueeze(1)
|
| 438 |
+
|
| 439 |
+
def generate_quantized_segm(self):
|
| 440 |
+
self.segm_tokens = self.get_quantized_segm(self.segm)
|
| 441 |
+
self.segm_tokens = self.segm_tokens.view(self.batch_size, -1)
|
| 442 |
+
|
| 443 |
+
def generate_texture_map(self):
|
| 444 |
+
upper_cls = [1., 4.]
|
| 445 |
+
lower_cls = [3., 5., 21.]
|
| 446 |
+
outer_cls = [2.]
|
| 447 |
+
|
| 448 |
+
mask_batch = []
|
| 449 |
+
for idx in range(self.batch_size):
|
| 450 |
+
mask = torch.zeros_like(self.segm[idx])
|
| 451 |
+
upper_fused_attr = self.upper_fused_attr[idx]
|
| 452 |
+
lower_fused_attr = self.lower_fused_attr[idx]
|
| 453 |
+
outer_fused_attr = self.outer_fused_attr[idx]
|
| 454 |
+
if upper_fused_attr != 17:
|
| 455 |
+
for cls in upper_cls:
|
| 456 |
+
mask[self.segm[idx] == cls] = upper_fused_attr + 1
|
| 457 |
+
|
| 458 |
+
if lower_fused_attr != 17:
|
| 459 |
+
for cls in lower_cls:
|
| 460 |
+
mask[self.segm[idx] == cls] = lower_fused_attr + 1
|
| 461 |
+
|
| 462 |
+
if outer_fused_attr != 17:
|
| 463 |
+
for cls in outer_cls:
|
| 464 |
+
mask[self.segm[idx] == cls] = outer_fused_attr + 1
|
| 465 |
+
|
| 466 |
+
mask_batch.append(mask)
|
| 467 |
+
self.texture_mask = torch.stack(mask_batch, dim=0).to(torch.float32)
|
| 468 |
+
|
| 469 |
+
def feed_pose_data(self, pose_img):
|
| 470 |
+
# for ui demo
|
| 471 |
+
|
| 472 |
+
self.pose = pose_img.to(self.device)
|
| 473 |
+
self.batch_size = self.pose.size(0)
|
| 474 |
+
|
| 475 |
+
def feed_shape_attributes(self, shape_attr):
|
| 476 |
+
# for ui demo
|
| 477 |
+
|
| 478 |
+
self.shape_attr = shape_attr.to(self.device)
|
| 479 |
+
|
| 480 |
+
def feed_texture_attributes(self, texture_attr):
|
| 481 |
+
# for ui demo
|
| 482 |
+
|
| 483 |
+
self.upper_fused_attr = texture_attr[0].unsqueeze(0).to(self.device)
|
| 484 |
+
self.lower_fused_attr = texture_attr[1].unsqueeze(0).to(self.device)
|
| 485 |
+
self.outer_fused_attr = texture_attr[2].unsqueeze(0).to(self.device)
|
| 486 |
+
|
| 487 |
+
def palette_result(self, result):
|
| 488 |
+
|
| 489 |
+
seg = result[0]
|
| 490 |
+
palette = np.array(self.palette)
|
| 491 |
+
assert palette.shape[1] == 3
|
| 492 |
+
assert len(palette.shape) == 2
|
| 493 |
+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
| 494 |
+
for label, color in enumerate(palette):
|
| 495 |
+
color_seg[seg == label, :] = color
|
| 496 |
+
# convert to BGR
|
| 497 |
+
# color_seg = color_seg[..., ::-1]
|
| 498 |
+
return color_seg
|
models/transformer_model.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributions as dists
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torchvision.utils import save_image
|
| 10 |
+
|
| 11 |
+
from models.archs.transformer_arch import TransformerMultiHead
|
| 12 |
+
from models.archs.vqgan_arch import (Decoder, Encoder, VectorQuantizer,
|
| 13 |
+
VectorQuantizerTexture)
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger('base')
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TransformerTextureAwareModel():
|
| 19 |
+
"""Texture-Aware Diffusion based Transformer model.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, opt):
|
| 23 |
+
self.opt = opt
|
| 24 |
+
self.device = torch.device('cuda')
|
| 25 |
+
self.is_train = opt['is_train']
|
| 26 |
+
|
| 27 |
+
# VQVAE for image
|
| 28 |
+
self.img_encoder = Encoder(
|
| 29 |
+
ch=opt['img_ch'],
|
| 30 |
+
num_res_blocks=opt['img_num_res_blocks'],
|
| 31 |
+
attn_resolutions=opt['img_attn_resolutions'],
|
| 32 |
+
ch_mult=opt['img_ch_mult'],
|
| 33 |
+
in_channels=opt['img_in_channels'],
|
| 34 |
+
resolution=opt['img_resolution'],
|
| 35 |
+
z_channels=opt['img_z_channels'],
|
| 36 |
+
double_z=opt['img_double_z'],
|
| 37 |
+
dropout=opt['img_dropout']).to(self.device)
|
| 38 |
+
self.img_decoder = Decoder(
|
| 39 |
+
in_channels=opt['img_in_channels'],
|
| 40 |
+
resolution=opt['img_resolution'],
|
| 41 |
+
z_channels=opt['img_z_channels'],
|
| 42 |
+
ch=opt['img_ch'],
|
| 43 |
+
out_ch=opt['img_out_ch'],
|
| 44 |
+
num_res_blocks=opt['img_num_res_blocks'],
|
| 45 |
+
attn_resolutions=opt['img_attn_resolutions'],
|
| 46 |
+
ch_mult=opt['img_ch_mult'],
|
| 47 |
+
dropout=opt['img_dropout'],
|
| 48 |
+
resamp_with_conv=True,
|
| 49 |
+
give_pre_end=False).to(self.device)
|
| 50 |
+
self.img_quantizer = VectorQuantizerTexture(
|
| 51 |
+
opt['img_n_embed'], opt['img_embed_dim'],
|
| 52 |
+
beta=0.25).to(self.device)
|
| 53 |
+
self.img_quant_conv = torch.nn.Conv2d(opt["img_z_channels"],
|
| 54 |
+
opt['img_embed_dim'],
|
| 55 |
+
1).to(self.device)
|
| 56 |
+
self.img_post_quant_conv = torch.nn.Conv2d(opt['img_embed_dim'],
|
| 57 |
+
opt["img_z_channels"],
|
| 58 |
+
1).to(self.device)
|
| 59 |
+
self.load_pretrained_image_vae()
|
| 60 |
+
|
| 61 |
+
# VAE for segmentation mask
|
| 62 |
+
self.segm_encoder = Encoder(
|
| 63 |
+
ch=opt['segm_ch'],
|
| 64 |
+
num_res_blocks=opt['segm_num_res_blocks'],
|
| 65 |
+
attn_resolutions=opt['segm_attn_resolutions'],
|
| 66 |
+
ch_mult=opt['segm_ch_mult'],
|
| 67 |
+
in_channels=opt['segm_in_channels'],
|
| 68 |
+
resolution=opt['segm_resolution'],
|
| 69 |
+
z_channels=opt['segm_z_channels'],
|
| 70 |
+
double_z=opt['segm_double_z'],
|
| 71 |
+
dropout=opt['segm_dropout']).to(self.device)
|
| 72 |
+
self.segm_quantizer = VectorQuantizer(
|
| 73 |
+
opt['segm_n_embed'],
|
| 74 |
+
opt['segm_embed_dim'],
|
| 75 |
+
beta=0.25,
|
| 76 |
+
sane_index_shape=True).to(self.device)
|
| 77 |
+
self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"],
|
| 78 |
+
opt['segm_embed_dim'],
|
| 79 |
+
1).to(self.device)
|
| 80 |
+
self.load_pretrained_segm_vae()
|
| 81 |
+
|
| 82 |
+
# define sampler
|
| 83 |
+
self._denoise_fn = TransformerMultiHead(
|
| 84 |
+
codebook_size=opt['codebook_size'],
|
| 85 |
+
segm_codebook_size=opt['segm_codebook_size'],
|
| 86 |
+
texture_codebook_size=opt['texture_codebook_size'],
|
| 87 |
+
bert_n_emb=opt['bert_n_emb'],
|
| 88 |
+
bert_n_layers=opt['bert_n_layers'],
|
| 89 |
+
bert_n_head=opt['bert_n_head'],
|
| 90 |
+
block_size=opt['block_size'],
|
| 91 |
+
latent_shape=opt['latent_shape'],
|
| 92 |
+
embd_pdrop=opt['embd_pdrop'],
|
| 93 |
+
resid_pdrop=opt['resid_pdrop'],
|
| 94 |
+
attn_pdrop=opt['attn_pdrop'],
|
| 95 |
+
num_head=opt['num_head']).to(self.device)
|
| 96 |
+
|
| 97 |
+
self.num_classes = opt['codebook_size']
|
| 98 |
+
self.shape = tuple(opt['latent_shape'])
|
| 99 |
+
self.num_timesteps = 1000
|
| 100 |
+
|
| 101 |
+
self.mask_id = opt['codebook_size']
|
| 102 |
+
self.loss_type = opt['loss_type']
|
| 103 |
+
self.mask_schedule = opt['mask_schedule']
|
| 104 |
+
|
| 105 |
+
self.sample_steps = opt['sample_steps']
|
| 106 |
+
|
| 107 |
+
self.init_training_settings()
|
| 108 |
+
|
| 109 |
+
def load_pretrained_image_vae(self):
|
| 110 |
+
# load pretrained vqgan for segmentation mask
|
| 111 |
+
img_ae_checkpoint = torch.load(self.opt['img_ae_path'])
|
| 112 |
+
self.img_encoder.load_state_dict(
|
| 113 |
+
img_ae_checkpoint['encoder'], strict=True)
|
| 114 |
+
self.img_decoder.load_state_dict(
|
| 115 |
+
img_ae_checkpoint['decoder'], strict=True)
|
| 116 |
+
self.img_quantizer.load_state_dict(
|
| 117 |
+
img_ae_checkpoint['quantize'], strict=True)
|
| 118 |
+
self.img_quant_conv.load_state_dict(
|
| 119 |
+
img_ae_checkpoint['quant_conv'], strict=True)
|
| 120 |
+
self.img_post_quant_conv.load_state_dict(
|
| 121 |
+
img_ae_checkpoint['post_quant_conv'], strict=True)
|
| 122 |
+
self.img_encoder.eval()
|
| 123 |
+
self.img_decoder.eval()
|
| 124 |
+
self.img_quantizer.eval()
|
| 125 |
+
self.img_quant_conv.eval()
|
| 126 |
+
self.img_post_quant_conv.eval()
|
| 127 |
+
|
| 128 |
+
def load_pretrained_segm_vae(self):
|
| 129 |
+
# load pretrained vqgan for segmentation mask
|
| 130 |
+
segm_ae_checkpoint = torch.load(self.opt['segm_ae_path'])
|
| 131 |
+
self.segm_encoder.load_state_dict(
|
| 132 |
+
segm_ae_checkpoint['encoder'], strict=True)
|
| 133 |
+
self.segm_quantizer.load_state_dict(
|
| 134 |
+
segm_ae_checkpoint['quantize'], strict=True)
|
| 135 |
+
self.segm_quant_conv.load_state_dict(
|
| 136 |
+
segm_ae_checkpoint['quant_conv'], strict=True)
|
| 137 |
+
self.segm_encoder.eval()
|
| 138 |
+
self.segm_quantizer.eval()
|
| 139 |
+
self.segm_quant_conv.eval()
|
| 140 |
+
|
| 141 |
+
def init_training_settings(self):
|
| 142 |
+
optim_params = []
|
| 143 |
+
for v in self._denoise_fn.parameters():
|
| 144 |
+
if v.requires_grad:
|
| 145 |
+
optim_params.append(v)
|
| 146 |
+
# set up optimizer
|
| 147 |
+
self.optimizer = torch.optim.Adam(
|
| 148 |
+
optim_params,
|
| 149 |
+
self.opt['lr'],
|
| 150 |
+
weight_decay=self.opt['weight_decay'])
|
| 151 |
+
self.log_dict = OrderedDict()
|
| 152 |
+
|
| 153 |
+
@torch.no_grad()
|
| 154 |
+
def get_quantized_img(self, image, texture_mask):
|
| 155 |
+
encoded_img = self.img_encoder(image)
|
| 156 |
+
encoded_img = self.img_quant_conv(encoded_img)
|
| 157 |
+
|
| 158 |
+
# img_tokens_input is the continual index for the input of transformer
|
| 159 |
+
# img_tokens_gt_list is the index for 18 texture-aware codebooks respectively
|
| 160 |
+
_, _, [_, img_tokens_input, img_tokens_gt_list
|
| 161 |
+
] = self.img_quantizer(encoded_img, texture_mask)
|
| 162 |
+
|
| 163 |
+
# reshape the tokens
|
| 164 |
+
b = image.size(0)
|
| 165 |
+
img_tokens_input = img_tokens_input.view(b, -1)
|
| 166 |
+
img_tokens_gt_return_list = [
|
| 167 |
+
img_tokens_gt.view(b, -1) for img_tokens_gt in img_tokens_gt_list
|
| 168 |
+
]
|
| 169 |
+
|
| 170 |
+
return img_tokens_input, img_tokens_gt_return_list
|
| 171 |
+
|
| 172 |
+
@torch.no_grad()
|
| 173 |
+
def decode(self, quant):
|
| 174 |
+
quant = self.img_post_quant_conv(quant)
|
| 175 |
+
dec = self.img_decoder(quant)
|
| 176 |
+
return dec
|
| 177 |
+
|
| 178 |
+
@torch.no_grad()
|
| 179 |
+
def decode_image_indices(self, indices_list, texture_mask):
|
| 180 |
+
quant = self.img_quantizer.get_codebook_entry(
|
| 181 |
+
indices_list, texture_mask,
|
| 182 |
+
(indices_list[0].size(0), self.shape[0], self.shape[1],
|
| 183 |
+
self.opt["img_z_channels"]))
|
| 184 |
+
dec = self.decode(quant)
|
| 185 |
+
|
| 186 |
+
return dec
|
| 187 |
+
|
| 188 |
+
def sample_time(self, b, device, method='uniform'):
|
| 189 |
+
if method == 'importance':
|
| 190 |
+
if not (self.Lt_count > 10).all():
|
| 191 |
+
return self.sample_time(b, device, method='uniform')
|
| 192 |
+
|
| 193 |
+
Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
|
| 194 |
+
Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
|
| 195 |
+
pt_all = Lt_sqrt / Lt_sqrt.sum()
|
| 196 |
+
|
| 197 |
+
t = torch.multinomial(pt_all, num_samples=b, replacement=True)
|
| 198 |
+
|
| 199 |
+
pt = pt_all.gather(dim=0, index=t)
|
| 200 |
+
|
| 201 |
+
return t, pt
|
| 202 |
+
|
| 203 |
+
elif method == 'uniform':
|
| 204 |
+
t = torch.randint(
|
| 205 |
+
1, self.num_timesteps + 1, (b, ), device=device).long()
|
| 206 |
+
pt = torch.ones_like(t).float() / self.num_timesteps
|
| 207 |
+
return t, pt
|
| 208 |
+
|
| 209 |
+
else:
|
| 210 |
+
raise ValueError
|
| 211 |
+
|
| 212 |
+
def q_sample(self, x_0, x_0_gt_list, t):
|
| 213 |
+
# samples q(x_t | x_0)
|
| 214 |
+
# randomly set token to mask with probability t/T
|
| 215 |
+
# x_t, x_0_ignore = x_0.clone(), x_0.clone()
|
| 216 |
+
x_t = x_0.clone()
|
| 217 |
+
|
| 218 |
+
mask = torch.rand_like(x_t.float()) < (
|
| 219 |
+
t.float().unsqueeze(-1) / self.num_timesteps)
|
| 220 |
+
x_t[mask] = self.mask_id
|
| 221 |
+
# x_0_ignore[torch.bitwise_not(mask)] = -1
|
| 222 |
+
|
| 223 |
+
# for every gt token list, we also need to do the mask
|
| 224 |
+
x_0_gt_ignore_list = []
|
| 225 |
+
for x_0_gt in x_0_gt_list:
|
| 226 |
+
x_0_gt_ignore = x_0_gt.clone()
|
| 227 |
+
x_0_gt_ignore[torch.bitwise_not(mask)] = -1
|
| 228 |
+
x_0_gt_ignore_list.append(x_0_gt_ignore)
|
| 229 |
+
|
| 230 |
+
return x_t, x_0_gt_ignore_list, mask
|
| 231 |
+
|
| 232 |
+
def _train_loss(self, x_0, x_0_gt_list):
|
| 233 |
+
b, device = x_0.size(0), x_0.device
|
| 234 |
+
|
| 235 |
+
# choose what time steps to compute loss at
|
| 236 |
+
t, pt = self.sample_time(b, device, 'uniform')
|
| 237 |
+
|
| 238 |
+
# make x noisy and denoise
|
| 239 |
+
if self.mask_schedule == 'random':
|
| 240 |
+
x_t, x_0_gt_ignore_list, mask = self.q_sample(
|
| 241 |
+
x_0=x_0, x_0_gt_list=x_0_gt_list, t=t)
|
| 242 |
+
else:
|
| 243 |
+
raise NotImplementedError
|
| 244 |
+
|
| 245 |
+
# sample p(x_0 | x_t)
|
| 246 |
+
x_0_hat_logits_list = self._denoise_fn(
|
| 247 |
+
x_t, self.segm_tokens, self.texture_tokens, t=t)
|
| 248 |
+
|
| 249 |
+
# Always compute ELBO for comparison purposes
|
| 250 |
+
cross_entropy_loss = 0
|
| 251 |
+
for x_0_hat_logits, x_0_gt_ignore in zip(x_0_hat_logits_list,
|
| 252 |
+
x_0_gt_ignore_list):
|
| 253 |
+
cross_entropy_loss += F.cross_entropy(
|
| 254 |
+
x_0_hat_logits.permute(0, 2, 1),
|
| 255 |
+
x_0_gt_ignore,
|
| 256 |
+
ignore_index=-1,
|
| 257 |
+
reduction='none').sum(1)
|
| 258 |
+
vb_loss = cross_entropy_loss / t
|
| 259 |
+
vb_loss = vb_loss / pt
|
| 260 |
+
vb_loss = vb_loss / (math.log(2) * x_0.shape[1:].numel())
|
| 261 |
+
if self.loss_type == 'elbo':
|
| 262 |
+
loss = vb_loss
|
| 263 |
+
elif self.loss_type == 'mlm':
|
| 264 |
+
denom = mask.float().sum(1)
|
| 265 |
+
denom[denom == 0] = 1 # prevent divide by 0 errors.
|
| 266 |
+
loss = cross_entropy_loss / denom
|
| 267 |
+
elif self.loss_type == 'reweighted_elbo':
|
| 268 |
+
weight = (1 - (t / self.num_timesteps))
|
| 269 |
+
loss = weight * cross_entropy_loss
|
| 270 |
+
loss = loss / (math.log(2) * x_0.shape[1:].numel())
|
| 271 |
+
else:
|
| 272 |
+
raise ValueError
|
| 273 |
+
|
| 274 |
+
return loss.mean(), vb_loss.mean()
|
| 275 |
+
|
| 276 |
+
def feed_data(self, data):
|
| 277 |
+
self.image = data['image'].to(self.device)
|
| 278 |
+
self.segm = data['segm'].to(self.device)
|
| 279 |
+
self.texture_mask = data['texture_mask'].to(self.device)
|
| 280 |
+
self.input_indices, self.gt_indices_list = self.get_quantized_img(
|
| 281 |
+
self.image, self.texture_mask)
|
| 282 |
+
|
| 283 |
+
self.texture_tokens = F.interpolate(
|
| 284 |
+
self.texture_mask, size=self.shape,
|
| 285 |
+
mode='nearest').view(self.image.size(0), -1).long()
|
| 286 |
+
|
| 287 |
+
self.segm_tokens = self.get_quantized_segm(self.segm)
|
| 288 |
+
self.segm_tokens = self.segm_tokens.view(self.image.size(0), -1)
|
| 289 |
+
|
| 290 |
+
def optimize_parameters(self):
|
| 291 |
+
self._denoise_fn.train()
|
| 292 |
+
|
| 293 |
+
loss, vb_loss = self._train_loss(self.input_indices,
|
| 294 |
+
self.gt_indices_list)
|
| 295 |
+
|
| 296 |
+
self.optimizer.zero_grad()
|
| 297 |
+
loss.backward()
|
| 298 |
+
self.optimizer.step()
|
| 299 |
+
|
| 300 |
+
self.log_dict['loss'] = loss
|
| 301 |
+
self.log_dict['vb_loss'] = vb_loss
|
| 302 |
+
|
| 303 |
+
self._denoise_fn.eval()
|
| 304 |
+
|
| 305 |
+
@torch.no_grad()
|
| 306 |
+
def get_quantized_segm(self, segm):
|
| 307 |
+
segm_one_hot = F.one_hot(
|
| 308 |
+
segm.squeeze(1).long(),
|
| 309 |
+
num_classes=self.opt['segm_num_segm_classes']).permute(
|
| 310 |
+
0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 311 |
+
encoded_segm_mask = self.segm_encoder(segm_one_hot)
|
| 312 |
+
encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask)
|
| 313 |
+
_, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask)
|
| 314 |
+
|
| 315 |
+
return segm_tokens
|
| 316 |
+
|
| 317 |
+
def sample_fn(self, temp=1.0, sample_steps=None):
|
| 318 |
+
self._denoise_fn.eval()
|
| 319 |
+
|
| 320 |
+
b, device = self.image.size(0), 'cuda'
|
| 321 |
+
x_t = torch.ones(
|
| 322 |
+
(b, np.prod(self.shape)), device=device).long() * self.mask_id
|
| 323 |
+
unmasked = torch.zeros_like(x_t, device=device).bool()
|
| 324 |
+
sample_steps = list(range(1, sample_steps + 1))
|
| 325 |
+
|
| 326 |
+
texture_mask_flatten = self.texture_tokens.view(-1)
|
| 327 |
+
|
| 328 |
+
# min_encodings_indices_list would be used to visualize the image
|
| 329 |
+
min_encodings_indices_list = [
|
| 330 |
+
torch.full(
|
| 331 |
+
texture_mask_flatten.size(),
|
| 332 |
+
fill_value=-1,
|
| 333 |
+
dtype=torch.long,
|
| 334 |
+
device=texture_mask_flatten.device) for _ in range(18)
|
| 335 |
+
]
|
| 336 |
+
|
| 337 |
+
for t in reversed(sample_steps):
|
| 338 |
+
print(f'Sample timestep {t:4d}', end='\r')
|
| 339 |
+
t = torch.full((b, ), t, device=device, dtype=torch.long)
|
| 340 |
+
|
| 341 |
+
# where to unmask
|
| 342 |
+
changes = torch.rand(
|
| 343 |
+
x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
|
| 344 |
+
# don't unmask somewhere already unmasked
|
| 345 |
+
changes = torch.bitwise_xor(changes,
|
| 346 |
+
torch.bitwise_and(changes, unmasked))
|
| 347 |
+
# update mask with changes
|
| 348 |
+
unmasked = torch.bitwise_or(unmasked, changes)
|
| 349 |
+
|
| 350 |
+
x_0_logits_list = self._denoise_fn(
|
| 351 |
+
x_t, self.segm_tokens, self.texture_tokens, t=t)
|
| 352 |
+
|
| 353 |
+
changes_flatten = changes.view(-1)
|
| 354 |
+
ori_shape = x_t.shape # [b, h*w]
|
| 355 |
+
x_t = x_t.view(-1) # [b*h*w]
|
| 356 |
+
for codebook_idx, x_0_logits in enumerate(x_0_logits_list):
|
| 357 |
+
if torch.sum(texture_mask_flatten[changes_flatten] ==
|
| 358 |
+
codebook_idx) > 0:
|
| 359 |
+
# scale by temperature
|
| 360 |
+
x_0_logits = x_0_logits / temp
|
| 361 |
+
x_0_dist = dists.Categorical(logits=x_0_logits)
|
| 362 |
+
x_0_hat = x_0_dist.sample().long()
|
| 363 |
+
x_0_hat = x_0_hat.view(-1)
|
| 364 |
+
|
| 365 |
+
# only replace the changed indices with corresponding codebook_idx
|
| 366 |
+
changes_segm = torch.bitwise_and(
|
| 367 |
+
changes_flatten, texture_mask_flatten == codebook_idx)
|
| 368 |
+
|
| 369 |
+
# x_t would be the input to the transformer, so the index range should be continual one
|
| 370 |
+
x_t[changes_segm] = x_0_hat[
|
| 371 |
+
changes_segm] + 1024 * codebook_idx
|
| 372 |
+
min_encodings_indices_list[codebook_idx][
|
| 373 |
+
changes_segm] = x_0_hat[changes_segm]
|
| 374 |
+
|
| 375 |
+
x_t = x_t.view(ori_shape) # [b, h*w]
|
| 376 |
+
|
| 377 |
+
min_encodings_indices_return_list = [
|
| 378 |
+
min_encodings_indices.view(ori_shape)
|
| 379 |
+
for min_encodings_indices in min_encodings_indices_list
|
| 380 |
+
]
|
| 381 |
+
|
| 382 |
+
self._denoise_fn.train()
|
| 383 |
+
|
| 384 |
+
return min_encodings_indices_return_list
|
| 385 |
+
|
| 386 |
+
def get_vis(self, image, gt_indices, predicted_indices, texture_mask,
|
| 387 |
+
save_path):
|
| 388 |
+
# original image
|
| 389 |
+
ori_img = self.decode_image_indices(gt_indices, texture_mask)
|
| 390 |
+
# pred image
|
| 391 |
+
pred_img = self.decode_image_indices(predicted_indices, texture_mask)
|
| 392 |
+
img_cat = torch.cat([
|
| 393 |
+
image,
|
| 394 |
+
ori_img,
|
| 395 |
+
pred_img,
|
| 396 |
+
], dim=3).detach()
|
| 397 |
+
img_cat = ((img_cat + 1) / 2)
|
| 398 |
+
img_cat = img_cat.clamp_(0, 1)
|
| 399 |
+
save_image(img_cat, save_path, nrow=1, padding=4)
|
| 400 |
+
|
| 401 |
+
def inference(self, data_loader, save_dir):
|
| 402 |
+
self._denoise_fn.eval()
|
| 403 |
+
|
| 404 |
+
for _, data in enumerate(data_loader):
|
| 405 |
+
img_name = data['img_name']
|
| 406 |
+
self.feed_data(data)
|
| 407 |
+
b = self.image.size(0)
|
| 408 |
+
with torch.no_grad():
|
| 409 |
+
sampled_indices_list = self.sample_fn(
|
| 410 |
+
temp=1, sample_steps=self.sample_steps)
|
| 411 |
+
for idx in range(b):
|
| 412 |
+
self.get_vis(self.image[idx:idx + 1], [
|
| 413 |
+
gt_indices[idx:idx + 1]
|
| 414 |
+
for gt_indices in self.gt_indices_list
|
| 415 |
+
], [
|
| 416 |
+
sampled_indices[idx:idx + 1]
|
| 417 |
+
for sampled_indices in sampled_indices_list
|
| 418 |
+
], self.texture_mask[idx:idx + 1],
|
| 419 |
+
f'{save_dir}/{img_name[idx]}')
|
| 420 |
+
|
| 421 |
+
self._denoise_fn.train()
|
| 422 |
+
|
| 423 |
+
def get_current_log(self):
|
| 424 |
+
return self.log_dict
|
| 425 |
+
|
| 426 |
+
def update_learning_rate(self, epoch, iters=None):
|
| 427 |
+
"""Update learning rate.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
current_iter (int): Current iteration.
|
| 431 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
| 432 |
+
Default: -1.
|
| 433 |
+
"""
|
| 434 |
+
lr = self.optimizer.param_groups[0]['lr']
|
| 435 |
+
|
| 436 |
+
if self.opt['lr_decay'] == 'step':
|
| 437 |
+
lr = self.opt['lr'] * (
|
| 438 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
| 439 |
+
elif self.opt['lr_decay'] == 'cos':
|
| 440 |
+
lr = self.opt['lr'] * (
|
| 441 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
| 442 |
+
elif self.opt['lr_decay'] == 'linear':
|
| 443 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
| 444 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
| 445 |
+
if epoch < self.opt['turning_point'] + 1:
|
| 446 |
+
# learning rate decay as 95%
|
| 447 |
+
# at the turning point (1 / 95% = 1.0526)
|
| 448 |
+
lr = self.opt['lr'] * (
|
| 449 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
| 450 |
+
else:
|
| 451 |
+
lr *= self.opt['gamma']
|
| 452 |
+
elif self.opt['lr_decay'] == 'schedule':
|
| 453 |
+
if epoch in self.opt['schedule']:
|
| 454 |
+
lr *= self.opt['gamma']
|
| 455 |
+
elif self.opt['lr_decay'] == 'warm_up':
|
| 456 |
+
if iters <= self.opt['warmup_iters']:
|
| 457 |
+
lr = self.opt['lr'] * float(iters) / self.opt['warmup_iters']
|
| 458 |
+
else:
|
| 459 |
+
lr = self.opt['lr']
|
| 460 |
+
else:
|
| 461 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
| 462 |
+
# set learning rate
|
| 463 |
+
for param_group in self.optimizer.param_groups:
|
| 464 |
+
param_group['lr'] = lr
|
| 465 |
+
|
| 466 |
+
return lr
|
| 467 |
+
|
| 468 |
+
def save_network(self, net, save_path):
|
| 469 |
+
"""Save networks.
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
net (nn.Module): Network to be saved.
|
| 473 |
+
net_label (str): Network label.
|
| 474 |
+
current_iter (int): Current iter number.
|
| 475 |
+
"""
|
| 476 |
+
state_dict = net.state_dict()
|
| 477 |
+
torch.save(state_dict, save_path)
|
| 478 |
+
|
| 479 |
+
def load_network(self):
|
| 480 |
+
checkpoint = torch.load(self.opt['pretrained_sampler'])
|
| 481 |
+
self._denoise_fn.load_state_dict(checkpoint, strict=True)
|
| 482 |
+
self._denoise_fn.eval()
|
models/vqgan_model.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import sys
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
sys.path.append('..')
|
| 6 |
+
import lpips
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torchvision.utils import save_image
|
| 10 |
+
|
| 11 |
+
from models.archs.vqgan_arch import (Decoder, Discriminator, Encoder,
|
| 12 |
+
VectorQuantizer, VectorQuantizerTexture)
|
| 13 |
+
from models.losses.segmentation_loss import BCELossWithQuant
|
| 14 |
+
from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
|
| 15 |
+
calculate_adaptive_weight, hinge_d_loss)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class VQModel():
|
| 19 |
+
|
| 20 |
+
def __init__(self, opt):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.opt = opt
|
| 23 |
+
self.device = torch.device('cuda')
|
| 24 |
+
self.encoder = Encoder(
|
| 25 |
+
ch=opt['ch'],
|
| 26 |
+
num_res_blocks=opt['num_res_blocks'],
|
| 27 |
+
attn_resolutions=opt['attn_resolutions'],
|
| 28 |
+
ch_mult=opt['ch_mult'],
|
| 29 |
+
in_channels=opt['in_channels'],
|
| 30 |
+
resolution=opt['resolution'],
|
| 31 |
+
z_channels=opt['z_channels'],
|
| 32 |
+
double_z=opt['double_z'],
|
| 33 |
+
dropout=opt['dropout']).to(self.device)
|
| 34 |
+
self.decoder = Decoder(
|
| 35 |
+
in_channels=opt['in_channels'],
|
| 36 |
+
resolution=opt['resolution'],
|
| 37 |
+
z_channels=opt['z_channels'],
|
| 38 |
+
ch=opt['ch'],
|
| 39 |
+
out_ch=opt['out_ch'],
|
| 40 |
+
num_res_blocks=opt['num_res_blocks'],
|
| 41 |
+
attn_resolutions=opt['attn_resolutions'],
|
| 42 |
+
ch_mult=opt['ch_mult'],
|
| 43 |
+
dropout=opt['dropout'],
|
| 44 |
+
resamp_with_conv=True,
|
| 45 |
+
give_pre_end=False).to(self.device)
|
| 46 |
+
self.quantize = VectorQuantizer(
|
| 47 |
+
opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device)
|
| 48 |
+
self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'],
|
| 49 |
+
1).to(self.device)
|
| 50 |
+
self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
| 51 |
+
opt["z_channels"],
|
| 52 |
+
1).to(self.device)
|
| 53 |
+
|
| 54 |
+
def init_training_settings(self):
|
| 55 |
+
self.loss = BCELossWithQuant()
|
| 56 |
+
self.log_dict = OrderedDict()
|
| 57 |
+
self.configure_optimizers()
|
| 58 |
+
|
| 59 |
+
def save_network(self, save_path):
|
| 60 |
+
"""Save networks.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
net (nn.Module): Network to be saved.
|
| 64 |
+
net_label (str): Network label.
|
| 65 |
+
current_iter (int): Current iter number.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
save_dict = {}
|
| 69 |
+
save_dict['encoder'] = self.encoder.state_dict()
|
| 70 |
+
save_dict['decoder'] = self.decoder.state_dict()
|
| 71 |
+
save_dict['quantize'] = self.quantize.state_dict()
|
| 72 |
+
save_dict['quant_conv'] = self.quant_conv.state_dict()
|
| 73 |
+
save_dict['post_quant_conv'] = self.post_quant_conv.state_dict()
|
| 74 |
+
save_dict['discriminator'] = self.disc.state_dict()
|
| 75 |
+
torch.save(save_dict, save_path)
|
| 76 |
+
|
| 77 |
+
def load_network(self):
|
| 78 |
+
checkpoint = torch.load(self.opt['pretrained_models'])
|
| 79 |
+
self.encoder.load_state_dict(checkpoint['encoder'], strict=True)
|
| 80 |
+
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
|
| 81 |
+
self.quantize.load_state_dict(checkpoint['quantize'], strict=True)
|
| 82 |
+
self.quant_conv.load_state_dict(checkpoint['quant_conv'], strict=True)
|
| 83 |
+
self.post_quant_conv.load_state_dict(
|
| 84 |
+
checkpoint['post_quant_conv'], strict=True)
|
| 85 |
+
|
| 86 |
+
def optimize_parameters(self, data, current_iter):
|
| 87 |
+
self.encoder.train()
|
| 88 |
+
self.decoder.train()
|
| 89 |
+
self.quantize.train()
|
| 90 |
+
self.quant_conv.train()
|
| 91 |
+
self.post_quant_conv.train()
|
| 92 |
+
|
| 93 |
+
loss = self.training_step(data)
|
| 94 |
+
self.optimizer.zero_grad()
|
| 95 |
+
loss.backward()
|
| 96 |
+
self.optimizer.step()
|
| 97 |
+
|
| 98 |
+
def encode(self, x):
|
| 99 |
+
h = self.encoder(x)
|
| 100 |
+
h = self.quant_conv(h)
|
| 101 |
+
quant, emb_loss, info = self.quantize(h)
|
| 102 |
+
return quant, emb_loss, info
|
| 103 |
+
|
| 104 |
+
def decode(self, quant):
|
| 105 |
+
quant = self.post_quant_conv(quant)
|
| 106 |
+
dec = self.decoder(quant)
|
| 107 |
+
return dec
|
| 108 |
+
|
| 109 |
+
def decode_code(self, code_b):
|
| 110 |
+
quant_b = self.quantize.embed_code(code_b)
|
| 111 |
+
dec = self.decode(quant_b)
|
| 112 |
+
return dec
|
| 113 |
+
|
| 114 |
+
def forward_step(self, input):
|
| 115 |
+
quant, diff, _ = self.encode(input)
|
| 116 |
+
dec = self.decode(quant)
|
| 117 |
+
return dec, diff
|
| 118 |
+
|
| 119 |
+
def feed_data(self, data):
|
| 120 |
+
x = data['segm']
|
| 121 |
+
x = F.one_hot(x, num_classes=self.opt['num_segm_classes'])
|
| 122 |
+
|
| 123 |
+
if len(x.shape) == 3:
|
| 124 |
+
x = x[..., None]
|
| 125 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
| 126 |
+
return x.float().to(self.device)
|
| 127 |
+
|
| 128 |
+
def get_current_log(self):
|
| 129 |
+
return self.log_dict
|
| 130 |
+
|
| 131 |
+
def update_learning_rate(self, epoch):
|
| 132 |
+
"""Update learning rate.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
current_iter (int): Current iteration.
|
| 136 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
| 137 |
+
Default: -1.
|
| 138 |
+
"""
|
| 139 |
+
lr = self.optimizer.param_groups[0]['lr']
|
| 140 |
+
|
| 141 |
+
if self.opt['lr_decay'] == 'step':
|
| 142 |
+
lr = self.opt['lr'] * (
|
| 143 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
| 144 |
+
elif self.opt['lr_decay'] == 'cos':
|
| 145 |
+
lr = self.opt['lr'] * (
|
| 146 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
| 147 |
+
elif self.opt['lr_decay'] == 'linear':
|
| 148 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
| 149 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
| 150 |
+
if epoch < self.opt['turning_point'] + 1:
|
| 151 |
+
# learning rate decay as 95%
|
| 152 |
+
# at the turning point (1 / 95% = 1.0526)
|
| 153 |
+
lr = self.opt['lr'] * (
|
| 154 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
| 155 |
+
else:
|
| 156 |
+
lr *= self.opt['gamma']
|
| 157 |
+
elif self.opt['lr_decay'] == 'schedule':
|
| 158 |
+
if epoch in self.opt['schedule']:
|
| 159 |
+
lr *= self.opt['gamma']
|
| 160 |
+
else:
|
| 161 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
| 162 |
+
# set learning rate
|
| 163 |
+
for param_group in self.optimizer.param_groups:
|
| 164 |
+
param_group['lr'] = lr
|
| 165 |
+
|
| 166 |
+
return lr
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class VQSegmentationModel(VQModel):
|
| 170 |
+
|
| 171 |
+
def __init__(self, opt):
|
| 172 |
+
super().__init__(opt)
|
| 173 |
+
self.colorize = torch.randn(3, opt['num_segm_classes'], 1,
|
| 174 |
+
1).to(self.device)
|
| 175 |
+
|
| 176 |
+
self.init_training_settings()
|
| 177 |
+
|
| 178 |
+
def configure_optimizers(self):
|
| 179 |
+
self.optimizer = torch.optim.Adam(
|
| 180 |
+
list(self.encoder.parameters()) + list(self.decoder.parameters()) +
|
| 181 |
+
list(self.quantize.parameters()) +
|
| 182 |
+
list(self.quant_conv.parameters()) +
|
| 183 |
+
list(self.post_quant_conv.parameters()),
|
| 184 |
+
lr=self.opt['lr'],
|
| 185 |
+
betas=(0.5, 0.9))
|
| 186 |
+
|
| 187 |
+
def training_step(self, data):
|
| 188 |
+
x = self.feed_data(data)
|
| 189 |
+
xrec, qloss = self.forward_step(x)
|
| 190 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
|
| 191 |
+
self.log_dict.update(log_dict_ae)
|
| 192 |
+
return aeloss
|
| 193 |
+
|
| 194 |
+
def to_rgb(self, x):
|
| 195 |
+
x = F.conv2d(x, weight=self.colorize)
|
| 196 |
+
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
| 197 |
+
return x
|
| 198 |
+
|
| 199 |
+
@torch.no_grad()
|
| 200 |
+
def inference(self, data_loader, save_dir):
|
| 201 |
+
self.encoder.eval()
|
| 202 |
+
self.decoder.eval()
|
| 203 |
+
self.quantize.eval()
|
| 204 |
+
self.quant_conv.eval()
|
| 205 |
+
self.post_quant_conv.eval()
|
| 206 |
+
|
| 207 |
+
loss_total = 0
|
| 208 |
+
loss_bce = 0
|
| 209 |
+
loss_quant = 0
|
| 210 |
+
num = 0
|
| 211 |
+
|
| 212 |
+
for _, data in enumerate(data_loader):
|
| 213 |
+
img_name = data['img_name'][0]
|
| 214 |
+
x = self.feed_data(data)
|
| 215 |
+
xrec, qloss = self.forward_step(x)
|
| 216 |
+
_, log_dict_ae = self.loss(qloss, x, xrec, split="val")
|
| 217 |
+
|
| 218 |
+
loss_total += log_dict_ae['val/total_loss']
|
| 219 |
+
loss_bce += log_dict_ae['val/bce_loss']
|
| 220 |
+
loss_quant += log_dict_ae['val/quant_loss']
|
| 221 |
+
|
| 222 |
+
num += x.size(0)
|
| 223 |
+
|
| 224 |
+
if x.shape[1] > 3:
|
| 225 |
+
# colorize with random projection
|
| 226 |
+
assert xrec.shape[1] > 3
|
| 227 |
+
# convert logits to indices
|
| 228 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
| 229 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
| 230 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
| 231 |
+
x = self.to_rgb(x)
|
| 232 |
+
xrec = self.to_rgb(xrec)
|
| 233 |
+
|
| 234 |
+
img_cat = torch.cat([x, xrec], dim=3).detach()
|
| 235 |
+
img_cat = ((img_cat + 1) / 2)
|
| 236 |
+
img_cat = img_cat.clamp_(0, 1)
|
| 237 |
+
save_image(
|
| 238 |
+
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
|
| 239 |
+
|
| 240 |
+
return (loss_total / num).item(), (loss_bce /
|
| 241 |
+
num).item(), (loss_quant /
|
| 242 |
+
num).item()
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class VQImageModel(VQModel):
|
| 246 |
+
|
| 247 |
+
def __init__(self, opt):
|
| 248 |
+
super().__init__(opt)
|
| 249 |
+
self.disc = Discriminator(
|
| 250 |
+
opt['n_channels'], opt['ndf'],
|
| 251 |
+
n_layers=opt['disc_layers']).to(self.device)
|
| 252 |
+
self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
|
| 253 |
+
self.perceptual_weight = opt['perceptual_weight']
|
| 254 |
+
self.disc_start_step = opt['disc_start_step']
|
| 255 |
+
self.disc_weight_max = opt['disc_weight_max']
|
| 256 |
+
self.diff_aug = opt['diff_aug']
|
| 257 |
+
self.policy = "color,translation"
|
| 258 |
+
|
| 259 |
+
self.disc.train()
|
| 260 |
+
|
| 261 |
+
self.init_training_settings()
|
| 262 |
+
|
| 263 |
+
def feed_data(self, data):
|
| 264 |
+
x = data['image']
|
| 265 |
+
|
| 266 |
+
return x.float().to(self.device)
|
| 267 |
+
|
| 268 |
+
def init_training_settings(self):
|
| 269 |
+
self.log_dict = OrderedDict()
|
| 270 |
+
self.configure_optimizers()
|
| 271 |
+
|
| 272 |
+
def configure_optimizers(self):
|
| 273 |
+
self.optimizer = torch.optim.Adam(
|
| 274 |
+
list(self.encoder.parameters()) + list(self.decoder.parameters()) +
|
| 275 |
+
list(self.quantize.parameters()) +
|
| 276 |
+
list(self.quant_conv.parameters()) +
|
| 277 |
+
list(self.post_quant_conv.parameters()),
|
| 278 |
+
lr=self.opt['lr'])
|
| 279 |
+
|
| 280 |
+
self.disc_optimizer = torch.optim.Adam(
|
| 281 |
+
self.disc.parameters(), lr=self.opt['lr'])
|
| 282 |
+
|
| 283 |
+
def training_step(self, data, step):
|
| 284 |
+
x = self.feed_data(data)
|
| 285 |
+
xrec, codebook_loss = self.forward_step(x)
|
| 286 |
+
|
| 287 |
+
# get recon/perceptual loss
|
| 288 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
| 289 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
| 290 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
| 291 |
+
nll_loss = torch.mean(nll_loss)
|
| 292 |
+
|
| 293 |
+
# augment for input to discriminator
|
| 294 |
+
if self.diff_aug:
|
| 295 |
+
xrec = DiffAugment(xrec, policy=self.policy)
|
| 296 |
+
|
| 297 |
+
# update generator
|
| 298 |
+
logits_fake = self.disc(xrec)
|
| 299 |
+
g_loss = -torch.mean(logits_fake)
|
| 300 |
+
last_layer = self.decoder.conv_out.weight
|
| 301 |
+
d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
|
| 302 |
+
self.disc_weight_max)
|
| 303 |
+
d_weight *= adopt_weight(1, step, self.disc_start_step)
|
| 304 |
+
loss = nll_loss + d_weight * g_loss + codebook_loss
|
| 305 |
+
|
| 306 |
+
self.log_dict["loss"] = loss
|
| 307 |
+
self.log_dict["l1"] = recon_loss.mean().item()
|
| 308 |
+
self.log_dict["perceptual"] = p_loss.mean().item()
|
| 309 |
+
self.log_dict["nll_loss"] = nll_loss.item()
|
| 310 |
+
self.log_dict["g_loss"] = g_loss.item()
|
| 311 |
+
self.log_dict["d_weight"] = d_weight
|
| 312 |
+
self.log_dict["codebook_loss"] = codebook_loss.item()
|
| 313 |
+
|
| 314 |
+
if step > self.disc_start_step:
|
| 315 |
+
if self.diff_aug:
|
| 316 |
+
logits_real = self.disc(
|
| 317 |
+
DiffAugment(x.contiguous().detach(), policy=self.policy))
|
| 318 |
+
else:
|
| 319 |
+
logits_real = self.disc(x.contiguous().detach())
|
| 320 |
+
logits_fake = self.disc(xrec.contiguous().detach(
|
| 321 |
+
)) # detach so that generator isn"t also updated
|
| 322 |
+
d_loss = hinge_d_loss(logits_real, logits_fake)
|
| 323 |
+
self.log_dict["d_loss"] = d_loss
|
| 324 |
+
else:
|
| 325 |
+
d_loss = None
|
| 326 |
+
|
| 327 |
+
return loss, d_loss
|
| 328 |
+
|
| 329 |
+
def optimize_parameters(self, data, step):
|
| 330 |
+
self.encoder.train()
|
| 331 |
+
self.decoder.train()
|
| 332 |
+
self.quantize.train()
|
| 333 |
+
self.quant_conv.train()
|
| 334 |
+
self.post_quant_conv.train()
|
| 335 |
+
|
| 336 |
+
loss, d_loss = self.training_step(data, step)
|
| 337 |
+
self.optimizer.zero_grad()
|
| 338 |
+
loss.backward()
|
| 339 |
+
self.optimizer.step()
|
| 340 |
+
|
| 341 |
+
if step > self.disc_start_step:
|
| 342 |
+
self.disc_optimizer.zero_grad()
|
| 343 |
+
d_loss.backward()
|
| 344 |
+
self.disc_optimizer.step()
|
| 345 |
+
|
| 346 |
+
@torch.no_grad()
|
| 347 |
+
def inference(self, data_loader, save_dir):
|
| 348 |
+
self.encoder.eval()
|
| 349 |
+
self.decoder.eval()
|
| 350 |
+
self.quantize.eval()
|
| 351 |
+
self.quant_conv.eval()
|
| 352 |
+
self.post_quant_conv.eval()
|
| 353 |
+
|
| 354 |
+
loss_total = 0
|
| 355 |
+
num = 0
|
| 356 |
+
|
| 357 |
+
for _, data in enumerate(data_loader):
|
| 358 |
+
img_name = data['img_name'][0]
|
| 359 |
+
x = self.feed_data(data)
|
| 360 |
+
xrec, _ = self.forward_step(x)
|
| 361 |
+
|
| 362 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
| 363 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
| 364 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
| 365 |
+
nll_loss = torch.mean(nll_loss)
|
| 366 |
+
loss_total += nll_loss
|
| 367 |
+
|
| 368 |
+
num += x.size(0)
|
| 369 |
+
|
| 370 |
+
if x.shape[1] > 3:
|
| 371 |
+
# colorize with random projection
|
| 372 |
+
assert xrec.shape[1] > 3
|
| 373 |
+
# convert logits to indices
|
| 374 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
| 375 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
| 376 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
| 377 |
+
x = self.to_rgb(x)
|
| 378 |
+
xrec = self.to_rgb(xrec)
|
| 379 |
+
|
| 380 |
+
img_cat = torch.cat([x, xrec], dim=3).detach()
|
| 381 |
+
img_cat = ((img_cat + 1) / 2)
|
| 382 |
+
img_cat = img_cat.clamp_(0, 1)
|
| 383 |
+
save_image(
|
| 384 |
+
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
|
| 385 |
+
|
| 386 |
+
return (loss_total / num).item()
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class VQImageSegmTextureModel(VQImageModel):
|
| 390 |
+
|
| 391 |
+
def __init__(self, opt):
|
| 392 |
+
self.opt = opt
|
| 393 |
+
self.device = torch.device('cuda')
|
| 394 |
+
self.encoder = Encoder(
|
| 395 |
+
ch=opt['ch'],
|
| 396 |
+
num_res_blocks=opt['num_res_blocks'],
|
| 397 |
+
attn_resolutions=opt['attn_resolutions'],
|
| 398 |
+
ch_mult=opt['ch_mult'],
|
| 399 |
+
in_channels=opt['in_channels'],
|
| 400 |
+
resolution=opt['resolution'],
|
| 401 |
+
z_channels=opt['z_channels'],
|
| 402 |
+
double_z=opt['double_z'],
|
| 403 |
+
dropout=opt['dropout']).to(self.device)
|
| 404 |
+
self.decoder = Decoder(
|
| 405 |
+
in_channels=opt['in_channels'],
|
| 406 |
+
resolution=opt['resolution'],
|
| 407 |
+
z_channels=opt['z_channels'],
|
| 408 |
+
ch=opt['ch'],
|
| 409 |
+
out_ch=opt['out_ch'],
|
| 410 |
+
num_res_blocks=opt['num_res_blocks'],
|
| 411 |
+
attn_resolutions=opt['attn_resolutions'],
|
| 412 |
+
ch_mult=opt['ch_mult'],
|
| 413 |
+
dropout=opt['dropout'],
|
| 414 |
+
resamp_with_conv=True,
|
| 415 |
+
give_pre_end=False).to(self.device)
|
| 416 |
+
self.quantize = VectorQuantizerTexture(
|
| 417 |
+
opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device)
|
| 418 |
+
self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'],
|
| 419 |
+
1).to(self.device)
|
| 420 |
+
self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
| 421 |
+
opt["z_channels"],
|
| 422 |
+
1).to(self.device)
|
| 423 |
+
|
| 424 |
+
self.disc = Discriminator(
|
| 425 |
+
opt['n_channels'], opt['ndf'],
|
| 426 |
+
n_layers=opt['disc_layers']).to(self.device)
|
| 427 |
+
self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
|
| 428 |
+
self.perceptual_weight = opt['perceptual_weight']
|
| 429 |
+
self.disc_start_step = opt['disc_start_step']
|
| 430 |
+
self.disc_weight_max = opt['disc_weight_max']
|
| 431 |
+
self.diff_aug = opt['diff_aug']
|
| 432 |
+
self.policy = "color,translation"
|
| 433 |
+
|
| 434 |
+
self.disc.train()
|
| 435 |
+
|
| 436 |
+
self.init_training_settings()
|
| 437 |
+
|
| 438 |
+
def feed_data(self, data):
|
| 439 |
+
x = data['image'].float().to(self.device)
|
| 440 |
+
mask = data['texture_mask'].float().to(self.device)
|
| 441 |
+
|
| 442 |
+
return x, mask
|
| 443 |
+
|
| 444 |
+
def training_step(self, data, step):
|
| 445 |
+
x, mask = self.feed_data(data)
|
| 446 |
+
xrec, codebook_loss = self.forward_step(x, mask)
|
| 447 |
+
|
| 448 |
+
# get recon/perceptual loss
|
| 449 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
| 450 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
| 451 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
| 452 |
+
nll_loss = torch.mean(nll_loss)
|
| 453 |
+
|
| 454 |
+
# augment for input to discriminator
|
| 455 |
+
if self.diff_aug:
|
| 456 |
+
xrec = DiffAugment(xrec, policy=self.policy)
|
| 457 |
+
|
| 458 |
+
# update generator
|
| 459 |
+
logits_fake = self.disc(xrec)
|
| 460 |
+
g_loss = -torch.mean(logits_fake)
|
| 461 |
+
last_layer = self.decoder.conv_out.weight
|
| 462 |
+
d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
|
| 463 |
+
self.disc_weight_max)
|
| 464 |
+
d_weight *= adopt_weight(1, step, self.disc_start_step)
|
| 465 |
+
loss = nll_loss + d_weight * g_loss + codebook_loss
|
| 466 |
+
|
| 467 |
+
self.log_dict["loss"] = loss
|
| 468 |
+
self.log_dict["l1"] = recon_loss.mean().item()
|
| 469 |
+
self.log_dict["perceptual"] = p_loss.mean().item()
|
| 470 |
+
self.log_dict["nll_loss"] = nll_loss.item()
|
| 471 |
+
self.log_dict["g_loss"] = g_loss.item()
|
| 472 |
+
self.log_dict["d_weight"] = d_weight
|
| 473 |
+
self.log_dict["codebook_loss"] = codebook_loss.item()
|
| 474 |
+
|
| 475 |
+
if step > self.disc_start_step:
|
| 476 |
+
if self.diff_aug:
|
| 477 |
+
logits_real = self.disc(
|
| 478 |
+
DiffAugment(x.contiguous().detach(), policy=self.policy))
|
| 479 |
+
else:
|
| 480 |
+
logits_real = self.disc(x.contiguous().detach())
|
| 481 |
+
logits_fake = self.disc(xrec.contiguous().detach(
|
| 482 |
+
)) # detach so that generator isn"t also updated
|
| 483 |
+
d_loss = hinge_d_loss(logits_real, logits_fake)
|
| 484 |
+
self.log_dict["d_loss"] = d_loss
|
| 485 |
+
else:
|
| 486 |
+
d_loss = None
|
| 487 |
+
|
| 488 |
+
return loss, d_loss
|
| 489 |
+
|
| 490 |
+
@torch.no_grad()
|
| 491 |
+
def inference(self, data_loader, save_dir):
|
| 492 |
+
self.encoder.eval()
|
| 493 |
+
self.decoder.eval()
|
| 494 |
+
self.quantize.eval()
|
| 495 |
+
self.quant_conv.eval()
|
| 496 |
+
self.post_quant_conv.eval()
|
| 497 |
+
|
| 498 |
+
loss_total = 0
|
| 499 |
+
num = 0
|
| 500 |
+
|
| 501 |
+
for _, data in enumerate(data_loader):
|
| 502 |
+
img_name = data['img_name'][0]
|
| 503 |
+
x, mask = self.feed_data(data)
|
| 504 |
+
xrec, _ = self.forward_step(x, mask)
|
| 505 |
+
|
| 506 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
| 507 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
| 508 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
| 509 |
+
nll_loss = torch.mean(nll_loss)
|
| 510 |
+
loss_total += nll_loss
|
| 511 |
+
|
| 512 |
+
num += x.size(0)
|
| 513 |
+
|
| 514 |
+
if x.shape[1] > 3:
|
| 515 |
+
# colorize with random projection
|
| 516 |
+
assert xrec.shape[1] > 3
|
| 517 |
+
# convert logits to indices
|
| 518 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
| 519 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
| 520 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
| 521 |
+
x = self.to_rgb(x)
|
| 522 |
+
xrec = self.to_rgb(xrec)
|
| 523 |
+
|
| 524 |
+
img_cat = torch.cat([x, xrec], dim=3).detach()
|
| 525 |
+
img_cat = ((img_cat + 1) / 2)
|
| 526 |
+
img_cat = img_cat.clamp_(0, 1)
|
| 527 |
+
save_image(
|
| 528 |
+
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
|
| 529 |
+
|
| 530 |
+
return (loss_total / num).item()
|
| 531 |
+
|
| 532 |
+
def encode(self, x, mask):
|
| 533 |
+
h = self.encoder(x)
|
| 534 |
+
h = self.quant_conv(h)
|
| 535 |
+
quant, emb_loss, info = self.quantize(h, mask)
|
| 536 |
+
return quant, emb_loss, info
|
| 537 |
+
|
| 538 |
+
def decode(self, quant):
|
| 539 |
+
quant = self.post_quant_conv(quant)
|
| 540 |
+
dec = self.decoder(quant)
|
| 541 |
+
return dec
|
| 542 |
+
|
| 543 |
+
def decode_code(self, code_b):
|
| 544 |
+
quant_b = self.quantize.embed_code(code_b)
|
| 545 |
+
dec = self.decode(quant_b)
|
| 546 |
+
return dec
|
| 547 |
+
|
| 548 |
+
def forward_step(self, input, mask):
|
| 549 |
+
quant, diff, _ = self.encode(input, mask)
|
| 550 |
+
dec = self.decode(quant)
|
| 551 |
+
return dec, diff
|