initail upload
Browse files- README.md +185 -0
- get_parser.py +121 -0
- train.py +381 -0
- validation.ipynb +364 -0
README.md
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CAM-Seg: A Continuous-valued Embedding Approach for Semantic Image Generation
|
2 |
+
|
3 |
+
**Official PyTorch Implementation**
|
4 |
+
|
5 |
+
This is a PyTorch/GPU implementation of the paper [CAM-Seg: A Continuous-valued Embedding Approach for Semantic Image Generation](https://arxiv.org/abs/2503.15617)
|
6 |
+
|
7 |
+
```
|
8 |
+
@article{ahmed2025cam,
|
9 |
+
title={CAM-Seg: A Continuous-valued Embedding Approach for Semantic Image Generation},
|
10 |
+
author={Ahmed, Masud and Hasan, Zahid and Haque, Syed Arefinul and Faridee, Abu Zaher Md and Purushotham, Sanjay and You, Suya and Roy, Nirmalya},
|
11 |
+
journal={arXiv preprint arXiv:2503.15617},
|
12 |
+
year={2025}
|
13 |
+
}
|
14 |
+
```
|
15 |
+
|
16 |
+
## Abstract
|
17 |
+
Traditional transformer-based semantic segmentation relies on quantized embeddings. However, our analysis reveals that autoencoder accuracy on segmentation mask using quantized embeddings (e.g. VQ-VAE) is 8\% lower than continuous-valued embeddings (e.g. KL-VAE). Motivated by this, we propose a continuous-valued embedding framework for semantic segmentation. By reformulating semantic mask generation as a continuous image-to-embedding diffusion process, our approach eliminates the need for discrete latent representations while preserving fine-grained spatial and semantic details. Our key contribution includes a diffusion-guided autoregressive transformer that learns a continuous semantic embedding space by modeling long-range dependencies in image features. Our framework contains a unified architecture combining a VAE encoder for continuous feature extraction, a diffusion-guided transformer for conditioned embedding generation, and a VAE decoder for semantic mask reconstruction. Our setting facilitates zero-shot domain adaptation capabilities enabled by the continuity of the embedding space. Experiments across diverse datasets (e.g., Cityscapes and domain-shifted variants) demonstrate state-of-the-art robustness to distribution shifts, including adverse weather (e.g., fog, snow) and viewpoint variations. Our model also exhibits strong noise resilience, achieving robust performance ($\approx$ 95\% AP compared to baseline) under gaussian noise, moderate motion blur, and moderate brightness/contrast variations, while experiencing only a moderate impact ($\approx$ 90\% AP compared to baseline) from 50\% salt and pepper noise, saturation and hue shifts.
|
18 |
+
|
19 |
+
## Result
|
20 |
+
Trained on Cityscape dataset and tested on SemanticKITTI, ACDC, CADEdgeTune dataset
|
21 |
+
<p align="center">
|
22 |
+
<img src="demo/qualitative.png" width="720">
|
23 |
+
</p>
|
24 |
+
|
25 |
+
Quantitative results of semantic segmentation under various noise conditions
|
26 |
+
<p align="center">
|
27 |
+
<table>
|
28 |
+
<tr>
|
29 |
+
<td align="center"><img src="demo/saltpepper_noise.png" width="200"/><br>Salt & Pepper Noise</td>
|
30 |
+
<td align="center"><img src="demo/motion_blur.png" width="200"/><br>Motion Blur</td>
|
31 |
+
<td align="center"><img src="demo/gaussian_noise.png" width="200"/><br>Gaussian Noise</td>
|
32 |
+
<td align="center"><img src="demo/gaussian_blur.png" width="200"/><br>Gaussian Blur</td>
|
33 |
+
</tr>
|
34 |
+
<tr>
|
35 |
+
<td align="center"><img src="demo/brightness.png" width="200"/><br>Brightness Variation</td>
|
36 |
+
<td align="center"><img src="demo/contrast.png" width="200"/><br>Contrast Variation</td>
|
37 |
+
<td align="center"><img src="demo/saturation.png" width="200"/><br>Saturation Variation</td>
|
38 |
+
<td align="center"><img src="demo/hue.png" width="200"/><br>Hue Variation</td>
|
39 |
+
</tr>
|
40 |
+
</table>
|
41 |
+
</p>
|
42 |
+
|
43 |
+
## Prerequisite
|
44 |
+
To install the docker environment, first edit the `docker_env/Makefile`:
|
45 |
+
```
|
46 |
+
IMAGE=img_name/dl-aio
|
47 |
+
CONTAINER=containter_name
|
48 |
+
AVAILABLE_GPUS='0,1,2,3'
|
49 |
+
LOCAL_JUPYTER_PORT=18888
|
50 |
+
LOCAL_TENSORBOARD_PORT=18006
|
51 |
+
PASSWORD=yourpassword
|
52 |
+
WORKSPACE=workspace_directory
|
53 |
+
```
|
54 |
+
- Edit the `img_name`, `containter_name`, `available_gpus`, `jupyter_port`, `tensorboard_port`, `password`, `workspace_directory`
|
55 |
+
|
56 |
+
1. For the first time run the following commands in terminal:
|
57 |
+
```
|
58 |
+
cd docker_env
|
59 |
+
make docker-build
|
60 |
+
make docker-run
|
61 |
+
```
|
62 |
+
2. or further use to docker environment
|
63 |
+
- To stop the environmnet: `make docker-stop`
|
64 |
+
- To resume the environmente: `make docker-resume`
|
65 |
+
|
66 |
+
For coding open a web browser `ip_address:jupyter_port` e.g.,`http://localhost:18888`
|
67 |
+
|
68 |
+
## Dataset
|
69 |
+
Four Dataset is used in the work
|
70 |
+
1. [Cityscapes Dataset](https://www.cityscapes-dataset.com/)
|
71 |
+
2. [KITTI Dataset](https://www.cvlibs.net/datasets/kitti/eval_step.php)
|
72 |
+
3. [ACDC Dataset](https://acdc.vision.ee.ethz.ch/)
|
73 |
+
4. [CAD-EdgeTune Dataset](https://ieee-dataport.org/documents/cad-edgetune)
|
74 |
+
|
75 |
+
**Modify the trainlist and vallist file to edit train and test split**
|
76 |
+
|
77 |
+
### Dataset structure
|
78 |
+
- Cityscapes Dataset
|
79 |
+
```
|
80 |
+
|-CityScapes
|
81 |
+
|----leftImg8bit #contians the RGB images
|
82 |
+
|----gtFine #contains semantic segmentation labels
|
83 |
+
|----trainlist.txt #image list used for training
|
84 |
+
|----vallist.txt #image list used for testing
|
85 |
+
|----cityscape.yaml #configuration file for Cityscapes dataset
|
86 |
+
```
|
87 |
+
|
88 |
+
- ACDC Dataset
|
89 |
+
```
|
90 |
+
|-ACDC
|
91 |
+
|----rgb_anon #contians the RGB images
|
92 |
+
|----gt #contains semantic segmentation labels
|
93 |
+
|----vallist_fog.txt #image list used for testing fog data
|
94 |
+
|----vallist_rain.txt #image list used for testing rain data
|
95 |
+
|----vallist_snow.txt #image list used for testing snow data
|
96 |
+
|----acdc.yaml #configuration file for ACDC dataset
|
97 |
+
```
|
98 |
+
|
99 |
+
## Weights
|
100 |
+
To download the pretrained weights please visit [Hugging Face Repo](https://huggingface.co/mahmed10/CAM-Seg)
|
101 |
+
- **LDM model** Pretrained model from Rombach et al.'s Latent Diffusion Models is used [Link](https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/vae/modelf16.ckpt)
|
102 |
+
- **MAR model** Following mar model is used
|
103 |
+
|
104 |
+
|Training Data|Model|Params|Link|
|
105 |
+
|-------------|-----|------|----|
|
106 |
+
|Cityscapes | Mar-base| 217M|[link](https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/mar/city768.16.pth)|
|
107 |
+
|
108 |
+
|
109 |
+
Download this weight files and organize as follow
|
110 |
+
```
|
111 |
+
|-pretrained_models
|
112 |
+
|----mar
|
113 |
+
|--------city768.16.pth
|
114 |
+
|----vae
|
115 |
+
|--------modelf16.ckpt
|
116 |
+
```
|
117 |
+
|
118 |
+
**Alternative code to automatically download pretrain weights**
|
119 |
+
```
|
120 |
+
import os
|
121 |
+
import requests
|
122 |
+
|
123 |
+
# Define URLs and file paths
|
124 |
+
files_to_download = {
|
125 |
+
"https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/vae/modelf16.ckpt":
|
126 |
+
"pretrained_models/vae/modelf16.ckpt",
|
127 |
+
"https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/mar/city768.16.pth":
|
128 |
+
"pretrained_models/mar/city768.16.pth"
|
129 |
+
}
|
130 |
+
|
131 |
+
for url, path in files_to_download.items():
|
132 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
133 |
+
|
134 |
+
print(f"Downloading from {url}...")
|
135 |
+
response = requests.get(url, stream=True)
|
136 |
+
if response.status_code == 200:
|
137 |
+
with open(path, 'wb') as f:
|
138 |
+
for chunk in response.iter_content(chunk_size=8192):
|
139 |
+
f.write(chunk)
|
140 |
+
print(f"Saved to {path}")
|
141 |
+
else:
|
142 |
+
print(f"Failed to download from {url}, status code {response.status_code}")
|
143 |
+
```
|
144 |
+
|
145 |
+
## Validation
|
146 |
+
Open the `validation.ipnyb` file
|
147 |
+
|
148 |
+
Edit the **Block 6** to select which dataset is to use for validation
|
149 |
+
|
150 |
+
```
|
151 |
+
dataset_train = cityscapes.CityScapes('dataset/CityScapes/vallist.txt', data_set= 'val', transform=transform_train,seed=36, img_size=768)
|
152 |
+
# dataset_train = umbc.UMBC('dataset/UMBC/all.txt', data_set= 'val', transform=transform_train,seed=36, img_size=768)
|
153 |
+
# dataset_train = acdc.ACDC('dataset/ACDC/vallist_fog.txt', data_set= 'val', transform=transform_train,seed=36, img_size=768)
|
154 |
+
# dataset_train = semantickitti.SemanticKITTI('dataset/SemanticKitti/vallist.txt', data_set= 'val', transform=transform_train, seed=36, img_size=768)
|
155 |
+
```
|
156 |
+
|
157 |
+
Run all the blocks
|
158 |
+
|
159 |
+
## Training
|
160 |
+
|
161 |
+
### From Scratch
|
162 |
+
|
163 |
+
Run the following code in terminal
|
164 |
+
```
|
165 |
+
torchrun --nproc_per_node=4 train.py
|
166 |
+
```
|
167 |
+
|
168 |
+
it will save checkpoint in `output_dir/year.month.day.hour.min` folder, for e.g. `output_dir/2025.05.09.02.27`
|
169 |
+
|
170 |
+
### Resume Training
|
171 |
+
|
172 |
+
Run the following code in terminal
|
173 |
+
```
|
174 |
+
torchrun --nproc_per_node=4 train.py --resume year.month.day.hour.min
|
175 |
+
```
|
176 |
+
|
177 |
+
Here is an example code
|
178 |
+
```
|
179 |
+
torchrun --nproc_per_node=4 train.py --resume 2025.05.09.02.27
|
180 |
+
```
|
181 |
+
|
182 |
+
## Acknowlegement
|
183 |
+
The code is developed on top following codework
|
184 |
+
1. [latent-diffusion](https://github.com/CompVis/latent-diffusion)
|
185 |
+
2. [mar](https://github.com/LTH14/mar)
|
get_parser.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
import yaml
|
4 |
+
def get_args_parser():
|
5 |
+
parser = argparse.ArgumentParser('MAR training with Diffusion Loss', add_help=False)
|
6 |
+
parser.add_argument('--batch_size', default=16, type=int,
|
7 |
+
help='Batch size per GPU (effective batch size is batch_size * # gpus')
|
8 |
+
parser.add_argument('--epochs', default=2000, type=int)
|
9 |
+
|
10 |
+
# Model parameters
|
11 |
+
parser.add_argument('--model', default='mar_base', type=str, metavar='MODEL',
|
12 |
+
help='Name of model to train')
|
13 |
+
parser.add_argument('--ckpt_path', default="pretrained_models/mar/city768.16.pth", type=str,
|
14 |
+
help='model checkpoint path')
|
15 |
+
|
16 |
+
# VAE parameters
|
17 |
+
parser.add_argument('--img_size', default=768, type=int,
|
18 |
+
help='images input size')
|
19 |
+
parser.add_argument('--vae_path', default="pretrained_models/vae/modelf16.ckpt", type=str,
|
20 |
+
help='images input size')
|
21 |
+
parser.add_argument('--vae_embed_dim', default=16, type=int,
|
22 |
+
help='vae output embedding dimension')
|
23 |
+
parser.add_argument('--vae_stride', default=16, type=int,
|
24 |
+
help='tokenizer stride, default use KL16')
|
25 |
+
parser.add_argument('--patch_size', default=1, type=int,
|
26 |
+
help='number of tokens to group as a patch.')
|
27 |
+
parser.add_argument('--config', default="ldm/config.yaml", type=str,
|
28 |
+
help='vae model configuration file')
|
29 |
+
|
30 |
+
# Generation parameters
|
31 |
+
parser.add_argument('--num_iter', default=64, type=int,
|
32 |
+
help='number of autoregressive iterations to generate an image')
|
33 |
+
parser.add_argument('--num_images', default=3000, type=int,
|
34 |
+
help='number of images to generate')
|
35 |
+
parser.add_argument('--cfg', default=1.0, type=float, help="classifier-free guidance")
|
36 |
+
parser.add_argument('--cfg_schedule', default="linear", type=str)
|
37 |
+
parser.add_argument('--label_drop_prob', default=0.1, type=float)
|
38 |
+
parser.add_argument('--eval_freq', type=int, default=40, help='evaluation frequency')
|
39 |
+
parser.add_argument('--save_last_freq', type=int, default=5, help='save last frequency')
|
40 |
+
parser.add_argument('--online_eval', action='store_true')
|
41 |
+
parser.add_argument('--evaluate', action='store_true')
|
42 |
+
parser.add_argument('--eval_bsz', type=int, default=64, help='generation batch size')
|
43 |
+
|
44 |
+
# Optimizer parameters
|
45 |
+
parser.add_argument('--weight_decay', type=float, default=0.02,
|
46 |
+
help='weight decay (default: 0.02)')
|
47 |
+
|
48 |
+
parser.add_argument('--grad_checkpointing', action='store_true')
|
49 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
50 |
+
help='learning rate (absolute lr)')
|
51 |
+
parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
|
52 |
+
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
53 |
+
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
|
54 |
+
help='lower lr bound for cyclic schedulers that hit 0')
|
55 |
+
parser.add_argument('--lr_schedule', type=str, default='constant',
|
56 |
+
help='learning rate schedule')
|
57 |
+
parser.add_argument('--warmup_epochs', type=int, default=100, metavar='N',
|
58 |
+
help='epochs to warmup LR')
|
59 |
+
parser.add_argument('--ema_rate', default=0.9999, type=float)
|
60 |
+
|
61 |
+
# MAR params
|
62 |
+
parser.add_argument('--mask_ratio_min', type=float, default=0.7,
|
63 |
+
help='Minimum mask ratio')
|
64 |
+
parser.add_argument('--grad_clip', type=float, default=3.0,
|
65 |
+
help='Gradient clip')
|
66 |
+
parser.add_argument('--attn_dropout', type=float, default=0.1,
|
67 |
+
help='attention dropout')
|
68 |
+
parser.add_argument('--proj_dropout', type=float, default=0.1,
|
69 |
+
help='projection dropout')
|
70 |
+
parser.add_argument('--buffer_size', type=int, default=64)
|
71 |
+
|
72 |
+
# Diffusion Loss params
|
73 |
+
parser.add_argument('--diffloss_d', type=int, default=6)
|
74 |
+
parser.add_argument('--diffloss_w', type=int, default=1024)
|
75 |
+
parser.add_argument('--num_sampling_steps', type=str, default="100")
|
76 |
+
parser.add_argument('--diffusion_batch_mul', type=int, default=4)
|
77 |
+
parser.add_argument('--temperature', default=1.0, type=float, help='diffusion loss sampling temperature')
|
78 |
+
|
79 |
+
# Dataset parameters
|
80 |
+
parser.add_argument('--output_dir', default='./output_dir',
|
81 |
+
help='path where to save, empty for no saving')
|
82 |
+
parser.add_argument('--log_dir', default='./output_dir',
|
83 |
+
help='path where to tensorboard log')
|
84 |
+
parser.add_argument('--device', default='cuda',
|
85 |
+
help='device to use for training / testing')
|
86 |
+
parser.add_argument('--seed', default=1, type=int)
|
87 |
+
parser.add_argument('--resume', default=None,#'pretrained_models/mar/mar_base',
|
88 |
+
help='resume from checkpoint')
|
89 |
+
|
90 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
91 |
+
help='start epoch')
|
92 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
93 |
+
parser.add_argument('--pin_mem', action='store_true',
|
94 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
95 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
96 |
+
parser.set_defaults(pin_mem=True)
|
97 |
+
|
98 |
+
# distributed training parameters
|
99 |
+
parser.add_argument('--world_size', default=1, type=int,
|
100 |
+
help='number of distributed processes')
|
101 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
102 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
103 |
+
parser.add_argument('--dist_url', default='env://',
|
104 |
+
help='url used to set up distributed training')
|
105 |
+
|
106 |
+
# caching latents
|
107 |
+
parser.add_argument('--use_cached', action='store_true', dest='use_cached',
|
108 |
+
help='Use cached latents')
|
109 |
+
parser.set_defaults(use_cached=False)
|
110 |
+
parser.add_argument('--cached_path', default='', help='path to cached latents')
|
111 |
+
|
112 |
+
return parser
|
113 |
+
|
114 |
+
args = get_args_parser()
|
115 |
+
args = args.parse_args()
|
116 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
117 |
+
args.log_dir = args.output_dir
|
118 |
+
|
119 |
+
with open(args.config, "r") as f:
|
120 |
+
config = yaml.safe_load(f)
|
121 |
+
args.ddconfig = config["ddconfig"]
|
train.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
from pathlib import Path
|
7 |
+
import yaml
|
8 |
+
import glob
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
from torch.utils.tensorboard import SummaryWriter
|
13 |
+
import torchvision.transforms as transforms
|
14 |
+
import torchvision.datasets as datasets
|
15 |
+
from data import cityscapes
|
16 |
+
|
17 |
+
from util.crop import center_crop_arr
|
18 |
+
import util.misc as misc
|
19 |
+
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
20 |
+
from util.loader import CachedFolder
|
21 |
+
|
22 |
+
from models.vae import AutoencoderKL
|
23 |
+
from models import mar
|
24 |
+
import copy
|
25 |
+
from tqdm import tqdm
|
26 |
+
|
27 |
+
import util.lr_sched as lr_sched
|
28 |
+
|
29 |
+
import logging
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
def update_ema(target_params, source_params, rate=0.99):
|
34 |
+
"""
|
35 |
+
Update target parameters to be closer to those of source parameters using
|
36 |
+
an exponential moving average.
|
37 |
+
|
38 |
+
:param target_params: the target parameter sequence.
|
39 |
+
:param source_params: the source parameter sequence.
|
40 |
+
:param rate: the EMA rate (closer to 1 means slower).
|
41 |
+
"""
|
42 |
+
for targ, src in zip(target_params, source_params):
|
43 |
+
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
44 |
+
|
45 |
+
def logger_file(path):
|
46 |
+
logger = logging.getLogger()
|
47 |
+
logger.setLevel(logging.DEBUG)
|
48 |
+
handler = logging.FileHandler(path,"w", encoding=None, delay="true")
|
49 |
+
handler.setLevel(logging.INFO)
|
50 |
+
formatter = logging.Formatter("%(message)s")
|
51 |
+
handler.setFormatter(formatter)
|
52 |
+
logger.addHandler(handler)
|
53 |
+
return logger
|
54 |
+
|
55 |
+
|
56 |
+
def get_args_parser():
|
57 |
+
parser = argparse.ArgumentParser('MAR training with Diffusion Loss', add_help=False)
|
58 |
+
parser.add_argument('--batch_size', default=2, type=int,
|
59 |
+
help='Batch size per GPU (effective batch size is batch_size * # gpus')
|
60 |
+
parser.add_argument('--epochs', default=2000, type=int)
|
61 |
+
|
62 |
+
# Model parameters
|
63 |
+
parser.add_argument('--model', default='mar_base', type=str, metavar='MODEL',
|
64 |
+
help='Name of model to train')
|
65 |
+
parser.add_argument('--ckpt_path', default="pretrained_models/mar/city768.16.pth", type=str,
|
66 |
+
help='model checkpoint path')
|
67 |
+
|
68 |
+
# VAE parameters
|
69 |
+
parser.add_argument('--img_size', default=768, type=int,
|
70 |
+
help='images input size')
|
71 |
+
parser.add_argument('--vae_path', default="pretrained_models/vae/modelf16.ckpt", type=str,
|
72 |
+
help='images input size')
|
73 |
+
parser.add_argument('--vae_embed_dim', default=16, type=int,
|
74 |
+
help='vae output embedding dimension')
|
75 |
+
parser.add_argument('--vae_stride', default=16, type=int,
|
76 |
+
help='tokenizer stride, default use KL16')
|
77 |
+
parser.add_argument('--patch_size', default=1, type=int,
|
78 |
+
help='number of tokens to group as a patch.')
|
79 |
+
parser.add_argument('--config', default="ldm/config.yaml", type=str,
|
80 |
+
help='vae model configuration file')
|
81 |
+
|
82 |
+
# Generation parameters
|
83 |
+
parser.add_argument('--num_iter', default=64, type=int,
|
84 |
+
help='number of autoregressive iterations to generate an image')
|
85 |
+
parser.add_argument('--num_images', default=3000, type=int,
|
86 |
+
help='number of images to generate')
|
87 |
+
parser.add_argument('--cfg', default=1.0, type=float, help="classifier-free guidance")
|
88 |
+
parser.add_argument('--cfg_schedule', default="linear", type=str)
|
89 |
+
parser.add_argument('--label_drop_prob', default=0.1, type=float)
|
90 |
+
parser.add_argument('--eval_freq', type=int, default=40, help='evaluation frequency')
|
91 |
+
parser.add_argument('--save_last_freq', type=int, default=5, help='save last frequency')
|
92 |
+
parser.add_argument('--online_eval', action='store_true')
|
93 |
+
parser.add_argument('--evaluate', action='store_true')
|
94 |
+
parser.add_argument('--eval_bsz', type=int, default=64, help='generation batch size')
|
95 |
+
|
96 |
+
# Optimizer parameters
|
97 |
+
parser.add_argument('--weight_decay', type=float, default=0.02,
|
98 |
+
help='weight decay (default: 0.02)')
|
99 |
+
|
100 |
+
parser.add_argument('--grad_checkpointing', action='store_true')
|
101 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
102 |
+
help='learning rate (absolute lr)')
|
103 |
+
parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
|
104 |
+
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
105 |
+
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
|
106 |
+
help='lower lr bound for cyclic schedulers that hit 0')
|
107 |
+
parser.add_argument('--lr_schedule', type=str, default='constant',
|
108 |
+
help='learning rate schedule')
|
109 |
+
parser.add_argument('--warmup_epochs', type=int, default=100, metavar='N',
|
110 |
+
help='epochs to warmup LR')
|
111 |
+
parser.add_argument('--ema_rate', default=0.9999, type=float)
|
112 |
+
|
113 |
+
# MAR params
|
114 |
+
parser.add_argument('--mask_ratio_min', type=float, default=0.7,
|
115 |
+
help='Minimum mask ratio')
|
116 |
+
parser.add_argument('--grad_clip', type=float, default=3.0,
|
117 |
+
help='Gradient clip')
|
118 |
+
parser.add_argument('--attn_dropout', type=float, default=0.1,
|
119 |
+
help='attention dropout')
|
120 |
+
parser.add_argument('--proj_dropout', type=float, default=0.1,
|
121 |
+
help='projection dropout')
|
122 |
+
parser.add_argument('--buffer_size', type=int, default=64)
|
123 |
+
|
124 |
+
# Diffusion Loss params
|
125 |
+
parser.add_argument('--diffloss_d', type=int, default=6)
|
126 |
+
parser.add_argument('--diffloss_w', type=int, default=1024)
|
127 |
+
parser.add_argument('--num_sampling_steps', type=str, default="100")
|
128 |
+
parser.add_argument('--diffusion_batch_mul', type=int, default=4)
|
129 |
+
parser.add_argument('--temperature', default=1.0, type=float, help='diffusion loss sampling temperature')
|
130 |
+
|
131 |
+
# Dataset parameters
|
132 |
+
parser.add_argument('--output_dir', default='./output_dir',
|
133 |
+
help='path where to save, empty for no saving')
|
134 |
+
parser.add_argument('--log_dir', default='./output_dir',
|
135 |
+
help='path where to tensorboard log')
|
136 |
+
parser.add_argument('--device', default='cuda',
|
137 |
+
help='device to use for training / testing')
|
138 |
+
parser.add_argument('--seed', default=1, type=int)
|
139 |
+
parser.add_argument('--resume', default=None,
|
140 |
+
help='resume from checkpoint')
|
141 |
+
|
142 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
143 |
+
help='start epoch')
|
144 |
+
parser.add_argument('--num_workers', default=10, type=int)
|
145 |
+
parser.add_argument('--pin_mem', action='store_true',
|
146 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
147 |
+
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
148 |
+
parser.set_defaults(pin_mem=True)
|
149 |
+
|
150 |
+
# distributed training parameters
|
151 |
+
parser.add_argument('--world_size', default=1, type=int,
|
152 |
+
help='number of distributed processes')
|
153 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
154 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
155 |
+
parser.add_argument('--dist_url', default='env://',
|
156 |
+
help='url used to set up distributed training')
|
157 |
+
|
158 |
+
# caching latents
|
159 |
+
parser.add_argument('--use_cached', action='store_true', dest='use_cached',
|
160 |
+
help='Use cached latents')
|
161 |
+
parser.set_defaults(use_cached=False)
|
162 |
+
parser.add_argument('--cached_path', default='', help='path to cached latents')
|
163 |
+
|
164 |
+
return parser
|
165 |
+
|
166 |
+
|
167 |
+
def main(args):
|
168 |
+
misc.init_distributed_mode(args)
|
169 |
+
|
170 |
+
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
171 |
+
print("{}".format(args).replace(', ', ',\n'))
|
172 |
+
|
173 |
+
device = torch.device(args.device)
|
174 |
+
|
175 |
+
# fix the seed for reproducibility
|
176 |
+
seed = args.seed + misc.get_rank()
|
177 |
+
torch.manual_seed(seed)
|
178 |
+
np.random.seed(seed)
|
179 |
+
|
180 |
+
cudnn.benchmark = True
|
181 |
+
|
182 |
+
num_tasks = misc.get_world_size()
|
183 |
+
global_rank = misc.get_rank()
|
184 |
+
|
185 |
+
log_writer = None
|
186 |
+
|
187 |
+
# augmentation following DiT and ADM
|
188 |
+
transform_train = transforms.Compose([
|
189 |
+
transforms.ToTensor(),
|
190 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
191 |
+
])
|
192 |
+
|
193 |
+
dataset_train = cityscapes.CityScapes('dataset/CityScapes/trainlist.txt', transform=transform_train, img_size=args.img_size)
|
194 |
+
|
195 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
196 |
+
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
197 |
+
)
|
198 |
+
print("Sampler_train = %s" % str(sampler_train))
|
199 |
+
|
200 |
+
data_loader_train = torch.utils.data.DataLoader(
|
201 |
+
dataset_train, sampler=sampler_train,
|
202 |
+
batch_size=args.batch_size,
|
203 |
+
num_workers=args.num_workers,
|
204 |
+
pin_memory=args.pin_mem,
|
205 |
+
drop_last=True,
|
206 |
+
)
|
207 |
+
|
208 |
+
# define the vae and mar model
|
209 |
+
with open(args.config, "r") as f:
|
210 |
+
config = yaml.safe_load(f)
|
211 |
+
args.ddconfig = config["ddconfig"]
|
212 |
+
print('cofig: ', config)
|
213 |
+
|
214 |
+
vae = AutoencoderKL(
|
215 |
+
ddconfig=args.ddconfig,
|
216 |
+
embed_dim=args.vae_embed_dim,
|
217 |
+
ckpt_path=args.vae_path
|
218 |
+
).cuda().eval()
|
219 |
+
|
220 |
+
for param in vae.parameters():
|
221 |
+
param.requires_grad = False
|
222 |
+
|
223 |
+
model = mar.__dict__[args.model](
|
224 |
+
img_size=args.img_size,
|
225 |
+
vae_stride=args.vae_stride,
|
226 |
+
patch_size=args.patch_size,
|
227 |
+
vae_embed_dim=args.vae_embed_dim,
|
228 |
+
mask_ratio_min=args.mask_ratio_min,
|
229 |
+
label_drop_prob=args.label_drop_prob,
|
230 |
+
attn_dropout=args.attn_dropout,
|
231 |
+
proj_dropout=args.proj_dropout,
|
232 |
+
buffer_size=args.buffer_size,
|
233 |
+
diffloss_d=args.diffloss_d,
|
234 |
+
diffloss_w=args.diffloss_w,
|
235 |
+
num_sampling_steps=args.num_sampling_steps,
|
236 |
+
diffusion_batch_mul=args.diffusion_batch_mul,
|
237 |
+
grad_checkpointing=args.grad_checkpointing,
|
238 |
+
)
|
239 |
+
|
240 |
+
if args.ckpt_path:
|
241 |
+
checkpoint = torch.load(args.ckpt_path, map_location='cpu')
|
242 |
+
model.load_state_dict(checkpoint['model'])
|
243 |
+
|
244 |
+
print("Model = %s" % str(model))
|
245 |
+
# following timm: set wd as 0 for bias and norm layers
|
246 |
+
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
247 |
+
print("Number of trainable parameters: {}M".format(n_params / 1e6))
|
248 |
+
|
249 |
+
model.to(device)
|
250 |
+
model_without_ddp = model
|
251 |
+
|
252 |
+
eff_batch_size = args.batch_size * misc.get_world_size()
|
253 |
+
|
254 |
+
if args.lr is None: # only base_lr is specified
|
255 |
+
args.lr = args.blr
|
256 |
+
|
257 |
+
print("base lr: %.2e" % args.blr)
|
258 |
+
print("actual lr: %.2e" % args.lr)
|
259 |
+
print("effective batch size: %d" % eff_batch_size)
|
260 |
+
|
261 |
+
if args.distributed:
|
262 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
263 |
+
model_without_ddp = model.module
|
264 |
+
|
265 |
+
# no weight decay on bias, norm layers, and diffloss MLP
|
266 |
+
param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay)
|
267 |
+
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
|
268 |
+
print(optimizer)
|
269 |
+
loss_scaler = NativeScaler()
|
270 |
+
|
271 |
+
# resume training
|
272 |
+
if args.resume and glob.glob(os.path.join(args.output_dir, args.resume, 'checkpoint*.pth')):
|
273 |
+
try:
|
274 |
+
checkpoint = torch.load(sorted(glob.glob(os.path.join(args.output_dir, args.resume, 'checkpoint*.pth')))[-1], map_location='cpu')
|
275 |
+
model.load_state_dict(checkpoint['model'])
|
276 |
+
except:
|
277 |
+
checkpoint = torch.load(sorted(glob.glob(os.path.join(args.output_dir, args.resume, 'checkpoint*.pth')))[-2], map_location='cpu')
|
278 |
+
model.load_state_dict(checkpoint['model'])
|
279 |
+
state_dict = {key.replace("module.", ""): value for key, value in checkpoint['model'].items()}
|
280 |
+
model_without_ddp.load_state_dict(state_dict)
|
281 |
+
model_params = list(model_without_ddp.parameters())
|
282 |
+
ema_params = copy.deepcopy(model_params)
|
283 |
+
ema_state_dict = {key.replace("module.", ""): value for key, value in checkpoint['model_ema'].items()}
|
284 |
+
ema_params = [ema_state_dict[name].cuda() for name, _ in model_without_ddp.named_parameters()]
|
285 |
+
print("Resume checkpoint %s" % args.resume)
|
286 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
|
287 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
288 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
289 |
+
if 'scaler' in checkpoint:
|
290 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
291 |
+
print("With optim & sched!")
|
292 |
+
del checkpoint
|
293 |
+
|
294 |
+
args.output_dir = os.path.join(args.output_dir, args.resume)
|
295 |
+
|
296 |
+
logger = logger_file(args.log_dir+'/'+args.resume+'.log')
|
297 |
+
if os.path.exists(args.log_dir+'/'+args.resume+'.log'):
|
298 |
+
with open(args.log_dir+'/'+args.resume+'.log', 'r') as infile:
|
299 |
+
for line in infile:
|
300 |
+
logger.info(line.rstrip())
|
301 |
+
else:
|
302 |
+
logger.info("All the arguments")
|
303 |
+
for k, v in vars(args).items():
|
304 |
+
logger.info(f"{k}: {v}")
|
305 |
+
logger.info("\n\n Loss information")
|
306 |
+
|
307 |
+
|
308 |
+
|
309 |
+
else:
|
310 |
+
model_params = list(model_without_ddp.parameters())
|
311 |
+
ema_params = copy.deepcopy(model_params)
|
312 |
+
print("Training from scratch")
|
313 |
+
args.resume = datetime.datetime.now().strftime("%Y.%m.%d.%H.%M")
|
314 |
+
args.output_dir = os.path.join(args.output_dir, args.resume)
|
315 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
316 |
+
|
317 |
+
logger = logger_file(args.log_dir+'/'+args.resume+'.log')
|
318 |
+
logger.info("All the arguments")
|
319 |
+
for k, v in vars(args).items():
|
320 |
+
logger.info(f"{k}: {v}")
|
321 |
+
logger.info("\n\n Loss information")
|
322 |
+
|
323 |
+
|
324 |
+
print(f"Start training for {args.epochs} epochs")
|
325 |
+
start_time = time.time()
|
326 |
+
for epoch in range(args.start_epoch, args.epochs):
|
327 |
+
if args.distributed:
|
328 |
+
data_loader_train.sampler.set_epoch(epoch)
|
329 |
+
|
330 |
+
|
331 |
+
|
332 |
+
for epoch in tqdm(range(args.start_epoch, args.epochs), desc="Training Progress"):
|
333 |
+
model.train(True)
|
334 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
335 |
+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
336 |
+
header = 'Epoch: [{}]'.format(epoch)
|
337 |
+
print_freq = 20
|
338 |
+
|
339 |
+
optimizer.zero_grad()
|
340 |
+
|
341 |
+
for data_iter_step, (samples, labels, _) in enumerate(data_loader_train):
|
342 |
+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_train) + epoch, args)
|
343 |
+
samples = samples.to(device, non_blocking=True)
|
344 |
+
labels = labels.to(device, non_blocking=True)
|
345 |
+
|
346 |
+
with torch.no_grad():
|
347 |
+
posterior_x = vae.encode(samples)
|
348 |
+
posterior_y = vae.encode(labels)
|
349 |
+
x = posterior_x.sample().mul_(0.2325)
|
350 |
+
y = posterior_y.sample().mul_(0.2325)
|
351 |
+
with torch.cuda.amp.autocast():
|
352 |
+
loss = model(x,y)
|
353 |
+
loss_value = loss.item()
|
354 |
+
loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), update_grad=True)
|
355 |
+
optimizer.zero_grad()
|
356 |
+
torch.cuda.synchronize()
|
357 |
+
|
358 |
+
update_ema(ema_params, model_params, rate=args.ema_rate)
|
359 |
+
metric_logger.update(loss=loss_value)
|
360 |
+
|
361 |
+
lr = optimizer.param_groups[0]["lr"]
|
362 |
+
metric_logger.update(lr=lr)
|
363 |
+
|
364 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
365 |
+
metric_logger.synchronize_between_processes()
|
366 |
+
logger.info(f"epoch: {epoch:4d}, Averaged stats: {metric_logger}")
|
367 |
+
if (epoch+1)% args.save_last_freq == 0:
|
368 |
+
misc.save_model(args=args, model=model, model_without_ddp=model, optimizer=optimizer,
|
369 |
+
loss_scaler=loss_scaler, epoch=epoch, ema_params=ema_params, epoch_name=str(epoch).zfill(5))
|
370 |
+
|
371 |
+
total_time = time.time() - start_time
|
372 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
373 |
+
print('Training time {}'.format(total_time_str))
|
374 |
+
|
375 |
+
|
376 |
+
if __name__ == '__main__':
|
377 |
+
args = get_args_parser()
|
378 |
+
args = args.parse_args()
|
379 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
380 |
+
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
|
381 |
+
main(args)
|
validation.ipynb
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "c524f796-e657-4a59-abcf-540531a38995",
|
7 |
+
"metadata": {
|
8 |
+
"tags": []
|
9 |
+
},
|
10 |
+
"outputs": [],
|
11 |
+
"source": [
|
12 |
+
"%run get_parser.py"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": null,
|
18 |
+
"id": "4c1cf01e-8229-4d28-bcb2-01c07fa641c2",
|
19 |
+
"metadata": {
|
20 |
+
"tags": []
|
21 |
+
},
|
22 |
+
"outputs": [],
|
23 |
+
"source": [
|
24 |
+
"import os\n",
|
25 |
+
"import requests\n",
|
26 |
+
"\n",
|
27 |
+
"# Define URLs and file paths\n",
|
28 |
+
"files_to_download = {\n",
|
29 |
+
" \"https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/vae/modelf16.ckpt\":\n",
|
30 |
+
" \"pretrained_models/vae/modelf16.ckpt\",\n",
|
31 |
+
" \"https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/mar/city768.16.pth\":\n",
|
32 |
+
" \"pretrained_models/mar/city768.16.pth\"\n",
|
33 |
+
"}\n",
|
34 |
+
"\n",
|
35 |
+
"for url, path in files_to_download.items():\n",
|
36 |
+
" os.makedirs(os.path.dirname(path), exist_ok=True)\n",
|
37 |
+
" \n",
|
38 |
+
" if os.path.exists(path):\n",
|
39 |
+
" print(f\"File already exists: {path} β skipping download.\")\n",
|
40 |
+
" continue\n",
|
41 |
+
"\n",
|
42 |
+
" print(f\"Downloading from {url}...\")\n",
|
43 |
+
" response = requests.get(url, stream=True)\n",
|
44 |
+
" if response.status_code == 200:\n",
|
45 |
+
" with open(path, 'wb') as f:\n",
|
46 |
+
" for chunk in response.iter_content(chunk_size=8192):\n",
|
47 |
+
" f.write(chunk)\n",
|
48 |
+
" print(f\"Saved to {path}\")\n",
|
49 |
+
" else:\n",
|
50 |
+
" print(f\"Failed to download from {url}, status code {response.status_code}\")"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "code",
|
55 |
+
"execution_count": 3,
|
56 |
+
"id": "3a7ac93b-1cbc-45f3-8ec5-8e8257a39786",
|
57 |
+
"metadata": {
|
58 |
+
"tags": []
|
59 |
+
},
|
60 |
+
"outputs": [],
|
61 |
+
"source": [
|
62 |
+
"import numpy as np\n",
|
63 |
+
"from tqdm import tqdm\n",
|
64 |
+
"from PIL import Image\n",
|
65 |
+
"import yaml\n",
|
66 |
+
"import math\n",
|
67 |
+
"\n",
|
68 |
+
"import torch\n",
|
69 |
+
"import torch.backends.cudnn as cudnn\n",
|
70 |
+
"import torchvision.transforms as transforms\n",
|
71 |
+
"\n",
|
72 |
+
"from data import cityscapes\n",
|
73 |
+
"import util.misc as misc\n",
|
74 |
+
"\n",
|
75 |
+
"from models.vae import AutoencoderKL\n",
|
76 |
+
"from models import mar"
|
77 |
+
]
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"cell_type": "code",
|
81 |
+
"execution_count": 4,
|
82 |
+
"id": "e2bde6fd-9b39-40fd-8d4d-d0a5f9c8217a",
|
83 |
+
"metadata": {
|
84 |
+
"tags": []
|
85 |
+
},
|
86 |
+
"outputs": [],
|
87 |
+
"source": [
|
88 |
+
"def mask_by_order(mask_len, order, bsz, seq_len):\n",
|
89 |
+
" masking = torch.zeros(bsz, seq_len).cuda()\n",
|
90 |
+
" masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()\n",
|
91 |
+
" return masking\n",
|
92 |
+
"\n",
|
93 |
+
"def fast_hist(pred, label, n):\n",
|
94 |
+
" k = (label >= 0) & (label < n)\n",
|
95 |
+
" bin_count = np.bincount(\n",
|
96 |
+
" n * label[k].astype(int) + pred[k], minlength=n ** 2)\n",
|
97 |
+
" return bin_count[:n ** 2].reshape(n, n)\n",
|
98 |
+
"\n",
|
99 |
+
"color_pallete = np.round(np.array([\n",
|
100 |
+
" 0, 0, 0,\n",
|
101 |
+
" 128, 64, 128,\n",
|
102 |
+
" 244, 35, 232,\n",
|
103 |
+
" 70, 70, 70,\n",
|
104 |
+
" 102, 102, 156,\n",
|
105 |
+
" 190, 153, 153,\n",
|
106 |
+
" 153, 153, 153,\n",
|
107 |
+
" 250, 170, 30,\n",
|
108 |
+
" 220, 220, 0,\n",
|
109 |
+
" 107, 142, 35,\n",
|
110 |
+
" 152, 251, 152,\n",
|
111 |
+
" 0, 130, 180,\n",
|
112 |
+
" 220, 20, 60,\n",
|
113 |
+
" 255, 0, 0,\n",
|
114 |
+
" 0, 0, 142,\n",
|
115 |
+
" 0, 0, 70,\n",
|
116 |
+
" 0, 60, 100,\n",
|
117 |
+
" 0, 80, 100,\n",
|
118 |
+
" 0, 0, 230,\n",
|
119 |
+
" 119, 11, 32,\n",
|
120 |
+
" ])/255.0, 4)\n",
|
121 |
+
"\n",
|
122 |
+
"color_pallete = color_pallete.reshape(-1, 3)"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": 5,
|
128 |
+
"id": "c189ac7b-ccff-4745-af56-460ec88770b4",
|
129 |
+
"metadata": {
|
130 |
+
"tags": []
|
131 |
+
},
|
132 |
+
"outputs": [],
|
133 |
+
"source": [
|
134 |
+
"device = torch.device(args.device)\n",
|
135 |
+
"device = torch.device('cuda:0')\n",
|
136 |
+
"args.batch_size = 1\n",
|
137 |
+
"\n",
|
138 |
+
"# fix the seed for reproducibility\n",
|
139 |
+
"seed = args.seed + misc.get_rank()\n",
|
140 |
+
"torch.manual_seed(seed)\n",
|
141 |
+
"np.random.seed(seed)\n",
|
142 |
+
"\n",
|
143 |
+
"cudnn.benchmark = True\n",
|
144 |
+
"\n",
|
145 |
+
"num_tasks = misc.get_world_size()\n",
|
146 |
+
"global_rank = misc.get_rank()"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"cell_type": "code",
|
151 |
+
"execution_count": 6,
|
152 |
+
"id": "28d13453-a3ac-4d2e-8906-0c179e85c2f9",
|
153 |
+
"metadata": {
|
154 |
+
"tags": []
|
155 |
+
},
|
156 |
+
"outputs": [],
|
157 |
+
"source": [
|
158 |
+
"transform_train = transforms.Compose([\n",
|
159 |
+
" transforms.ToTensor(),\n",
|
160 |
+
" transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
|
161 |
+
"])\n",
|
162 |
+
"\n",
|
163 |
+
"dataset_train = cityscapes.CityScapes('dataset/CityScapes/vallist.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n",
|
164 |
+
"# dataset_train = umbc.UMBC('dataset/UMBC/all.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n",
|
165 |
+
"# dataset_train = acdc.ACDC('dataset/ACDC/vallist_fog.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n",
|
166 |
+
"# dataset_train = semantickitti.SemanticKITTI('dataset/SemanticKitti/vallist.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n",
|
167 |
+
"\n",
|
168 |
+
"\n",
|
169 |
+
"sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=1, rank=0, shuffle=False)\n",
|
170 |
+
"\n",
|
171 |
+
"data_loader_train = torch.utils.data.DataLoader(\n",
|
172 |
+
" dataset_train, sampler=sampler_train,\n",
|
173 |
+
" batch_size=args.batch_size,\n",
|
174 |
+
" num_workers=args.num_workers,\n",
|
175 |
+
" pin_memory=args.pin_mem,\n",
|
176 |
+
" drop_last=True,\n",
|
177 |
+
")"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "code",
|
182 |
+
"execution_count": null,
|
183 |
+
"id": "2e22d231-02db-4586-b489-01a97314aed9",
|
184 |
+
"metadata": {
|
185 |
+
"tags": []
|
186 |
+
},
|
187 |
+
"outputs": [],
|
188 |
+
"source": [
|
189 |
+
"vae = AutoencoderKL(\n",
|
190 |
+
" ddconfig=args.ddconfig,\n",
|
191 |
+
" embed_dim=args.vae_embed_dim,\n",
|
192 |
+
" ckpt_path=args.vae_path\n",
|
193 |
+
").to(device).eval()\n",
|
194 |
+
"\n",
|
195 |
+
"for param in vae.parameters():\n",
|
196 |
+
" param.requires_grad = False\n",
|
197 |
+
" \n",
|
198 |
+
"model = mar.mar_base(\n",
|
199 |
+
" img_size=args.img_size,\n",
|
200 |
+
" vae_stride=args.vae_stride,\n",
|
201 |
+
" patch_size=args.patch_size,\n",
|
202 |
+
" vae_embed_dim=args.vae_embed_dim,\n",
|
203 |
+
" mask_ratio_min=args.mask_ratio_min,\n",
|
204 |
+
" label_drop_prob=args.label_drop_prob,\n",
|
205 |
+
" attn_dropout=args.attn_dropout,\n",
|
206 |
+
" proj_dropout=args.proj_dropout,\n",
|
207 |
+
" buffer_size=args.buffer_size,\n",
|
208 |
+
" diffloss_d=args.diffloss_d,\n",
|
209 |
+
" diffloss_w=args.diffloss_w,\n",
|
210 |
+
" num_sampling_steps=args.num_sampling_steps,\n",
|
211 |
+
" diffusion_batch_mul=args.diffusion_batch_mul,\n",
|
212 |
+
" grad_checkpointing=args.grad_checkpointing,\n",
|
213 |
+
")\n",
|
214 |
+
"\n",
|
215 |
+
"n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
216 |
+
"print(\"Number of trainable parameters: {}M\".format(n_params / 1e6))\n",
|
217 |
+
"\n",
|
218 |
+
"\n",
|
219 |
+
"checkpoint = torch.load(args.ckpt_path, map_location='cpu')\n",
|
220 |
+
"model.load_state_dict(checkpoint['model'])\n",
|
221 |
+
"model.to(device)\n",
|
222 |
+
"\n",
|
223 |
+
"eff_batch_size = args.batch_size * misc.get_world_size()\n",
|
224 |
+
"\n",
|
225 |
+
"print(\"effective batch size: %d\" % eff_batch_size)"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "code",
|
230 |
+
"execution_count": 8,
|
231 |
+
"id": "4c83c0eb-35a5-4241-b869-d52eb6cd31e0",
|
232 |
+
"metadata": {
|
233 |
+
"tags": []
|
234 |
+
},
|
235 |
+
"outputs": [
|
236 |
+
{
|
237 |
+
"name": "stderr",
|
238 |
+
"output_type": "stream",
|
239 |
+
"text": [
|
240 |
+
"Training Progress: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 500/500 [13:11<00:00, 1.58s/it]"
|
241 |
+
]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"name": "stdout",
|
245 |
+
"output_type": "stream",
|
246 |
+
"text": [
|
247 |
+
"road : 98.06\n",
|
248 |
+
"sidewalk : 86.32\n",
|
249 |
+
"building : 89.23\n",
|
250 |
+
"wall : 47.44\n",
|
251 |
+
"fence : 43.78\n",
|
252 |
+
"pole : 60.14\n",
|
253 |
+
"tlight : 63.16\n",
|
254 |
+
"tsign : 82.48\n",
|
255 |
+
"vtation : 92.72\n",
|
256 |
+
"terrain : 80.45\n",
|
257 |
+
"sky : 95.99\n",
|
258 |
+
"person : 70.83\n",
|
259 |
+
"rider : 64.25\n",
|
260 |
+
"car : 94.06\n",
|
261 |
+
"truck : 44.90\n",
|
262 |
+
"bus : 66.81\n",
|
263 |
+
"train : 44.04\n",
|
264 |
+
"motorcycle : 47.34\n",
|
265 |
+
"bicycle : 62.50\n",
|
266 |
+
"Avg Pre : 70.24\n"
|
267 |
+
]
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"name": "stderr",
|
271 |
+
"output_type": "stream",
|
272 |
+
"text": [
|
273 |
+
"\n"
|
274 |
+
]
|
275 |
+
}
|
276 |
+
],
|
277 |
+
"source": [
|
278 |
+
"hist = []\n",
|
279 |
+
"model.eval()\n",
|
280 |
+
"for data_iter_step, (samples, labels, path) in enumerate(tqdm(data_loader_train, desc=\"Training Progress\")):\n",
|
281 |
+
" samples = samples.to(device, non_blocking=True)\n",
|
282 |
+
" labels = labels.to(device, non_blocking=True)\n",
|
283 |
+
"\n",
|
284 |
+
" with torch.no_grad():\n",
|
285 |
+
" posterior_x = vae.encode(samples)\n",
|
286 |
+
" posterior_y = vae.encode(labels)\n",
|
287 |
+
" x = posterior_x.sample().mul_(0.2325)\n",
|
288 |
+
" y = posterior_y.sample().mul_(0.2325)\n",
|
289 |
+
" x = model.patchify(x)\n",
|
290 |
+
" y = model.patchify(y)\n",
|
291 |
+
" gt_latents = y.clone().detach()\n",
|
292 |
+
" cfg_iter = 1.0\n",
|
293 |
+
" temperature = 1.0\n",
|
294 |
+
" mask_actual = torch.cat([torch.zeros(args.batch_size, model.seq_len), torch.ones(args.batch_size, model.seq_len)], dim=1).cuda()\n",
|
295 |
+
" tokens = torch.zeros(args.batch_size, model.seq_len, model.token_embed_dim).cuda()\n",
|
296 |
+
"\n",
|
297 |
+
" with torch.no_grad():\n",
|
298 |
+
" x1 = model.forward_mae_encoder(x, mask_actual, tokens)\n",
|
299 |
+
" z = model.forward_mae_decoder(x1, mask_actual)\n",
|
300 |
+
" z = z[0]\n",
|
301 |
+
" sampled_token_latent = model.diffloss.sample(z, temperature, cfg_iter)\n",
|
302 |
+
"\n",
|
303 |
+
" tokens[0] = sampled_token_latent[model.seq_len:]\n",
|
304 |
+
" tokens = model.unpatchify(tokens)\n",
|
305 |
+
" \n",
|
306 |
+
" sampled_images = vae.decode(tokens / 0.2325)\n",
|
307 |
+
" \n",
|
308 |
+
" image_tensor = labels[0] \n",
|
309 |
+
" image_tensor = image_tensor * 0.5 + 0.5\n",
|
310 |
+
" gt_np = image_tensor.permute(1, 2, 0).cpu().numpy()\n",
|
311 |
+
" H, W, _ = gt_np.shape\n",
|
312 |
+
" pixels = gt_np.reshape(-1, 3)\n",
|
313 |
+
" distances = np.linalg.norm(pixels[:, None, :] - color_pallete[None, :, :], axis=2)\n",
|
314 |
+
" output = np.argmin(distances, axis=1)\n",
|
315 |
+
" gt = output.reshape(H, W)\n",
|
316 |
+
" \n",
|
317 |
+
" image_tensor = sampled_images[0]\n",
|
318 |
+
" image_tensor = image_tensor * 0.5 + 0.5 \n",
|
319 |
+
" ss_np = image_tensor.permute(1, 2, 0).cpu().numpy()\n",
|
320 |
+
" H, W, _ = ss_np.shape\n",
|
321 |
+
" pixels = ss_np.reshape(-1, 3)\n",
|
322 |
+
" distances = np.linalg.norm(pixels[:, None, :] - color_pallete[None, :, :], axis=2)\n",
|
323 |
+
" output = np.argmin(distances, axis=1)\n",
|
324 |
+
" output = output.reshape(H, W)\n",
|
325 |
+
" \n",
|
326 |
+
" hist.append(fast_hist(output.reshape(-1), gt.reshape(-1), 20))\n",
|
327 |
+
"\n",
|
328 |
+
"cm = np.sum(hist, axis=0)\n",
|
329 |
+
"\n",
|
330 |
+
"epsilon = 1e-10\n",
|
331 |
+
"class_precision = np.diag(cm[1:,1:]) / (np.sum(cm[1:,1:], axis=0) + epsilon)\n",
|
332 |
+
"class_names = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'tlight', 'tsign', \n",
|
333 |
+
" 'vtation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \n",
|
334 |
+
" 'motorcycle', 'bicycle']\n",
|
335 |
+
"\n",
|
336 |
+
"for i in range(len(class_names)):\n",
|
337 |
+
" print(f\"{class_names[i]:<12}: {class_precision[i]*100:6.2f}\")\n",
|
338 |
+
"average_precision = np.mean(class_precision)\n",
|
339 |
+
"print(f\"{'Avg Pre':<12}: {average_precision*100:6.2f}\")"
|
340 |
+
]
|
341 |
+
}
|
342 |
+
],
|
343 |
+
"metadata": {
|
344 |
+
"kernelspec": {
|
345 |
+
"display_name": "Python 3 (ipykernel)",
|
346 |
+
"language": "python",
|
347 |
+
"name": "python3"
|
348 |
+
},
|
349 |
+
"language_info": {
|
350 |
+
"codemirror_mode": {
|
351 |
+
"name": "ipython",
|
352 |
+
"version": 3
|
353 |
+
},
|
354 |
+
"file_extension": ".py",
|
355 |
+
"mimetype": "text/x-python",
|
356 |
+
"name": "python",
|
357 |
+
"nbconvert_exporter": "python",
|
358 |
+
"pygments_lexer": "ipython3",
|
359 |
+
"version": "3.8.10"
|
360 |
+
}
|
361 |
+
},
|
362 |
+
"nbformat": 4,
|
363 |
+
"nbformat_minor": 5
|
364 |
+
}
|