mahmed10 commited on
Commit
31ca592
Β·
1 Parent(s): 13dd0a8

initail upload

Browse files
Files changed (4) hide show
  1. README.md +185 -0
  2. get_parser.py +121 -0
  3. train.py +381 -0
  4. validation.ipynb +364 -0
README.md ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CAM-Seg: A Continuous-valued Embedding Approach for Semantic Image Generation
2
+
3
+ **Official PyTorch Implementation**
4
+
5
+ This is a PyTorch/GPU implementation of the paper [CAM-Seg: A Continuous-valued Embedding Approach for Semantic Image Generation](https://arxiv.org/abs/2503.15617)
6
+
7
+ ```
8
+ @article{ahmed2025cam,
9
+ title={CAM-Seg: A Continuous-valued Embedding Approach for Semantic Image Generation},
10
+ author={Ahmed, Masud and Hasan, Zahid and Haque, Syed Arefinul and Faridee, Abu Zaher Md and Purushotham, Sanjay and You, Suya and Roy, Nirmalya},
11
+ journal={arXiv preprint arXiv:2503.15617},
12
+ year={2025}
13
+ }
14
+ ```
15
+
16
+ ## Abstract
17
+ Traditional transformer-based semantic segmentation relies on quantized embeddings. However, our analysis reveals that autoencoder accuracy on segmentation mask using quantized embeddings (e.g. VQ-VAE) is 8\% lower than continuous-valued embeddings (e.g. KL-VAE). Motivated by this, we propose a continuous-valued embedding framework for semantic segmentation. By reformulating semantic mask generation as a continuous image-to-embedding diffusion process, our approach eliminates the need for discrete latent representations while preserving fine-grained spatial and semantic details. Our key contribution includes a diffusion-guided autoregressive transformer that learns a continuous semantic embedding space by modeling long-range dependencies in image features. Our framework contains a unified architecture combining a VAE encoder for continuous feature extraction, a diffusion-guided transformer for conditioned embedding generation, and a VAE decoder for semantic mask reconstruction. Our setting facilitates zero-shot domain adaptation capabilities enabled by the continuity of the embedding space. Experiments across diverse datasets (e.g., Cityscapes and domain-shifted variants) demonstrate state-of-the-art robustness to distribution shifts, including adverse weather (e.g., fog, snow) and viewpoint variations. Our model also exhibits strong noise resilience, achieving robust performance ($\approx$ 95\% AP compared to baseline) under gaussian noise, moderate motion blur, and moderate brightness/contrast variations, while experiencing only a moderate impact ($\approx$ 90\% AP compared to baseline) from 50\% salt and pepper noise, saturation and hue shifts.
18
+
19
+ ## Result
20
+ Trained on Cityscape dataset and tested on SemanticKITTI, ACDC, CADEdgeTune dataset
21
+ <p align="center">
22
+ <img src="demo/qualitative.png" width="720">
23
+ </p>
24
+
25
+ Quantitative results of semantic segmentation under various noise conditions
26
+ <p align="center">
27
+ <table>
28
+ <tr>
29
+ <td align="center"><img src="demo/saltpepper_noise.png" width="200"/><br>Salt & Pepper Noise</td>
30
+ <td align="center"><img src="demo/motion_blur.png" width="200"/><br>Motion Blur</td>
31
+ <td align="center"><img src="demo/gaussian_noise.png" width="200"/><br>Gaussian Noise</td>
32
+ <td align="center"><img src="demo/gaussian_blur.png" width="200"/><br>Gaussian Blur</td>
33
+ </tr>
34
+ <tr>
35
+ <td align="center"><img src="demo/brightness.png" width="200"/><br>Brightness Variation</td>
36
+ <td align="center"><img src="demo/contrast.png" width="200"/><br>Contrast Variation</td>
37
+ <td align="center"><img src="demo/saturation.png" width="200"/><br>Saturation Variation</td>
38
+ <td align="center"><img src="demo/hue.png" width="200"/><br>Hue Variation</td>
39
+ </tr>
40
+ </table>
41
+ </p>
42
+
43
+ ## Prerequisite
44
+ To install the docker environment, first edit the `docker_env/Makefile`:
45
+ ```
46
+ IMAGE=img_name/dl-aio
47
+ CONTAINER=containter_name
48
+ AVAILABLE_GPUS='0,1,2,3'
49
+ LOCAL_JUPYTER_PORT=18888
50
+ LOCAL_TENSORBOARD_PORT=18006
51
+ PASSWORD=yourpassword
52
+ WORKSPACE=workspace_directory
53
+ ```
54
+ - Edit the `img_name`, `containter_name`, `available_gpus`, `jupyter_port`, `tensorboard_port`, `password`, `workspace_directory`
55
+
56
+ 1. For the first time run the following commands in terminal:
57
+ ```
58
+ cd docker_env
59
+ make docker-build
60
+ make docker-run
61
+ ```
62
+ 2. or further use to docker environment
63
+ - To stop the environmnet: `make docker-stop`
64
+ - To resume the environmente: `make docker-resume`
65
+
66
+ For coding open a web browser `ip_address:jupyter_port` e.g.,`http://localhost:18888`
67
+
68
+ ## Dataset
69
+ Four Dataset is used in the work
70
+ 1. [Cityscapes Dataset](https://www.cityscapes-dataset.com/)
71
+ 2. [KITTI Dataset](https://www.cvlibs.net/datasets/kitti/eval_step.php)
72
+ 3. [ACDC Dataset](https://acdc.vision.ee.ethz.ch/)
73
+ 4. [CAD-EdgeTune Dataset](https://ieee-dataport.org/documents/cad-edgetune)
74
+
75
+ **Modify the trainlist and vallist file to edit train and test split**
76
+
77
+ ### Dataset structure
78
+ - Cityscapes Dataset
79
+ ```
80
+ |-CityScapes
81
+ |----leftImg8bit #contians the RGB images
82
+ |----gtFine #contains semantic segmentation labels
83
+ |----trainlist.txt #image list used for training
84
+ |----vallist.txt #image list used for testing
85
+ |----cityscape.yaml #configuration file for Cityscapes dataset
86
+ ```
87
+
88
+ - ACDC Dataset
89
+ ```
90
+ |-ACDC
91
+ |----rgb_anon #contians the RGB images
92
+ |----gt #contains semantic segmentation labels
93
+ |----vallist_fog.txt #image list used for testing fog data
94
+ |----vallist_rain.txt #image list used for testing rain data
95
+ |----vallist_snow.txt #image list used for testing snow data
96
+ |----acdc.yaml #configuration file for ACDC dataset
97
+ ```
98
+
99
+ ## Weights
100
+ To download the pretrained weights please visit [Hugging Face Repo](https://huggingface.co/mahmed10/CAM-Seg)
101
+ - **LDM model** Pretrained model from Rombach et al.'s Latent Diffusion Models is used [Link](https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/vae/modelf16.ckpt)
102
+ - **MAR model** Following mar model is used
103
+
104
+ |Training Data|Model|Params|Link|
105
+ |-------------|-----|------|----|
106
+ |Cityscapes | Mar-base| 217M|[link](https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/mar/city768.16.pth)|
107
+
108
+
109
+ Download this weight files and organize as follow
110
+ ```
111
+ |-pretrained_models
112
+ |----mar
113
+ |--------city768.16.pth
114
+ |----vae
115
+ |--------modelf16.ckpt
116
+ ```
117
+
118
+ **Alternative code to automatically download pretrain weights**
119
+ ```
120
+ import os
121
+ import requests
122
+
123
+ # Define URLs and file paths
124
+ files_to_download = {
125
+ "https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/vae/modelf16.ckpt":
126
+ "pretrained_models/vae/modelf16.ckpt",
127
+ "https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/mar/city768.16.pth":
128
+ "pretrained_models/mar/city768.16.pth"
129
+ }
130
+
131
+ for url, path in files_to_download.items():
132
+ os.makedirs(os.path.dirname(path), exist_ok=True)
133
+
134
+ print(f"Downloading from {url}...")
135
+ response = requests.get(url, stream=True)
136
+ if response.status_code == 200:
137
+ with open(path, 'wb') as f:
138
+ for chunk in response.iter_content(chunk_size=8192):
139
+ f.write(chunk)
140
+ print(f"Saved to {path}")
141
+ else:
142
+ print(f"Failed to download from {url}, status code {response.status_code}")
143
+ ```
144
+
145
+ ## Validation
146
+ Open the `validation.ipnyb` file
147
+
148
+ Edit the **Block 6** to select which dataset is to use for validation
149
+
150
+ ```
151
+ dataset_train = cityscapes.CityScapes('dataset/CityScapes/vallist.txt', data_set= 'val', transform=transform_train,seed=36, img_size=768)
152
+ # dataset_train = umbc.UMBC('dataset/UMBC/all.txt', data_set= 'val', transform=transform_train,seed=36, img_size=768)
153
+ # dataset_train = acdc.ACDC('dataset/ACDC/vallist_fog.txt', data_set= 'val', transform=transform_train,seed=36, img_size=768)
154
+ # dataset_train = semantickitti.SemanticKITTI('dataset/SemanticKitti/vallist.txt', data_set= 'val', transform=transform_train, seed=36, img_size=768)
155
+ ```
156
+
157
+ Run all the blocks
158
+
159
+ ## Training
160
+
161
+ ### From Scratch
162
+
163
+ Run the following code in terminal
164
+ ```
165
+ torchrun --nproc_per_node=4 train.py
166
+ ```
167
+
168
+ it will save checkpoint in `output_dir/year.month.day.hour.min` folder, for e.g. `output_dir/2025.05.09.02.27`
169
+
170
+ ### Resume Training
171
+
172
+ Run the following code in terminal
173
+ ```
174
+ torchrun --nproc_per_node=4 train.py --resume year.month.day.hour.min
175
+ ```
176
+
177
+ Here is an example code
178
+ ```
179
+ torchrun --nproc_per_node=4 train.py --resume 2025.05.09.02.27
180
+ ```
181
+
182
+ ## Acknowlegement
183
+ The code is developed on top following codework
184
+ 1. [latent-diffusion](https://github.com/CompVis/latent-diffusion)
185
+ 2. [mar](https://github.com/LTH14/mar)
get_parser.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import yaml
4
+ def get_args_parser():
5
+ parser = argparse.ArgumentParser('MAR training with Diffusion Loss', add_help=False)
6
+ parser.add_argument('--batch_size', default=16, type=int,
7
+ help='Batch size per GPU (effective batch size is batch_size * # gpus')
8
+ parser.add_argument('--epochs', default=2000, type=int)
9
+
10
+ # Model parameters
11
+ parser.add_argument('--model', default='mar_base', type=str, metavar='MODEL',
12
+ help='Name of model to train')
13
+ parser.add_argument('--ckpt_path', default="pretrained_models/mar/city768.16.pth", type=str,
14
+ help='model checkpoint path')
15
+
16
+ # VAE parameters
17
+ parser.add_argument('--img_size', default=768, type=int,
18
+ help='images input size')
19
+ parser.add_argument('--vae_path', default="pretrained_models/vae/modelf16.ckpt", type=str,
20
+ help='images input size')
21
+ parser.add_argument('--vae_embed_dim', default=16, type=int,
22
+ help='vae output embedding dimension')
23
+ parser.add_argument('--vae_stride', default=16, type=int,
24
+ help='tokenizer stride, default use KL16')
25
+ parser.add_argument('--patch_size', default=1, type=int,
26
+ help='number of tokens to group as a patch.')
27
+ parser.add_argument('--config', default="ldm/config.yaml", type=str,
28
+ help='vae model configuration file')
29
+
30
+ # Generation parameters
31
+ parser.add_argument('--num_iter', default=64, type=int,
32
+ help='number of autoregressive iterations to generate an image')
33
+ parser.add_argument('--num_images', default=3000, type=int,
34
+ help='number of images to generate')
35
+ parser.add_argument('--cfg', default=1.0, type=float, help="classifier-free guidance")
36
+ parser.add_argument('--cfg_schedule', default="linear", type=str)
37
+ parser.add_argument('--label_drop_prob', default=0.1, type=float)
38
+ parser.add_argument('--eval_freq', type=int, default=40, help='evaluation frequency')
39
+ parser.add_argument('--save_last_freq', type=int, default=5, help='save last frequency')
40
+ parser.add_argument('--online_eval', action='store_true')
41
+ parser.add_argument('--evaluate', action='store_true')
42
+ parser.add_argument('--eval_bsz', type=int, default=64, help='generation batch size')
43
+
44
+ # Optimizer parameters
45
+ parser.add_argument('--weight_decay', type=float, default=0.02,
46
+ help='weight decay (default: 0.02)')
47
+
48
+ parser.add_argument('--grad_checkpointing', action='store_true')
49
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
50
+ help='learning rate (absolute lr)')
51
+ parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
52
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
53
+ parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
54
+ help='lower lr bound for cyclic schedulers that hit 0')
55
+ parser.add_argument('--lr_schedule', type=str, default='constant',
56
+ help='learning rate schedule')
57
+ parser.add_argument('--warmup_epochs', type=int, default=100, metavar='N',
58
+ help='epochs to warmup LR')
59
+ parser.add_argument('--ema_rate', default=0.9999, type=float)
60
+
61
+ # MAR params
62
+ parser.add_argument('--mask_ratio_min', type=float, default=0.7,
63
+ help='Minimum mask ratio')
64
+ parser.add_argument('--grad_clip', type=float, default=3.0,
65
+ help='Gradient clip')
66
+ parser.add_argument('--attn_dropout', type=float, default=0.1,
67
+ help='attention dropout')
68
+ parser.add_argument('--proj_dropout', type=float, default=0.1,
69
+ help='projection dropout')
70
+ parser.add_argument('--buffer_size', type=int, default=64)
71
+
72
+ # Diffusion Loss params
73
+ parser.add_argument('--diffloss_d', type=int, default=6)
74
+ parser.add_argument('--diffloss_w', type=int, default=1024)
75
+ parser.add_argument('--num_sampling_steps', type=str, default="100")
76
+ parser.add_argument('--diffusion_batch_mul', type=int, default=4)
77
+ parser.add_argument('--temperature', default=1.0, type=float, help='diffusion loss sampling temperature')
78
+
79
+ # Dataset parameters
80
+ parser.add_argument('--output_dir', default='./output_dir',
81
+ help='path where to save, empty for no saving')
82
+ parser.add_argument('--log_dir', default='./output_dir',
83
+ help='path where to tensorboard log')
84
+ parser.add_argument('--device', default='cuda',
85
+ help='device to use for training / testing')
86
+ parser.add_argument('--seed', default=1, type=int)
87
+ parser.add_argument('--resume', default=None,#'pretrained_models/mar/mar_base',
88
+ help='resume from checkpoint')
89
+
90
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
91
+ help='start epoch')
92
+ parser.add_argument('--num_workers', default=10, type=int)
93
+ parser.add_argument('--pin_mem', action='store_true',
94
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
95
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
96
+ parser.set_defaults(pin_mem=True)
97
+
98
+ # distributed training parameters
99
+ parser.add_argument('--world_size', default=1, type=int,
100
+ help='number of distributed processes')
101
+ parser.add_argument('--local_rank', default=-1, type=int)
102
+ parser.add_argument('--dist_on_itp', action='store_true')
103
+ parser.add_argument('--dist_url', default='env://',
104
+ help='url used to set up distributed training')
105
+
106
+ # caching latents
107
+ parser.add_argument('--use_cached', action='store_true', dest='use_cached',
108
+ help='Use cached latents')
109
+ parser.set_defaults(use_cached=False)
110
+ parser.add_argument('--cached_path', default='', help='path to cached latents')
111
+
112
+ return parser
113
+
114
+ args = get_args_parser()
115
+ args = args.parse_args()
116
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
117
+ args.log_dir = args.output_dir
118
+
119
+ with open(args.config, "r") as f:
120
+ config = yaml.safe_load(f)
121
+ args.ddconfig = config["ddconfig"]
train.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import numpy as np
4
+ import os
5
+ import time
6
+ from pathlib import Path
7
+ import yaml
8
+ import glob
9
+
10
+ import torch
11
+ import torch.backends.cudnn as cudnn
12
+ from torch.utils.tensorboard import SummaryWriter
13
+ import torchvision.transforms as transforms
14
+ import torchvision.datasets as datasets
15
+ from data import cityscapes
16
+
17
+ from util.crop import center_crop_arr
18
+ import util.misc as misc
19
+ from util.misc import NativeScalerWithGradNormCount as NativeScaler
20
+ from util.loader import CachedFolder
21
+
22
+ from models.vae import AutoencoderKL
23
+ from models import mar
24
+ import copy
25
+ from tqdm import tqdm
26
+
27
+ import util.lr_sched as lr_sched
28
+
29
+ import logging
30
+
31
+
32
+
33
+ def update_ema(target_params, source_params, rate=0.99):
34
+ """
35
+ Update target parameters to be closer to those of source parameters using
36
+ an exponential moving average.
37
+
38
+ :param target_params: the target parameter sequence.
39
+ :param source_params: the source parameter sequence.
40
+ :param rate: the EMA rate (closer to 1 means slower).
41
+ """
42
+ for targ, src in zip(target_params, source_params):
43
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
44
+
45
+ def logger_file(path):
46
+ logger = logging.getLogger()
47
+ logger.setLevel(logging.DEBUG)
48
+ handler = logging.FileHandler(path,"w", encoding=None, delay="true")
49
+ handler.setLevel(logging.INFO)
50
+ formatter = logging.Formatter("%(message)s")
51
+ handler.setFormatter(formatter)
52
+ logger.addHandler(handler)
53
+ return logger
54
+
55
+
56
+ def get_args_parser():
57
+ parser = argparse.ArgumentParser('MAR training with Diffusion Loss', add_help=False)
58
+ parser.add_argument('--batch_size', default=2, type=int,
59
+ help='Batch size per GPU (effective batch size is batch_size * # gpus')
60
+ parser.add_argument('--epochs', default=2000, type=int)
61
+
62
+ # Model parameters
63
+ parser.add_argument('--model', default='mar_base', type=str, metavar='MODEL',
64
+ help='Name of model to train')
65
+ parser.add_argument('--ckpt_path', default="pretrained_models/mar/city768.16.pth", type=str,
66
+ help='model checkpoint path')
67
+
68
+ # VAE parameters
69
+ parser.add_argument('--img_size', default=768, type=int,
70
+ help='images input size')
71
+ parser.add_argument('--vae_path', default="pretrained_models/vae/modelf16.ckpt", type=str,
72
+ help='images input size')
73
+ parser.add_argument('--vae_embed_dim', default=16, type=int,
74
+ help='vae output embedding dimension')
75
+ parser.add_argument('--vae_stride', default=16, type=int,
76
+ help='tokenizer stride, default use KL16')
77
+ parser.add_argument('--patch_size', default=1, type=int,
78
+ help='number of tokens to group as a patch.')
79
+ parser.add_argument('--config', default="ldm/config.yaml", type=str,
80
+ help='vae model configuration file')
81
+
82
+ # Generation parameters
83
+ parser.add_argument('--num_iter', default=64, type=int,
84
+ help='number of autoregressive iterations to generate an image')
85
+ parser.add_argument('--num_images', default=3000, type=int,
86
+ help='number of images to generate')
87
+ parser.add_argument('--cfg', default=1.0, type=float, help="classifier-free guidance")
88
+ parser.add_argument('--cfg_schedule', default="linear", type=str)
89
+ parser.add_argument('--label_drop_prob', default=0.1, type=float)
90
+ parser.add_argument('--eval_freq', type=int, default=40, help='evaluation frequency')
91
+ parser.add_argument('--save_last_freq', type=int, default=5, help='save last frequency')
92
+ parser.add_argument('--online_eval', action='store_true')
93
+ parser.add_argument('--evaluate', action='store_true')
94
+ parser.add_argument('--eval_bsz', type=int, default=64, help='generation batch size')
95
+
96
+ # Optimizer parameters
97
+ parser.add_argument('--weight_decay', type=float, default=0.02,
98
+ help='weight decay (default: 0.02)')
99
+
100
+ parser.add_argument('--grad_checkpointing', action='store_true')
101
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
102
+ help='learning rate (absolute lr)')
103
+ parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
104
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
105
+ parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
106
+ help='lower lr bound for cyclic schedulers that hit 0')
107
+ parser.add_argument('--lr_schedule', type=str, default='constant',
108
+ help='learning rate schedule')
109
+ parser.add_argument('--warmup_epochs', type=int, default=100, metavar='N',
110
+ help='epochs to warmup LR')
111
+ parser.add_argument('--ema_rate', default=0.9999, type=float)
112
+
113
+ # MAR params
114
+ parser.add_argument('--mask_ratio_min', type=float, default=0.7,
115
+ help='Minimum mask ratio')
116
+ parser.add_argument('--grad_clip', type=float, default=3.0,
117
+ help='Gradient clip')
118
+ parser.add_argument('--attn_dropout', type=float, default=0.1,
119
+ help='attention dropout')
120
+ parser.add_argument('--proj_dropout', type=float, default=0.1,
121
+ help='projection dropout')
122
+ parser.add_argument('--buffer_size', type=int, default=64)
123
+
124
+ # Diffusion Loss params
125
+ parser.add_argument('--diffloss_d', type=int, default=6)
126
+ parser.add_argument('--diffloss_w', type=int, default=1024)
127
+ parser.add_argument('--num_sampling_steps', type=str, default="100")
128
+ parser.add_argument('--diffusion_batch_mul', type=int, default=4)
129
+ parser.add_argument('--temperature', default=1.0, type=float, help='diffusion loss sampling temperature')
130
+
131
+ # Dataset parameters
132
+ parser.add_argument('--output_dir', default='./output_dir',
133
+ help='path where to save, empty for no saving')
134
+ parser.add_argument('--log_dir', default='./output_dir',
135
+ help='path where to tensorboard log')
136
+ parser.add_argument('--device', default='cuda',
137
+ help='device to use for training / testing')
138
+ parser.add_argument('--seed', default=1, type=int)
139
+ parser.add_argument('--resume', default=None,
140
+ help='resume from checkpoint')
141
+
142
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
143
+ help='start epoch')
144
+ parser.add_argument('--num_workers', default=10, type=int)
145
+ parser.add_argument('--pin_mem', action='store_true',
146
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
147
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
148
+ parser.set_defaults(pin_mem=True)
149
+
150
+ # distributed training parameters
151
+ parser.add_argument('--world_size', default=1, type=int,
152
+ help='number of distributed processes')
153
+ parser.add_argument('--local_rank', default=-1, type=int)
154
+ parser.add_argument('--dist_on_itp', action='store_true')
155
+ parser.add_argument('--dist_url', default='env://',
156
+ help='url used to set up distributed training')
157
+
158
+ # caching latents
159
+ parser.add_argument('--use_cached', action='store_true', dest='use_cached',
160
+ help='Use cached latents')
161
+ parser.set_defaults(use_cached=False)
162
+ parser.add_argument('--cached_path', default='', help='path to cached latents')
163
+
164
+ return parser
165
+
166
+
167
+ def main(args):
168
+ misc.init_distributed_mode(args)
169
+
170
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
171
+ print("{}".format(args).replace(', ', ',\n'))
172
+
173
+ device = torch.device(args.device)
174
+
175
+ # fix the seed for reproducibility
176
+ seed = args.seed + misc.get_rank()
177
+ torch.manual_seed(seed)
178
+ np.random.seed(seed)
179
+
180
+ cudnn.benchmark = True
181
+
182
+ num_tasks = misc.get_world_size()
183
+ global_rank = misc.get_rank()
184
+
185
+ log_writer = None
186
+
187
+ # augmentation following DiT and ADM
188
+ transform_train = transforms.Compose([
189
+ transforms.ToTensor(),
190
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
191
+ ])
192
+
193
+ dataset_train = cityscapes.CityScapes('dataset/CityScapes/trainlist.txt', transform=transform_train, img_size=args.img_size)
194
+
195
+ sampler_train = torch.utils.data.DistributedSampler(
196
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
197
+ )
198
+ print("Sampler_train = %s" % str(sampler_train))
199
+
200
+ data_loader_train = torch.utils.data.DataLoader(
201
+ dataset_train, sampler=sampler_train,
202
+ batch_size=args.batch_size,
203
+ num_workers=args.num_workers,
204
+ pin_memory=args.pin_mem,
205
+ drop_last=True,
206
+ )
207
+
208
+ # define the vae and mar model
209
+ with open(args.config, "r") as f:
210
+ config = yaml.safe_load(f)
211
+ args.ddconfig = config["ddconfig"]
212
+ print('cofig: ', config)
213
+
214
+ vae = AutoencoderKL(
215
+ ddconfig=args.ddconfig,
216
+ embed_dim=args.vae_embed_dim,
217
+ ckpt_path=args.vae_path
218
+ ).cuda().eval()
219
+
220
+ for param in vae.parameters():
221
+ param.requires_grad = False
222
+
223
+ model = mar.__dict__[args.model](
224
+ img_size=args.img_size,
225
+ vae_stride=args.vae_stride,
226
+ patch_size=args.patch_size,
227
+ vae_embed_dim=args.vae_embed_dim,
228
+ mask_ratio_min=args.mask_ratio_min,
229
+ label_drop_prob=args.label_drop_prob,
230
+ attn_dropout=args.attn_dropout,
231
+ proj_dropout=args.proj_dropout,
232
+ buffer_size=args.buffer_size,
233
+ diffloss_d=args.diffloss_d,
234
+ diffloss_w=args.diffloss_w,
235
+ num_sampling_steps=args.num_sampling_steps,
236
+ diffusion_batch_mul=args.diffusion_batch_mul,
237
+ grad_checkpointing=args.grad_checkpointing,
238
+ )
239
+
240
+ if args.ckpt_path:
241
+ checkpoint = torch.load(args.ckpt_path, map_location='cpu')
242
+ model.load_state_dict(checkpoint['model'])
243
+
244
+ print("Model = %s" % str(model))
245
+ # following timm: set wd as 0 for bias and norm layers
246
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
247
+ print("Number of trainable parameters: {}M".format(n_params / 1e6))
248
+
249
+ model.to(device)
250
+ model_without_ddp = model
251
+
252
+ eff_batch_size = args.batch_size * misc.get_world_size()
253
+
254
+ if args.lr is None: # only base_lr is specified
255
+ args.lr = args.blr
256
+
257
+ print("base lr: %.2e" % args.blr)
258
+ print("actual lr: %.2e" % args.lr)
259
+ print("effective batch size: %d" % eff_batch_size)
260
+
261
+ if args.distributed:
262
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
263
+ model_without_ddp = model.module
264
+
265
+ # no weight decay on bias, norm layers, and diffloss MLP
266
+ param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay)
267
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
268
+ print(optimizer)
269
+ loss_scaler = NativeScaler()
270
+
271
+ # resume training
272
+ if args.resume and glob.glob(os.path.join(args.output_dir, args.resume, 'checkpoint*.pth')):
273
+ try:
274
+ checkpoint = torch.load(sorted(glob.glob(os.path.join(args.output_dir, args.resume, 'checkpoint*.pth')))[-1], map_location='cpu')
275
+ model.load_state_dict(checkpoint['model'])
276
+ except:
277
+ checkpoint = torch.load(sorted(glob.glob(os.path.join(args.output_dir, args.resume, 'checkpoint*.pth')))[-2], map_location='cpu')
278
+ model.load_state_dict(checkpoint['model'])
279
+ state_dict = {key.replace("module.", ""): value for key, value in checkpoint['model'].items()}
280
+ model_without_ddp.load_state_dict(state_dict)
281
+ model_params = list(model_without_ddp.parameters())
282
+ ema_params = copy.deepcopy(model_params)
283
+ ema_state_dict = {key.replace("module.", ""): value for key, value in checkpoint['model_ema'].items()}
284
+ ema_params = [ema_state_dict[name].cuda() for name, _ in model_without_ddp.named_parameters()]
285
+ print("Resume checkpoint %s" % args.resume)
286
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint:
287
+ optimizer.load_state_dict(checkpoint['optimizer'])
288
+ args.start_epoch = checkpoint['epoch'] + 1
289
+ if 'scaler' in checkpoint:
290
+ loss_scaler.load_state_dict(checkpoint['scaler'])
291
+ print("With optim & sched!")
292
+ del checkpoint
293
+
294
+ args.output_dir = os.path.join(args.output_dir, args.resume)
295
+
296
+ logger = logger_file(args.log_dir+'/'+args.resume+'.log')
297
+ if os.path.exists(args.log_dir+'/'+args.resume+'.log'):
298
+ with open(args.log_dir+'/'+args.resume+'.log', 'r') as infile:
299
+ for line in infile:
300
+ logger.info(line.rstrip())
301
+ else:
302
+ logger.info("All the arguments")
303
+ for k, v in vars(args).items():
304
+ logger.info(f"{k}: {v}")
305
+ logger.info("\n\n Loss information")
306
+
307
+
308
+
309
+ else:
310
+ model_params = list(model_without_ddp.parameters())
311
+ ema_params = copy.deepcopy(model_params)
312
+ print("Training from scratch")
313
+ args.resume = datetime.datetime.now().strftime("%Y.%m.%d.%H.%M")
314
+ args.output_dir = os.path.join(args.output_dir, args.resume)
315
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
316
+
317
+ logger = logger_file(args.log_dir+'/'+args.resume+'.log')
318
+ logger.info("All the arguments")
319
+ for k, v in vars(args).items():
320
+ logger.info(f"{k}: {v}")
321
+ logger.info("\n\n Loss information")
322
+
323
+
324
+ print(f"Start training for {args.epochs} epochs")
325
+ start_time = time.time()
326
+ for epoch in range(args.start_epoch, args.epochs):
327
+ if args.distributed:
328
+ data_loader_train.sampler.set_epoch(epoch)
329
+
330
+
331
+
332
+ for epoch in tqdm(range(args.start_epoch, args.epochs), desc="Training Progress"):
333
+ model.train(True)
334
+ metric_logger = misc.MetricLogger(delimiter=" ")
335
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
336
+ header = 'Epoch: [{}]'.format(epoch)
337
+ print_freq = 20
338
+
339
+ optimizer.zero_grad()
340
+
341
+ for data_iter_step, (samples, labels, _) in enumerate(data_loader_train):
342
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_train) + epoch, args)
343
+ samples = samples.to(device, non_blocking=True)
344
+ labels = labels.to(device, non_blocking=True)
345
+
346
+ with torch.no_grad():
347
+ posterior_x = vae.encode(samples)
348
+ posterior_y = vae.encode(labels)
349
+ x = posterior_x.sample().mul_(0.2325)
350
+ y = posterior_y.sample().mul_(0.2325)
351
+ with torch.cuda.amp.autocast():
352
+ loss = model(x,y)
353
+ loss_value = loss.item()
354
+ loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), update_grad=True)
355
+ optimizer.zero_grad()
356
+ torch.cuda.synchronize()
357
+
358
+ update_ema(ema_params, model_params, rate=args.ema_rate)
359
+ metric_logger.update(loss=loss_value)
360
+
361
+ lr = optimizer.param_groups[0]["lr"]
362
+ metric_logger.update(lr=lr)
363
+
364
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
365
+ metric_logger.synchronize_between_processes()
366
+ logger.info(f"epoch: {epoch:4d}, Averaged stats: {metric_logger}")
367
+ if (epoch+1)% args.save_last_freq == 0:
368
+ misc.save_model(args=args, model=model, model_without_ddp=model, optimizer=optimizer,
369
+ loss_scaler=loss_scaler, epoch=epoch, ema_params=ema_params, epoch_name=str(epoch).zfill(5))
370
+
371
+ total_time = time.time() - start_time
372
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
373
+ print('Training time {}'.format(total_time_str))
374
+
375
+
376
+ if __name__ == '__main__':
377
+ args = get_args_parser()
378
+ args = args.parse_args()
379
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
380
+ Path(args.log_dir).mkdir(parents=True, exist_ok=True)
381
+ main(args)
validation.ipynb ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "c524f796-e657-4a59-abcf-540531a38995",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "%run get_parser.py"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "4c1cf01e-8229-4d28-bcb2-01c07fa641c2",
19
+ "metadata": {
20
+ "tags": []
21
+ },
22
+ "outputs": [],
23
+ "source": [
24
+ "import os\n",
25
+ "import requests\n",
26
+ "\n",
27
+ "# Define URLs and file paths\n",
28
+ "files_to_download = {\n",
29
+ " \"https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/vae/modelf16.ckpt\":\n",
30
+ " \"pretrained_models/vae/modelf16.ckpt\",\n",
31
+ " \"https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/mar/city768.16.pth\":\n",
32
+ " \"pretrained_models/mar/city768.16.pth\"\n",
33
+ "}\n",
34
+ "\n",
35
+ "for url, path in files_to_download.items():\n",
36
+ " os.makedirs(os.path.dirname(path), exist_ok=True)\n",
37
+ " \n",
38
+ " if os.path.exists(path):\n",
39
+ " print(f\"File already exists: {path} β€” skipping download.\")\n",
40
+ " continue\n",
41
+ "\n",
42
+ " print(f\"Downloading from {url}...\")\n",
43
+ " response = requests.get(url, stream=True)\n",
44
+ " if response.status_code == 200:\n",
45
+ " with open(path, 'wb') as f:\n",
46
+ " for chunk in response.iter_content(chunk_size=8192):\n",
47
+ " f.write(chunk)\n",
48
+ " print(f\"Saved to {path}\")\n",
49
+ " else:\n",
50
+ " print(f\"Failed to download from {url}, status code {response.status_code}\")"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 3,
56
+ "id": "3a7ac93b-1cbc-45f3-8ec5-8e8257a39786",
57
+ "metadata": {
58
+ "tags": []
59
+ },
60
+ "outputs": [],
61
+ "source": [
62
+ "import numpy as np\n",
63
+ "from tqdm import tqdm\n",
64
+ "from PIL import Image\n",
65
+ "import yaml\n",
66
+ "import math\n",
67
+ "\n",
68
+ "import torch\n",
69
+ "import torch.backends.cudnn as cudnn\n",
70
+ "import torchvision.transforms as transforms\n",
71
+ "\n",
72
+ "from data import cityscapes\n",
73
+ "import util.misc as misc\n",
74
+ "\n",
75
+ "from models.vae import AutoencoderKL\n",
76
+ "from models import mar"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 4,
82
+ "id": "e2bde6fd-9b39-40fd-8d4d-d0a5f9c8217a",
83
+ "metadata": {
84
+ "tags": []
85
+ },
86
+ "outputs": [],
87
+ "source": [
88
+ "def mask_by_order(mask_len, order, bsz, seq_len):\n",
89
+ " masking = torch.zeros(bsz, seq_len).cuda()\n",
90
+ " masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()\n",
91
+ " return masking\n",
92
+ "\n",
93
+ "def fast_hist(pred, label, n):\n",
94
+ " k = (label >= 0) & (label < n)\n",
95
+ " bin_count = np.bincount(\n",
96
+ " n * label[k].astype(int) + pred[k], minlength=n ** 2)\n",
97
+ " return bin_count[:n ** 2].reshape(n, n)\n",
98
+ "\n",
99
+ "color_pallete = np.round(np.array([\n",
100
+ " 0, 0, 0,\n",
101
+ " 128, 64, 128,\n",
102
+ " 244, 35, 232,\n",
103
+ " 70, 70, 70,\n",
104
+ " 102, 102, 156,\n",
105
+ " 190, 153, 153,\n",
106
+ " 153, 153, 153,\n",
107
+ " 250, 170, 30,\n",
108
+ " 220, 220, 0,\n",
109
+ " 107, 142, 35,\n",
110
+ " 152, 251, 152,\n",
111
+ " 0, 130, 180,\n",
112
+ " 220, 20, 60,\n",
113
+ " 255, 0, 0,\n",
114
+ " 0, 0, 142,\n",
115
+ " 0, 0, 70,\n",
116
+ " 0, 60, 100,\n",
117
+ " 0, 80, 100,\n",
118
+ " 0, 0, 230,\n",
119
+ " 119, 11, 32,\n",
120
+ " ])/255.0, 4)\n",
121
+ "\n",
122
+ "color_pallete = color_pallete.reshape(-1, 3)"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": 5,
128
+ "id": "c189ac7b-ccff-4745-af56-460ec88770b4",
129
+ "metadata": {
130
+ "tags": []
131
+ },
132
+ "outputs": [],
133
+ "source": [
134
+ "device = torch.device(args.device)\n",
135
+ "device = torch.device('cuda:0')\n",
136
+ "args.batch_size = 1\n",
137
+ "\n",
138
+ "# fix the seed for reproducibility\n",
139
+ "seed = args.seed + misc.get_rank()\n",
140
+ "torch.manual_seed(seed)\n",
141
+ "np.random.seed(seed)\n",
142
+ "\n",
143
+ "cudnn.benchmark = True\n",
144
+ "\n",
145
+ "num_tasks = misc.get_world_size()\n",
146
+ "global_rank = misc.get_rank()"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": 6,
152
+ "id": "28d13453-a3ac-4d2e-8906-0c179e85c2f9",
153
+ "metadata": {
154
+ "tags": []
155
+ },
156
+ "outputs": [],
157
+ "source": [
158
+ "transform_train = transforms.Compose([\n",
159
+ " transforms.ToTensor(),\n",
160
+ " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
161
+ "])\n",
162
+ "\n",
163
+ "dataset_train = cityscapes.CityScapes('dataset/CityScapes/vallist.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n",
164
+ "# dataset_train = umbc.UMBC('dataset/UMBC/all.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n",
165
+ "# dataset_train = acdc.ACDC('dataset/ACDC/vallist_fog.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n",
166
+ "# dataset_train = semantickitti.SemanticKITTI('dataset/SemanticKitti/vallist.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n",
167
+ "\n",
168
+ "\n",
169
+ "sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=1, rank=0, shuffle=False)\n",
170
+ "\n",
171
+ "data_loader_train = torch.utils.data.DataLoader(\n",
172
+ " dataset_train, sampler=sampler_train,\n",
173
+ " batch_size=args.batch_size,\n",
174
+ " num_workers=args.num_workers,\n",
175
+ " pin_memory=args.pin_mem,\n",
176
+ " drop_last=True,\n",
177
+ ")"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "id": "2e22d231-02db-4586-b489-01a97314aed9",
184
+ "metadata": {
185
+ "tags": []
186
+ },
187
+ "outputs": [],
188
+ "source": [
189
+ "vae = AutoencoderKL(\n",
190
+ " ddconfig=args.ddconfig,\n",
191
+ " embed_dim=args.vae_embed_dim,\n",
192
+ " ckpt_path=args.vae_path\n",
193
+ ").to(device).eval()\n",
194
+ "\n",
195
+ "for param in vae.parameters():\n",
196
+ " param.requires_grad = False\n",
197
+ " \n",
198
+ "model = mar.mar_base(\n",
199
+ " img_size=args.img_size,\n",
200
+ " vae_stride=args.vae_stride,\n",
201
+ " patch_size=args.patch_size,\n",
202
+ " vae_embed_dim=args.vae_embed_dim,\n",
203
+ " mask_ratio_min=args.mask_ratio_min,\n",
204
+ " label_drop_prob=args.label_drop_prob,\n",
205
+ " attn_dropout=args.attn_dropout,\n",
206
+ " proj_dropout=args.proj_dropout,\n",
207
+ " buffer_size=args.buffer_size,\n",
208
+ " diffloss_d=args.diffloss_d,\n",
209
+ " diffloss_w=args.diffloss_w,\n",
210
+ " num_sampling_steps=args.num_sampling_steps,\n",
211
+ " diffusion_batch_mul=args.diffusion_batch_mul,\n",
212
+ " grad_checkpointing=args.grad_checkpointing,\n",
213
+ ")\n",
214
+ "\n",
215
+ "n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
216
+ "print(\"Number of trainable parameters: {}M\".format(n_params / 1e6))\n",
217
+ "\n",
218
+ "\n",
219
+ "checkpoint = torch.load(args.ckpt_path, map_location='cpu')\n",
220
+ "model.load_state_dict(checkpoint['model'])\n",
221
+ "model.to(device)\n",
222
+ "\n",
223
+ "eff_batch_size = args.batch_size * misc.get_world_size()\n",
224
+ "\n",
225
+ "print(\"effective batch size: %d\" % eff_batch_size)"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": 8,
231
+ "id": "4c83c0eb-35a5-4241-b869-d52eb6cd31e0",
232
+ "metadata": {
233
+ "tags": []
234
+ },
235
+ "outputs": [
236
+ {
237
+ "name": "stderr",
238
+ "output_type": "stream",
239
+ "text": [
240
+ "Training Progress: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 500/500 [13:11<00:00, 1.58s/it]"
241
+ ]
242
+ },
243
+ {
244
+ "name": "stdout",
245
+ "output_type": "stream",
246
+ "text": [
247
+ "road : 98.06\n",
248
+ "sidewalk : 86.32\n",
249
+ "building : 89.23\n",
250
+ "wall : 47.44\n",
251
+ "fence : 43.78\n",
252
+ "pole : 60.14\n",
253
+ "tlight : 63.16\n",
254
+ "tsign : 82.48\n",
255
+ "vtation : 92.72\n",
256
+ "terrain : 80.45\n",
257
+ "sky : 95.99\n",
258
+ "person : 70.83\n",
259
+ "rider : 64.25\n",
260
+ "car : 94.06\n",
261
+ "truck : 44.90\n",
262
+ "bus : 66.81\n",
263
+ "train : 44.04\n",
264
+ "motorcycle : 47.34\n",
265
+ "bicycle : 62.50\n",
266
+ "Avg Pre : 70.24\n"
267
+ ]
268
+ },
269
+ {
270
+ "name": "stderr",
271
+ "output_type": "stream",
272
+ "text": [
273
+ "\n"
274
+ ]
275
+ }
276
+ ],
277
+ "source": [
278
+ "hist = []\n",
279
+ "model.eval()\n",
280
+ "for data_iter_step, (samples, labels, path) in enumerate(tqdm(data_loader_train, desc=\"Training Progress\")):\n",
281
+ " samples = samples.to(device, non_blocking=True)\n",
282
+ " labels = labels.to(device, non_blocking=True)\n",
283
+ "\n",
284
+ " with torch.no_grad():\n",
285
+ " posterior_x = vae.encode(samples)\n",
286
+ " posterior_y = vae.encode(labels)\n",
287
+ " x = posterior_x.sample().mul_(0.2325)\n",
288
+ " y = posterior_y.sample().mul_(0.2325)\n",
289
+ " x = model.patchify(x)\n",
290
+ " y = model.patchify(y)\n",
291
+ " gt_latents = y.clone().detach()\n",
292
+ " cfg_iter = 1.0\n",
293
+ " temperature = 1.0\n",
294
+ " mask_actual = torch.cat([torch.zeros(args.batch_size, model.seq_len), torch.ones(args.batch_size, model.seq_len)], dim=1).cuda()\n",
295
+ " tokens = torch.zeros(args.batch_size, model.seq_len, model.token_embed_dim).cuda()\n",
296
+ "\n",
297
+ " with torch.no_grad():\n",
298
+ " x1 = model.forward_mae_encoder(x, mask_actual, tokens)\n",
299
+ " z = model.forward_mae_decoder(x1, mask_actual)\n",
300
+ " z = z[0]\n",
301
+ " sampled_token_latent = model.diffloss.sample(z, temperature, cfg_iter)\n",
302
+ "\n",
303
+ " tokens[0] = sampled_token_latent[model.seq_len:]\n",
304
+ " tokens = model.unpatchify(tokens)\n",
305
+ " \n",
306
+ " sampled_images = vae.decode(tokens / 0.2325)\n",
307
+ " \n",
308
+ " image_tensor = labels[0] \n",
309
+ " image_tensor = image_tensor * 0.5 + 0.5\n",
310
+ " gt_np = image_tensor.permute(1, 2, 0).cpu().numpy()\n",
311
+ " H, W, _ = gt_np.shape\n",
312
+ " pixels = gt_np.reshape(-1, 3)\n",
313
+ " distances = np.linalg.norm(pixels[:, None, :] - color_pallete[None, :, :], axis=2)\n",
314
+ " output = np.argmin(distances, axis=1)\n",
315
+ " gt = output.reshape(H, W)\n",
316
+ " \n",
317
+ " image_tensor = sampled_images[0]\n",
318
+ " image_tensor = image_tensor * 0.5 + 0.5 \n",
319
+ " ss_np = image_tensor.permute(1, 2, 0).cpu().numpy()\n",
320
+ " H, W, _ = ss_np.shape\n",
321
+ " pixels = ss_np.reshape(-1, 3)\n",
322
+ " distances = np.linalg.norm(pixels[:, None, :] - color_pallete[None, :, :], axis=2)\n",
323
+ " output = np.argmin(distances, axis=1)\n",
324
+ " output = output.reshape(H, W)\n",
325
+ " \n",
326
+ " hist.append(fast_hist(output.reshape(-1), gt.reshape(-1), 20))\n",
327
+ "\n",
328
+ "cm = np.sum(hist, axis=0)\n",
329
+ "\n",
330
+ "epsilon = 1e-10\n",
331
+ "class_precision = np.diag(cm[1:,1:]) / (np.sum(cm[1:,1:], axis=0) + epsilon)\n",
332
+ "class_names = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'tlight', 'tsign', \n",
333
+ " 'vtation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \n",
334
+ " 'motorcycle', 'bicycle']\n",
335
+ "\n",
336
+ "for i in range(len(class_names)):\n",
337
+ " print(f\"{class_names[i]:<12}: {class_precision[i]*100:6.2f}\")\n",
338
+ "average_precision = np.mean(class_precision)\n",
339
+ "print(f\"{'Avg Pre':<12}: {average_precision*100:6.2f}\")"
340
+ ]
341
+ }
342
+ ],
343
+ "metadata": {
344
+ "kernelspec": {
345
+ "display_name": "Python 3 (ipykernel)",
346
+ "language": "python",
347
+ "name": "python3"
348
+ },
349
+ "language_info": {
350
+ "codemirror_mode": {
351
+ "name": "ipython",
352
+ "version": 3
353
+ },
354
+ "file_extension": ".py",
355
+ "mimetype": "text/x-python",
356
+ "name": "python",
357
+ "nbconvert_exporter": "python",
358
+ "pygments_lexer": "ipython3",
359
+ "version": "3.8.10"
360
+ }
361
+ },
362
+ "nbformat": 4,
363
+ "nbformat_minor": 5
364
+ }