mikonvergence commited on
Commit
82f1234
·
1 Parent(s): f7d38e5

github code incorporated

Browse files
app.py CHANGED
@@ -11,7 +11,7 @@ with gr.Blocks(theme=theme) as demo:
11
  gr.Markdown("# 🔵 COP-GEN-Beta: Unified Generative Modelling of COPernicus Imagery Thumbnails")
12
  gr.Markdown("### Miguel Espinosa, Valerio Marsocci, Yuru Jia, Elliot J. Crowley, Mikolaj Czerkawski")
13
  gr.Markdown('[[Website](https://miquel-espinosa.github.io/cop-gen-beta/)] [[GitHub](https://github.com/miquel-espinosa/COP-GEN-Beta)] [[Model](https://huggingface.co/mespinosami/COP-GEN-Beta)] [[Dataset](https://huggingface.co/Major-TOM)]')
14
- gr.Markdown('> ## ⚠️ NOTE: This is a prototype Beta model of COP-GEN. It is based on image thumbnails of Major TOM and does not yet support raw source data. The hillshade visualisation is used for elevation. The full model COP-GEN is coming soon.')
15
 
16
  with gr.Column(elem_classes="Main app"):
17
 
 
11
  gr.Markdown("# 🔵 COP-GEN-Beta: Unified Generative Modelling of COPernicus Imagery Thumbnails")
12
  gr.Markdown("### Miguel Espinosa, Valerio Marsocci, Yuru Jia, Elliot J. Crowley, Mikolaj Czerkawski")
13
  gr.Markdown('[[Website](https://miquel-espinosa.github.io/cop-gen-beta/)] [[GitHub](https://github.com/miquel-espinosa/COP-GEN-Beta)] [[Model](https://huggingface.co/mespinosami/COP-GEN-Beta)] [[Dataset](https://huggingface.co/Major-TOM)]')
14
+ gr.Markdown('> ## ⚠️ NOTE: This is a prototype Beta model of COP-GEN. It is based on image thumbnails of [[Major TOM](https://huggingface.co/Major-TOM)] and does not yet support raw source data. The hillshade visualisation is used for elevation. The full model COP-GEN is coming soon.')
15
 
16
  with gr.Column(elem_classes="Main app"):
17
 
src/COP-GEN-Beta DELETED
@@ -1 +0,0 @@
1
- Subproject commit eef71c50f3a233c30f1e6f8b87d7815494eb1ff2
 
 
src/COP-GEN-Beta/.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data
2
+ ./data
3
+ assets
4
+ ./assets
5
+ workdir
6
+ ./workdir
7
+ __pycache__/
8
+ **__pycache__/
9
+ *.out
10
+ *.pth
11
+ out_images/
12
+ models/
13
+ models
src/COP-GEN-Beta/README.md ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![image/png](images/banner-github-simpler.png)
2
+
3
+ # [CVPRW 2025] 🌍 COP-GEN-Beta: Unified Generative Modelling of COPernicus Imagery Thumbnails
4
+
5
+ [![HF](https://img.shields.io/badge/%F0%9F%A4%97-Demo-yellow)](https://huggingface.co/mespinosami/COP-GEN-Beta)
6
+ [![GitHub](https://img.shields.io/badge/%E2%80%8B-COP--GEN--Beta-black?logo=github)](https://github.com/miquel-espinosa/COP-GEN-Beta)
7
+ [![website](https://img.shields.io/badge/🌐-Website-grey)](https://miquel-espinosa.github.io/cop-gen-beta/)
8
+ [![HF](https://img.shields.io/badge/%F0%9F%A4%97-Model-yellow)](https://huggingface.co/mespinosami/COP-GEN-Beta)
9
+ [![paper](https://img.shields.io/badge/arXiv-2402.12095-D12424)](https://www.arxiv.org/abs/2504.08548)
10
+ <a href="https://colab.research.google.com/github/ESA-PhiLab/Major-TOM/blob/main/03-Filtering-in-Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
11
+
12
+ This repository contains the official implementation of our paper:
13
+
14
+ [COP-GEN-Beta: Unified Generative Modelling of COPernicus Imagery Thumbnails,
15
+ *Miguel Espinosa*, *Valerio Marsocci*, *Yuru Jia*, *Elliot J. Crowley*, *Mikolaj Czerkawski*, CVPRW 2025](https://www.arxiv.org/pdf/2504.08548)
16
+
17
+ ### Abstract
18
+ > _In remote sensing, multi-modal data from various sensors capturing the same scene_
19
+ _offers rich opportunities, but learning a unified representation across these modalities remains a significant challenge._
20
+ _Traditional methods have often been limited to single or dual-modality approaches._
21
+ _In this paper, we introduce COP-GEN-Beta, a generative diffusion model trained on optical, radar, and elevation data from the Major TOM dataset._
22
+ _What sets COP-GEN-Beta apart is its ability to map any subset of modalities to any other, enabling zero-shot modality translation after training._
23
+ _This is achieved through a sequence-based diffusion transformer, where each modality is controlled by its own timestep embedding._
24
+ _We extensively evaluate COP-GEN-Beta on thumbnail images from the Major TOM dataset, demonstrating its effectiveness in generating high-quality samples._
25
+ _Qualitative and quantitative evaluations validate the model's performance, highlighting its potential as a powerful pre-trained model for future remote sensing tasks._
26
+
27
+ <!-- <details> -->
28
+ <!-- <summary><h3><b>Table of Contents</b></h3></summary> -->
29
+
30
+ ### Table of Contents
31
+ - [Architecture Overview](#cop-gen-beta-architecture-overview)
32
+ - [Training Code Instructions](#training-code-instructions)
33
+ - [0. Basic folder setup](#0-basic-folder-setup)
34
+ - [1. Download training data](#1-download-training-data-subset-example-rome)
35
+ - [2. Patchify and encode thumbnails](#2-patchify-and-encode-thumbnails)
36
+ - [3. Pre-compute features with Stable Diffusion](#3-pre-compute-features-with-stable-diffusion-pretrained-autoencoder)
37
+ - [4. Convert dataset to LMDB (optional)](#4-convert-dataset-to-lmdb-optional)
38
+ - [5. Train the model](#5-train-the-model)
39
+ - [Inference Instructions](#cop-gen-beta-inference-instructions)
40
+ - [1. Download model checkpoint](#1-download-model-checkpoint)
41
+ - [2. Run inference on test set](#2-run-inference-on-test-set-rome-subset)
42
+ - [Example 1: Unconditional generation](#example-1-unconditional-generation)
43
+ - [Example 2: Single modality conditioning](#example-2-single-modality-conditioning)
44
+ - [Example 3: 2 modality conditioning](#example-3-2-modality-conditioning)
45
+ - [Example 4: 3 modality conditioning](#example-4-3-modality-conditioning)
46
+
47
+ <!-- </details> -->
48
+
49
+ # COP-GEN-Beta: Architecture Overview
50
+
51
+ We introduce COP-GEN-Beta, a diffusion model designed to handle multiple remote sensing modalities. Specifically, COP-GEN-Beta operates on four key EO modalities: Digital Elevation Model (DEM), Sentinel-1 Radar Terrain Corrected (S1 RTC), Sentinel-2 Level 1C (S2 L1C), and Sentinel-2 Level 2A (S2 L2A). Unlike previous approaches, which require separate models for per modality, COP-GEN-Beta learns joint, conditional, and marginal distributions within a unified framework.
52
+
53
+ This is achieved by (a) sampling a global and dense dataset of these modalities from Major TOM, encoding all images with a pretrained StableDiffusion autoencoder, and (b) training a sequence-based denoising diffusion model using a transformer backbone, where each modality is supplied with its designated timestep. This approach makes it possible to (c) generate all modalities based on any subset thereof that is available.
54
+
55
+ ![COP-GEN-Beta Architecture](images/cop-gen-beta-architecture.png)
56
+
57
+ # COP-GEN-Beta: Results
58
+
59
+ COP-GEN-Beta's flexible sampling capabilities enable a wide range of downstream applications
60
+ through various modality translation combinations. By allowing generation of any subset of
61
+ modalities conditioned on any other subset, our model unlocks numerous practical use cases in
62
+ remote sensing, from atmospheric correction and DEM generation to dataset expansion.
63
+
64
+ ![COP-GEN-Beta Results](images/use-case-horizontal.png)
65
+
66
+
67
+
68
+ # COP-GEN-Beta: Training code instructions.
69
+
70
+ ## 0. Basic folder setup
71
+
72
+ Data will be stored in `./data/`. Create a symlink if you need.
73
+
74
+ ```bash
75
+ ln -s /path/to/disk/with/storage/ ./data
76
+ ```
77
+
78
+ Download stable diffusion model weights.
79
+
80
+ ```bash
81
+ mkdir -p ./assets/stable-diffusion
82
+ ```
83
+
84
+ Download from [here](https://drive.google.com/drive/folders/1sV-IvcGUrZeIlTmtuKv9vDJiB4JEHL-f) and place in `./assets/stable-diffusion`.
85
+
86
+
87
+ ## 1. Download training data (subset example: Rome)
88
+
89
+ Select a subset of data to download for training. For example, lets download the region of Rome.
90
+
91
+ ```bash
92
+ sh scripts/download_rome.sh
93
+ ```
94
+
95
+ Run `python3 scripts/download_world.py --help` to see the available options.
96
+
97
+ The folder structure generated will look like this:
98
+ ```
99
+ data/majorTOM
100
+ ├── Core-DEM
101
+ │ ├── metadata.parquet
102
+ ├── Core-S1RTC
103
+ │ ├── metadata.parquet
104
+ ├── Core-S2L1C
105
+ │ ├── metadata.parquet
106
+ ├── Core-S2L2A
107
+ │ ├── metadata.parquet
108
+ ├── rome
109
+ │ ├── Core-DEM
110
+ │ │ ├── metadata.parquet (metadata.parquet for rome subset)
111
+ │ │ ├── <grid_cell>
112
+ │ │ │ ├── ...
113
+ | │ │ ├── compressed.tif
114
+ | │ │ ├── DEM.tif
115
+ | │ │ └── thumbnail.png
116
+ │ ├── Core-S1RTC
117
+ │ │ ├── metadata.parquet (metadata.parquet for rome subset)
118
+ │ │ ├── <grid_cell>
119
+ │ │ │ ├── ...
120
+ │ │ │ ├── vh.tif
121
+ │ │ │ ├── vv.tif
122
+ │ │ │ ├── thumbnail.png
123
+ ```
124
+
125
+ ## 2. Patchify and encode thumbnails.
126
+
127
+ Align image modalities (find common grid_cells), patchify into 256x256 patches (thumbnails are 1068x1068, we first crop to 1024x1024, then patchify), create train/test splits.
128
+
129
+ ```bash
130
+ python3 prepare_dataset_images.py --subset_path data/majorTOM/rome --output_dir data/majorTOM/rome/rome_thumbnail_png --bands thumbnail
131
+ ```
132
+
133
+ ## 3. Pre-compute features with Stable Diffusion pretrained autoencoder.
134
+
135
+ ```bash
136
+ bands=(DEM_thumbnail S1RTC_thumbnail S2L1C_thumbnail S2L2A_thumbnail)
137
+ splits=(train test)
138
+ for band in "${bands[@]}"; do
139
+ for split in "${splits[@]}"; do
140
+ python3 encode_majortom_images.py \
141
+ --path "data/majorTOM/rome/rome_thumbnail_png/${split}/${band}" \
142
+ --resolution 256 \
143
+ --output_dir "data/majorTOM/rome/rome_thumbnail_npy/${split}/${band}"
144
+ done
145
+ done
146
+ ```
147
+
148
+ Folder structure generated for the command above:
149
+ ```
150
+ data/majorTOM/rome/
151
+
152
+ ├── train
153
+ │ ├── DEM_thumbnail
154
+ │ │ ├── 0.npy
155
+ │ │ ├── 1.npy
156
+ │ │ ├── ...
157
+ │ ├── S1RTC_thumbnail
158
+ │ │ ├── 0.npy
159
+ │ │ ├── 1.npy
160
+ │ │ ├── ...
161
+ │ ├── S2L1C_thumbnail
162
+ │ │ ├── 0.npy
163
+ │ │ ├── 1.npy
164
+ │ │ ├── ...
165
+ │ ├── S2L2A_thumbnail
166
+ │ │ ├── 0.npy
167
+ │ │ ├── 1.npy
168
+ │ │ ├── ...
169
+ ├── test
170
+ │ ├── DEM_thumbnail
171
+ │ │ ├── 0.npy
172
+ │ │ ├── 1.npy
173
+ │ │ ├── ...
174
+ ```
175
+ ## 4. Convert dataset to LMDB (optional).
176
+
177
+ Convert npy files to LMDB dataset, for both train and test splits. (Update --batch-size to a lower value if it doesn't work).
178
+
179
+ ```bash
180
+ python3 create_lmdb.py \
181
+ --input-img-dir data/majorTOM/rome/rome_thumbnail_npy/train \
182
+ --output-dir data/majorTOM/rome/rome_thumbnail_npy_lmdb/train \
183
+ --input-type npy
184
+
185
+ python3 create_lmdb.py \
186
+ --input-img-dir data/majorTOM/rome/rome_thumbnail_npy/test \
187
+ --output-dir data/majorTOM/rome/rome_thumbnail_npy_lmdb/test \
188
+ --input-type npy
189
+ ```
190
+
191
+ ## 5. Train the model
192
+
193
+ Train the model using the following command (2 GPUs, 4 modalities). Adjust the number of GPUs and the config file as needed.
194
+
195
+ NOTE on the config file:
196
+ - Since we are training on a toy dataset, we set the batch size to 8. Feel free to increase it for larger datasets.
197
+ - Logging frequency, eval frequency, and save frequency are set to 2 for faster training. Feel free to increase it for larger datasets.
198
+
199
+ Visual results, checkpoints, and logs are stored in a generated folder called `workdir`.
200
+
201
+ ### Training with LMDB dataset
202
+
203
+ ```bash
204
+ export NUM_GPUS=2
205
+ accelerate launch \
206
+ --multi_gpu \
207
+ --num_processes $NUM_GPUS \
208
+ --mixed_precision fp16 \
209
+ train_triffuser_discrete.py \
210
+ --config="configs/majortom/discrete/lmdb/rome_dems1s2s2_cop_gen_beta.py"
211
+ ```
212
+
213
+ ### Training without LMDB (tuples dataset)
214
+
215
+ ```bash
216
+ export NUM_GPUS=2
217
+ accelerate launch \
218
+ --multi_gpu \
219
+ --num_processes $NUM_GPUS \
220
+ --mixed_precision fp16 \
221
+ train_triffuser_discrete.py \
222
+ --config="configs/majortom/discrete/rome_dems1s2s2_cop_gen_beta.py"
223
+ ```
224
+
225
+
226
+ # COP-GEN-Beta: Inference instructions.
227
+
228
+ COP-GEN-Beta is characterized by its great versatility when generating images.
229
+ Given 4 modalities (DEM, S1RTC, S2L1C, S2L2A), there exist the following generation options:
230
+
231
+ - **Unconditional generation:** Generates tuples of 4 modalities without any condition.
232
+ - **Conditional generation:**
233
+ - **Single modality conditioning:** Generates missing modalities conditioned on a single modality.
234
+ - **2 modality conditioning:** Generates missing modalities conditioned on 2 modalities.
235
+ - **3 modality conditioning:** Generates missing modalities conditioned on 3 modalities.
236
+
237
+ ## 1. Download model checkpoint
238
+
239
+ <!-- To upload the model to Hugging Face Hub just run in the pth folder: -->
240
+ <!-- huggingface-cli upload mespinosami/COP-GEN-Beta . -->
241
+
242
+ Download the model ema checkpoint from Hugging Face (https://huggingface.co/mespinosami/COP-GEN-Beta) [download-link](https://huggingface.co/mespinosami/COP-GEN-Beta/resolve/main/nnet_ema_114000.pth) and place it in `./models` folder.
243
+
244
+ This can be done by running:
245
+ ```bash
246
+ mkdir -p models
247
+ wget https://huggingface.co/mespinosami/COP-GEN-Beta/resolve/main/nnet_ema_114000.pth -O models/nnet_ema_114000.pth
248
+ ```
249
+
250
+ ## 2. Run inference on test set (Rome subset)
251
+
252
+ To see all the available inference options, run `python3 sample_n_triffuser.py --help`.
253
+ For instance:
254
+ - `--n_samples` controls the number of samples to generate for the same input condition
255
+ (useful to evaluate the generation variability),
256
+ - `--generate` is the list (comma separated) of modalities to generate,
257
+ - `--condition` is the list (comma separated) of modalities to condition on,
258
+
259
+
260
+ ### Example 1: Unconditional generation
261
+
262
+ Generates all modalities (DEM, S1RTC, S2L2A, S2L1C).
263
+ ```bash
264
+ python3 sample_n_triffuser.py \
265
+ --config configs/majortom/discrete/lmdb/rome_dems1s2s2_cop_gen_beta.py \
266
+ --data_path data/majorTOM/rome/rome_thumbnail_npy_lmdb/test \
267
+ --data_type lmdb \
268
+ --nnet_path models/nnet_ema_114000.pth \
269
+ --n_mod 4 \
270
+ --generate dem,s1_rtc,s2_l2a,s2_l1c \
271
+ --output_path out_images \
272
+ --n_samples 4 \
273
+ --save_as grid
274
+ ```
275
+
276
+ ### Example 2: Single modality conditioning
277
+
278
+ Conditioning on S1RTC to generate DEM, S2L2A, and S2L1C.
279
+
280
+ ```bash
281
+ python3 sample_n_triffuser.py \
282
+ --config configs/majortom/discrete/lmdb/rome_dems1s2s2_cop_gen_beta.py \
283
+ --data_path data/majorTOM/rome/rome_thumbnail_npy_lmdb/test \
284
+ --data_type lmdb \
285
+ --nnet_path models/nnet_ema_114000.pth \
286
+ --n_mod 4 \
287
+ --condition s1_rtc \
288
+ --generate dem,s2_l2a,s2_l1c \
289
+ --output_path out_images \
290
+ --n_samples 4 \
291
+ --save_as grid
292
+ ```
293
+
294
+ ### Example 3: 2 modality conditioning
295
+
296
+ Conditioning on DEM and S1RTC to generate S2L2A and S2L1C.
297
+
298
+ ```bash
299
+ python3 sample_n_triffuser.py \
300
+ --config configs/majortom/discrete/lmdb/rome_dems1s2s2_cop_gen_beta.py \
301
+ --data_path data/majorTOM/rome/rome_thumbnail_npy_lmdb/test \
302
+ --data_type lmdb \
303
+ --nnet_path models/nnet_ema_114000.pth \
304
+ --n_mod 4 \
305
+ --condition dem,s1_rtc \
306
+ --generate s2_l2a,s2_l1c \
307
+ --output_path out_images \
308
+ --n_samples 4 \
309
+ --save_as grid
310
+ ```
311
+
312
+ ### Example 4: 3 modality conditioning
313
+
314
+ Conditioning on DEM, S1RTC, and S2L2A to generate S2L1C.
315
+
316
+ ```bash
317
+ python3 sample_n_triffuser.py \
318
+ --config configs/majortom/discrete/lmdb/rome_dems1s2s2_cop_gen_beta.py \
319
+ --data_path data/majorTOM/rome/rome_thumbnail_npy_lmdb/test \
320
+ --data_type lmdb \
321
+ --nnet_path models/nnet_ema_114000.pth \
322
+ --n_mod 4 \
323
+ --condition dem,s1_rtc,s2_l2a \
324
+ --generate s2_l1c \
325
+ --output_path out_images \
326
+ --n_samples 4 \
327
+ --save_as grid
328
+ ```
329
+
330
+ # Citation
331
+
332
+ If you find this work useful, please cite it as follows:
333
+
334
+ ```bibtex
335
+ @inproceedings{espinosa2025copgenbeta,
336
+ title={COP-GEN-Beta: Unified Generative Modelling of COPernicus Imagery Thumbnails},
337
+ author={Espinosa, Miguel and Marsocci, Valerio and Jia, Yuru and Crowley, Elliot J. and Czerkawski, Mikolaj},
338
+ booktitle={CVPRW},
339
+ year={2025}
340
+ }
341
+ ```
src/COP-GEN-Beta/configs/majortom/discrete/lmdb/rome_dems1s2s2_cop_gen_beta.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ def d(**kwargs):
3
+ """Helper of creating a config dict."""
4
+ return ml_collections.ConfigDict(initial_dictionary=kwargs)
5
+ def get_config():
6
+ config = ml_collections.ConfigDict()
7
+ config.seed = 1234
8
+ config.pred = "noise_pred"
9
+ config.z_shape = (4, 32, 32)
10
+ config.autoencoder = d(pretrained_path="assets/stable-diffusion/autoencoder_kl_ema.pth")
11
+ config.train = d(
12
+ n_steps=500000,
13
+ batch_size=8, # Increase to 512 for larger datasets
14
+ mode="uncond",
15
+ log_interval=2, # Increase to 100 for larger datasets
16
+ eval_interval=2, # Increase to 1500 for larger datasets
17
+ save_interval=2, # Increase to 1500 for larger datasets
18
+ multi_modal=True,
19
+ )
20
+ config.optimizer = d(
21
+ name="adamw",
22
+ lr=0.0002,
23
+ weight_decay=0.03,
24
+ betas=(0.99, 0.99),
25
+ )
26
+ config.lr_scheduler = d(name="customized", warmup_steps=5000)
27
+ config.nnet = d(
28
+ name="triffuser_multi_post_ln",
29
+ img_size=32,
30
+ in_chans=4,
31
+ patch_size=2,
32
+ embed_dim=1024,
33
+ depth=20,
34
+ num_heads=16,
35
+ mlp_ratio=4,
36
+ qkv_bias=False,
37
+ pos_drop_rate=0.,
38
+ drop_rate=0.,
39
+ attn_drop_rate=0.,
40
+ mlp_time_embed=False,
41
+ num_modalities=4,
42
+ use_checkpoint=True,
43
+ )
44
+ config.dataset = d(
45
+ name="majorTOM_lmdb_256_features",
46
+ path="data/majorTOM/rome/rome_thumbnail_npy_lmdb/train",
47
+ # name="majorTOM_tuples_256_features",
48
+ # paths=["data/majorTOM/northern_italy/northern_italy_thumbnail_npy/train/DEM_thumbnail",
49
+ # "data/majorTOM/northern_italy/northern_italy_thumbnail_npy/train/S1RTC_thumbnail",
50
+ # "data/majorTOM/northern_italy/northern_italy_thumbnail_npy/train/S2L1C_thumbnail",
51
+ # "data/majorTOM/northern_italy/northern_italy_thumbnail_npy/train/S2L2A_thumbnail"],
52
+ cfg=False,
53
+ p_uncond=0.1, # 0.15
54
+ )
55
+ config.sample = d(
56
+ sample_steps=50,
57
+ n_samples=50000,
58
+ mini_batch_size=50, # the decoder is large
59
+ algorithm="dpm_solver",
60
+ cfg=True,
61
+ scale=0.4,
62
+ path="",
63
+ )
64
+ return config
src/COP-GEN-Beta/configs/majortom/discrete/rome_dems1s2s2_cop_gen_beta.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ def d(**kwargs):
3
+ """Helper of creating a config dict."""
4
+ return ml_collections.ConfigDict(initial_dictionary=kwargs)
5
+ def get_config():
6
+ config = ml_collections.ConfigDict()
7
+ config.seed = 1234
8
+ config.pred = "noise_pred"
9
+ config.z_shape = (4, 32, 32)
10
+ config.autoencoder = d(pretrained_path="assets/stable-diffusion/autoencoder_kl_ema.pth")
11
+ config.train = d(
12
+ n_steps=500000,
13
+ batch_size=8, # Increase to 512 for larger datasets
14
+ mode="uncond",
15
+ log_interval=2, # Increase to 100 for larger datasets
16
+ eval_interval=2, # Increase to 1500 for larger datasets
17
+ save_interval=2, # Increase to 1500 for larger datasets
18
+ multi_modal=True,
19
+ )
20
+ config.optimizer = d(
21
+ name="adamw",
22
+ lr=0.0002,
23
+ weight_decay=0.03,
24
+ betas=(0.99, 0.99),
25
+ )
26
+ config.lr_scheduler = d(name="customized", warmup_steps=5000)
27
+ config.nnet = d(
28
+ name="triffuser_multi_post_ln",
29
+ img_size=32,
30
+ in_chans=4,
31
+ patch_size=2,
32
+ embed_dim=1024,
33
+ depth=20,
34
+ num_heads=16,
35
+ mlp_ratio=4,
36
+ qkv_bias=False,
37
+ pos_drop_rate=0.,
38
+ drop_rate=0.,
39
+ attn_drop_rate=0.,
40
+ mlp_time_embed=False,
41
+ num_modalities=4,
42
+ use_checkpoint=True,
43
+ )
44
+ config.dataset = d(
45
+ name="majorTOM_tuples_256_features",
46
+ paths=["data/majorTOM/rome/rome_thumbnail_npy/train/DEM_thumbnail",
47
+ "data/majorTOM/rome/rome_thumbnail_npy/train/S1RTC_thumbnail",
48
+ "data/majorTOM/rome/rome_thumbnail_npy/train/S2L1C_thumbnail",
49
+ "data/majorTOM/rome/rome_thumbnail_npy/train/S2L2A_thumbnail"],
50
+ cfg=False,
51
+ p_uncond=0.1, # 0.15
52
+ )
53
+ config.sample = d(
54
+ sample_steps=50,
55
+ n_samples=50000,
56
+ mini_batch_size=50, # the decoder is large
57
+ algorithm="dpm_solver",
58
+ cfg=True,
59
+ scale=0.4,
60
+ path="",
61
+ )
62
+ return config
src/COP-GEN-Beta/create_lmdb.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Chenhongyi Yang
3
+ Reference: GPViT https://github.com/ChenhongyiYang/GPViT
4
+ """
5
+
6
+ """
7
+ This script will generate a paired LMDB database for all modalities found in the input directory.
8
+ Thus, the input directory should contain subdirectories for each modality, each containing a set of images.
9
+ The names for the paired images in the different subdirectories should be the same.
10
+
11
+ # Example:
12
+ python3 scripts/create_lmdb.py \
13
+ --input-img-dir data/majorTOM/northern_italy/northern_italy_thumbnail_npy/train \
14
+ --output-dir data/majorTOM/northern_italy/northern_italy_thumbnail_npy_lmdb/train \
15
+ --input-type npy
16
+ """
17
+
18
+
19
+ import glob
20
+ import blobfile as bf
21
+ import os
22
+ import re
23
+ import time
24
+ from collections import defaultdict
25
+ from concurrent.futures import ThreadPoolExecutor
26
+ from typing import Tuple
27
+ import pickle
28
+
29
+ import cv2
30
+ import lmdb
31
+
32
+ import argparse
33
+ parser = argparse.ArgumentParser('Convert LMDB dataset')
34
+ parser.add_argument('--input-img-dir', help='Path to ImageNet training images')
35
+ parser.add_argument('--output-dir', help='Path to output training lmdb dataset')
36
+ parser.add_argument('--input-type', choices=['png', 'npy'],
37
+ help='Type of input to encode: "png" for PNG images or "npy" for NPY features')
38
+ parser.add_argument('--batch-size', type=int, default=10000,
39
+ help='Batch size for processing images')
40
+ # parser.add_argument('val-img-dir', 'Path to ImageNet validation images')
41
+ # parser.add_argument('val-out', 'Path to output validation lmdb dataset')
42
+ args = parser.parse_args()
43
+
44
+ _10TB = 10 * (1 << 40)
45
+
46
+ class LmdbDataExporter(object):
47
+ """
48
+ making LMDB database
49
+ """
50
+ # label_pattern = re.compile(r'/.*/.*?(\d+)$')
51
+
52
+ def __init__(self,
53
+ img_dir=None,
54
+ output_path=None,
55
+ batch_size=None):
56
+ """
57
+ img_dir: imgs directory
58
+ output_path: LMDB output path
59
+ """
60
+ self.img_dir = img_dir
61
+ self.output_path = output_path
62
+ self.batch_size = batch_size
63
+
64
+ if not os.path.exists(img_dir):
65
+ raise Exception(f'{img_dir} does not exist!')
66
+
67
+ if not os.path.exists(output_path):
68
+ os.makedirs(output_path)
69
+
70
+ self.lmdb_env = lmdb.open(output_path, map_size=_10TB, max_dbs=4)
71
+ self.modalities = self._get_modalities()
72
+
73
+ def _get_modalities(self):
74
+ """Get list of modalities (subdirectories) in the input directory"""
75
+ return [d for d in os.listdir(self.img_dir)
76
+ if os.path.isdir(os.path.join(self.img_dir, d))]
77
+
78
+ def export(self):
79
+ idx = 0
80
+ results = []
81
+ st = time.time()
82
+ iter_img_lst = self.read_imgs()
83
+ length = self.get_length()
84
+ print(f'length: {length}')
85
+ while True:
86
+ items = []
87
+ try:
88
+ while len(items) < self.batch_size:
89
+ items.append(next(iter_img_lst))
90
+ except StopIteration:
91
+ break
92
+
93
+ with ThreadPoolExecutor() as executor:
94
+ results.extend(executor.map(self._extract_once, items))
95
+
96
+ if len(results) >= self.batch_size:
97
+ self.save_to_lmdb(results)
98
+ idx += self.batch_size
99
+ et = time.time()
100
+ print(f'time: {(et-st)}(s) count: {idx}')
101
+ st = time.time()
102
+ # Progressively decrease batch size for remaining items
103
+ remaining = length - idx
104
+ if remaining < self.batch_size:
105
+ self.batch_size = max(remaining // 2, 1)
106
+ print(f'batch_size is reduced to: {self.batch_size}')
107
+ del results[:]
108
+
109
+ et = time.time()
110
+ print(f'time: {(et-st)}(s) count: {idx}')
111
+ self.save_to_lmdb(results)
112
+ # self.save_total(idx)
113
+ print('Total length:', len(results))
114
+ del results[:]
115
+
116
+ def save_to_lmdb(self, results):
117
+ """
118
+ persist to lmdb
119
+ """
120
+ with self.lmdb_env.begin(write=True) as txn:
121
+ while results:
122
+ img_key, img_byte = results.pop()
123
+ if img_key is None or img_byte is None:
124
+ continue
125
+ txn.put(img_key, img_byte)
126
+
127
+ def save_total(self, total: int):
128
+ """
129
+ persist all numbers of imgs
130
+ """
131
+ with self.lmdb_env.begin(write=True, buffers=True) as txn:
132
+ txn.put('total'.encode(), str(total).encode())
133
+
134
+ def _extract_once(self, item) -> Tuple[bytes, bytes]:
135
+ image_name = item[1]
136
+ modality_data = item[2] # Dictionary of modality -> file path
137
+
138
+ # Create a dictionary to store all modality data
139
+ data_dict = {}
140
+
141
+ # Read each modality's data
142
+ for modality, file_path in modality_data.items():
143
+ if args.input_type == 'image':
144
+ img = cv2.imread(file_path)
145
+ if img is None:
146
+ print(f'{file_path} is a bad img file.')
147
+ return None, None
148
+ _, img_byte = cv2.imencode('.png', img)
149
+ data_dict[modality] = img_byte.tobytes()
150
+ else: # feature
151
+ try:
152
+ import numpy as np
153
+ features = np.load(file_path)
154
+ data_dict[modality] = features.tobytes()
155
+ except Exception as e:
156
+ print(f'Error loading {file_path}: {e}')
157
+ return None, None
158
+
159
+ return (image_name.encode('ascii'), pickle.dumps(data_dict))
160
+
161
+ def get_length(self):
162
+ # Just count files in the first modality directory
163
+ if not self.modalities:
164
+ return 0
165
+ first_modality_dir = os.path.join(self.img_dir, self.modalities[0])
166
+ img_list = glob.glob(os.path.join(first_modality_dir, '*.npy'))
167
+ return len(img_list)
168
+
169
+ def _list_image_files_recursively(self, data_dir):
170
+ results = []
171
+ for entry in sorted(bf.listdir(data_dir)):
172
+ full_path = bf.join(data_dir, entry)
173
+ ext = entry.split(".")[-1]
174
+ if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif", "npy"]:
175
+ results.append(full_path)
176
+ elif bf.isdir(full_path):
177
+ results.extend(self._list_image_files_recursively(full_path))
178
+ return results
179
+
180
+ def read_imgs(self):
181
+ # Create a dictionary to store files by their base name
182
+ file_groups = defaultdict(dict)
183
+
184
+ # File extension based on input type
185
+ extensions = ['.png'] if args.input_type == 'png' else ['.npy']
186
+
187
+ # Collect files from each modality
188
+ for modality in self.modalities:
189
+ modality_path = os.path.join(self.img_dir, modality)
190
+ for file_path in self._list_image_files_recursively(modality_path):
191
+ ext = os.path.splitext(file_path)[1].lower()
192
+ if ext in extensions:
193
+ base_name = os.path.basename(file_path)
194
+ file_groups[base_name][modality] = file_path
195
+
196
+ # Only yield complete groups
197
+ for idx, (base_name, modality_files) in enumerate(file_groups.items()):
198
+ if len(modality_files) == len(self.modalities):
199
+ item = (idx, base_name, modality_files)
200
+ yield item
201
+ else:
202
+ print(f"Skipping incomplete group {base_name}, found modalities: {list(modality_files.keys())}")
203
+
204
+
205
+ if __name__ == '__main__':
206
+ input_img_dir = args.input_img_dir
207
+ output_dir = args.output_dir
208
+
209
+ exporter = LmdbDataExporter(
210
+ input_img_dir,
211
+ output_dir,
212
+ batch_size=args.batch_size)
213
+ exporter.export()
src/COP-GEN-Beta/datasets.py ADDED
@@ -0,0 +1,885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from torchvision import datasets
3
+ import torchvision.transforms as transforms
4
+ import numpy as np
5
+ import torch
6
+ import math
7
+ import random
8
+ from PIL import Image
9
+ import os
10
+ import glob
11
+ import einops
12
+ import torchvision.transforms.functional as F
13
+
14
+
15
+ class UnlabeledDataset(Dataset):
16
+ def __init__(self, dataset):
17
+ self.dataset = dataset
18
+
19
+ def __len__(self):
20
+ return len(self.dataset)
21
+
22
+ def __getitem__(self, item):
23
+ # data = tuple(self.dataset[item][:-1]) # remove label
24
+ data = self.dataset[item]
25
+ if len(data) == 1:
26
+ data = data[0]
27
+ return data
28
+
29
+
30
+ class LabeledDataset(Dataset):
31
+ def __init__(self, dataset, labels):
32
+ self.dataset = dataset
33
+ self.labels = labels
34
+
35
+ def __len__(self):
36
+ return len(self.dataset)
37
+
38
+ def __getitem__(self, item):
39
+ return self.dataset[item], self.labels[item]
40
+
41
+
42
+ class CFGDataset(Dataset): # for classifier free guidance
43
+ def __init__(self, dataset, p_uncond, empty_token):
44
+ self.dataset = dataset
45
+ self.p_uncond = p_uncond
46
+ self.empty_token = empty_token
47
+
48
+ def __len__(self):
49
+ return len(self.dataset)
50
+
51
+ def __getitem__(self, item):
52
+ x, y = self.dataset[item]
53
+ if random.random() < self.p_uncond:
54
+ y = self.empty_token
55
+ return x, y
56
+
57
+
58
+ class DatasetFactory(object):
59
+
60
+ def __init__(self):
61
+ self.train = None
62
+ self.test = None
63
+
64
+ def get_split(self, split, labeled=False, nosplit=False):
65
+ if nosplit:
66
+ return self.dataset
67
+ if split == "train":
68
+ dataset = self.train
69
+ elif split == "test":
70
+ dataset = self.test
71
+ else:
72
+ raise ValueError
73
+
74
+ if self.has_label:
75
+ return dataset if labeled else UnlabeledDataset(dataset)
76
+ else:
77
+ assert not labeled
78
+ return dataset
79
+
80
+ def unpreprocess(self, v): # to B C H W and [0, 1]
81
+ v = 0.5 * (v + 1.)
82
+ v.clamp_(0., 1.)
83
+ return v
84
+
85
+ @property
86
+ def has_label(self):
87
+ return True
88
+
89
+ @property
90
+ def data_shape(self):
91
+ raise NotImplementedError
92
+
93
+ @property
94
+ def data_dim(self):
95
+ return int(np.prod(self.data_shape))
96
+
97
+ @property
98
+ def fid_stat(self):
99
+ return None
100
+
101
+ def sample_label(self, n_samples, device):
102
+ raise NotImplementedError
103
+
104
+ def label_prob(self, k):
105
+ raise NotImplementedError
106
+
107
+
108
+ # CIFAR10
109
+
110
+ class CIFAR10(DatasetFactory):
111
+ r""" CIFAR10 dataset
112
+
113
+ Information of the raw dataset:
114
+ train: 50,000
115
+ test: 10,000
116
+ shape: 3 * 32 * 32
117
+ """
118
+
119
+ def __init__(self, path, random_flip=False, cfg=False, p_uncond=None):
120
+ super().__init__()
121
+
122
+ transform_train = [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]
123
+ transform_test = [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]
124
+ if random_flip: # only for train
125
+ transform_train.append(transforms.RandomHorizontalFlip())
126
+ transform_train = transforms.Compose(transform_train)
127
+ transform_test = transforms.Compose(transform_test)
128
+ self.train = datasets.CIFAR10(path, train=True, transform=transform_train, download=True)
129
+ self.test = datasets.CIFAR10(path, train=False, transform=transform_test, download=True)
130
+
131
+ assert len(self.train.targets) == 50000
132
+ self.K = max(self.train.targets) + 1
133
+ self.cnt = torch.tensor([len(np.where(np.array(self.train.targets) == k)[0]) for k in range(self.K)]).float()
134
+ self.frac = [self.cnt[k] / 50000 for k in range(self.K)]
135
+ print(f'{self.K} classes')
136
+ print(f'cnt: {self.cnt}')
137
+ print(f'frac: {self.frac}')
138
+
139
+ if cfg: # classifier free guidance
140
+ assert p_uncond is not None
141
+ print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
142
+ self.train = CFGDataset(self.train, p_uncond, self.K)
143
+
144
+ @property
145
+ def data_shape(self):
146
+ return 3, 32, 32
147
+
148
+ @property
149
+ def fid_stat(self):
150
+ return 'assets/fid_stats/fid_stats_cifar10_train_pytorch.npz'
151
+
152
+ def sample_label(self, n_samples, device):
153
+ return torch.multinomial(self.cnt, n_samples, replacement=True).to(device)
154
+
155
+ def label_prob(self, k):
156
+ return self.frac[k]
157
+
158
+
159
+ # ImageNet
160
+
161
+
162
+ class FeatureDataset(Dataset):
163
+ def __init__(self, path):
164
+ super().__init__()
165
+ self.path = path
166
+ # names = sorted(os.listdir(path))
167
+ # self.files = [os.path.join(path, name) for name in names]
168
+
169
+ def __len__(self):
170
+ return 1_281_167 * 2 # consider the random flip
171
+
172
+ def __getitem__(self, idx):
173
+ path = os.path.join(self.path, f'{idx}.npy')
174
+ z, label = np.load(path, allow_pickle=True)
175
+ return z, label
176
+
177
+
178
+ class MajorTOM_S2_FeatureDataset(Dataset):
179
+
180
+ def __init__(self, path, transform=None):
181
+ super().__init__()
182
+ self.path = path
183
+ self.transform = transform
184
+ # names = sorted(os.listdir(path))
185
+ # self.files = [os.path.join(path, name) for name in names]
186
+
187
+ def __len__(self):
188
+ return len(glob.glob(f"{self.path}/*.npy"))
189
+
190
+ def __getitem__(self, idx):
191
+ path = os.path.join(self.path, f"{idx}.npy")
192
+ moment = np.load(path, allow_pickle=True).copy()
193
+ if self.transform is not None:
194
+ moment = self.transform(moment)
195
+ return moment
196
+
197
+
198
+ class MajorTOM_Tuples_FeatureDataset(Dataset):
199
+
200
+ def __init__(self, paths, transform=None):
201
+ super().__init__()
202
+ self.paths = paths
203
+ self.transform = transform
204
+ print(f"Gathering filenames...")
205
+ self.filenames = [os.path.splitext(os.path.basename(f))[0] for f in glob.glob(f"{self.paths[0]}/*.npy")]
206
+ print(f"Found {len(self.filenames)} filenames across all paths")
207
+
208
+ def __len__(self):
209
+ return len(self.filenames)
210
+
211
+ def __getitem__(self, idx):
212
+ # Return npy files for each modality. Always in the same order
213
+ moments = []
214
+ for path in self.paths:
215
+ path = os.path.join(path, f"{self.filenames[idx]}.npy")
216
+ moment = np.load(path, allow_pickle=True).copy()
217
+ if self.transform is not None:
218
+ moment = self.transform(moment)
219
+ moments.append(moment)
220
+ return moments
221
+
222
+
223
+ class MajorTOM_S2_Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder
224
+ def __init__(self, path, cfg=False, p_uncond=None):
225
+ super().__init__()
226
+ print("Prepare dataset...")
227
+ # transform_train = [transforms.ToTensor()]
228
+ transform_train = []
229
+ self.train = MajorTOM_S2_FeatureDataset(
230
+ path, transform=transforms.Compose(transform_train)
231
+ )
232
+ self.path = path
233
+ print("Prepare dataset ok")
234
+ self.K = 1000
235
+
236
+ if cfg: # classifier free guidance
237
+ assert p_uncond is not None
238
+ print(f"prepare the dataset for classifier free guidance with p_uncond={p_uncond}")
239
+ self.train = CFGDataset(self.train, p_uncond, self.K)
240
+
241
+ def get_split(self, split, labeled=False):
242
+ if split == "train":
243
+ dataset = self.train
244
+ elif split == "test":
245
+ dataset = self.test
246
+ else:
247
+ raise ValueError
248
+
249
+ if self.has_label:
250
+ return dataset if labeled else UnlabeledDataset(dataset)
251
+ else:
252
+ assert not labeled
253
+ return dataset
254
+
255
+ @property
256
+ def data_shape(self):
257
+ return 4, 133, 133
258
+
259
+ @property
260
+ def fid_stat(self):
261
+ return f"assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz"
262
+
263
+ def sample_label(self, n_samples, device):
264
+ return torch.randint(0, 1000, (n_samples,), device=device)
265
+
266
+
267
+ class MajorTOM_Tuples_Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder
268
+ def __init__(self, paths, cfg=False, p_uncond=None):
269
+ super().__init__()
270
+ print("Prepare dataset...")
271
+ # transform_train = [transforms.ToTensor()]
272
+ transform_train = []
273
+ self.train = MajorTOM_Tuples_FeatureDataset(
274
+ paths, transform=transforms.Compose(transform_train)
275
+ )
276
+ self.paths = paths
277
+ print("Prepare dataset ok")
278
+ self.K = 1000
279
+
280
+ if cfg: # classifier free guidance
281
+ assert p_uncond is not None
282
+ print(f"prepare the dataset for classifier free guidance with p_uncond={p_uncond}")
283
+ self.train = CFGDataset(self.train, p_uncond, self.K)
284
+
285
+ def get_split(self, split, labeled=False):
286
+ if split == "train":
287
+ dataset = self.train
288
+ elif split == "test":
289
+ dataset = self.test
290
+ else:
291
+ raise ValueError
292
+
293
+ if self.has_label:
294
+ return dataset if labeled else UnlabeledDataset(dataset)
295
+ else:
296
+ assert not labeled
297
+ return dataset
298
+
299
+ @property
300
+ def data_shape(self):
301
+ return "blablabla"
302
+
303
+ @property
304
+ def fid_stat(self):
305
+ return f"assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz"
306
+
307
+ def sample_label(self, n_samples, device):
308
+ raise NotImplementedError
309
+ return torch.randint(0, 1000, (n_samples,), device=device)
310
+
311
+
312
+ class MajorTOM_Lmdb_FeatureDataset(Dataset):
313
+ def __init__(self, path, transform=None, return_filename=False):
314
+ super().__init__()
315
+ import pickle
316
+
317
+ self.transform = transform
318
+ self.path = path # Store the path instead of the environment
319
+ self.return_filename = return_filename
320
+
321
+ # Create a temporary environment just to get the stats and keys
322
+ import lmdb
323
+ env = lmdb.open(
324
+ path,
325
+ max_readers=1,
326
+ readonly=True,
327
+ lock=False,
328
+ readahead=False,
329
+ meminit=False,
330
+ )
331
+
332
+ # Get total number of entries
333
+ with env.begin(write=False) as txn:
334
+ self.length = txn.stat()["entries"]
335
+
336
+ # Load or create cache of keys
337
+ root_split = path.split("/")
338
+ cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}")
339
+ if os.path.isfile(cache_file):
340
+ self.keys = pickle.load(open(cache_file, "rb"))
341
+ else:
342
+ with env.begin(write=False) as txn:
343
+ self.keys = [key for key, _ in txn.cursor()]
344
+ pickle.dump(self.keys, open(cache_file, "wb"))
345
+
346
+ # Close the temporary environment
347
+ env.close()
348
+
349
+ # Create environment lazily in each worker
350
+ self._env = None
351
+
352
+ def _init_db(self):
353
+ """Initialize LMDB environment"""
354
+ import lmdb
355
+ self._env = lmdb.open(
356
+ self.path,
357
+ max_readers=1,
358
+ readonly=True,
359
+ lock=False,
360
+ readahead=False,
361
+ meminit=False,
362
+ )
363
+
364
+ @property
365
+ def env(self):
366
+ """Get LMDB environment, creating it if necessary"""
367
+ if self._env is None:
368
+ self._init_db()
369
+ return self._env
370
+
371
+ def __len__(self):
372
+ return self.length
373
+
374
+ def __getitem__(self, idx):
375
+ # Get data from LMDB
376
+ import pickle
377
+ import numpy as np
378
+
379
+ key = self.keys[idx]
380
+ filename = key.decode('utf-8') if isinstance(key, bytes) else key
381
+ filename = os.path.basename(filename) # get filename without path
382
+ filename = os.path.splitext(filename)[0] # remove .npy extension
383
+
384
+ with self.env.begin(write=False) as txn:
385
+ data = pickle.loads(txn.get(key))
386
+
387
+ # Convert bytes to data for each modality
388
+ decoded_data = {}
389
+ for k, bytes_data in data.items():
390
+ # Convert bytes back to numpy array with the expected shape (8, 32, 32).
391
+ # TODO: This is currently hardcoded.
392
+ features = np.frombuffer(bytes_data, dtype=np.float32).reshape(8, 32, 32).copy()
393
+ decoded_data[k] = features
394
+
395
+ # Apply transforms if any
396
+ if self.transform is not None:
397
+ decoded_data = {k: self.transform(v) for k, v in decoded_data.items()}
398
+
399
+ # Convert the dictionary values to a list in a consistent order
400
+ moments = [decoded_data[k] for k in sorted(decoded_data.keys())]
401
+
402
+ if self.return_filename:
403
+ return moments, filename
404
+ return moments
405
+
406
+ def __del__(self):
407
+ if self._env is not None:
408
+ self._env.close()
409
+
410
+
411
+ class MajorTOM_Lmdb_Features(DatasetFactory):
412
+ def __init__(self, path, cfg=False, p_uncond=None, return_filename=False):
413
+ super().__init__()
414
+ print("Prepare dataset...")
415
+ transform_train = []
416
+ self.return_filename = return_filename
417
+ self.train = MajorTOM_Lmdb_FeatureDataset(
418
+ path, transform=transforms.Compose(transform_train), return_filename=return_filename
419
+ )
420
+ self.path = path
421
+ print("Prepare dataset ok")
422
+ self.K = 1000
423
+
424
+ if cfg: # classifier free guidance
425
+ assert p_uncond is not None
426
+ print(f"prepare the dataset for classifier free guidance with p_uncond={p_uncond}")
427
+ self.train = CFGDataset(self.train, p_uncond, self.K)
428
+
429
+ def get_split(self, split, labeled=False):
430
+ if split == "train":
431
+ dataset = self.train
432
+ elif split == "test":
433
+ dataset = self.test
434
+ else:
435
+ raise ValueError
436
+
437
+ if self.has_label:
438
+ return dataset if labeled else UnlabeledDataset(dataset)
439
+ else:
440
+ assert not labeled
441
+ return dataset
442
+
443
+ @property
444
+ def data_shape(self):
445
+ return "blablabla"
446
+
447
+ @property
448
+ def fid_stat(self):
449
+ return f"assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz"
450
+
451
+ def sample_label(self, n_samples, device):
452
+ raise NotImplementedError
453
+ return torch.randint(0, 1000, (n_samples,), device=device)
454
+
455
+
456
+ class ImageNet256Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder
457
+ def __init__(self, path, cfg=False, p_uncond=None):
458
+ super().__init__()
459
+ print('Prepare dataset...')
460
+ self.train = FeatureDataset(path)
461
+ print('Prepare dataset ok')
462
+ self.K = 1000
463
+
464
+ if cfg: # classifier free guidance
465
+ assert p_uncond is not None
466
+ print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
467
+ self.train = CFGDataset(self.train, p_uncond, self.K)
468
+
469
+ @property
470
+ def data_shape(self):
471
+ return 4, 32, 32
472
+
473
+ @property
474
+ def fid_stat(self):
475
+ return f'assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz'
476
+
477
+ def sample_label(self, n_samples, device):
478
+ return torch.randint(0, 1000, (n_samples,), device=device)
479
+
480
+
481
+ class ImageNet512Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder
482
+ def __init__(self, path, cfg=False, p_uncond=None):
483
+ super().__init__()
484
+ print('Prepare dataset...')
485
+ self.train = FeatureDataset(path)
486
+ print('Prepare dataset ok')
487
+ self.K = 1000
488
+
489
+ if cfg: # classifier free guidance
490
+ assert p_uncond is not None
491
+ print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
492
+ self.train = CFGDataset(self.train, p_uncond, self.K)
493
+
494
+ @property
495
+ def data_shape(self):
496
+ return 4, 64, 64
497
+
498
+ @property
499
+ def fid_stat(self):
500
+ return f'assets/fid_stats/fid_stats_imagenet512_guided_diffusion.npz'
501
+
502
+ def sample_label(self, n_samples, device):
503
+ return torch.randint(0, 1000, (n_samples,), device=device)
504
+
505
+
506
+ class ImageNet(DatasetFactory):
507
+ def __init__(self, path, resolution, random_crop=False, random_flip=True):
508
+ super().__init__()
509
+
510
+ print(f'Counting ImageNet files from {path}')
511
+ train_files = _list_image_files_recursively(os.path.join(path, 'train'))
512
+ class_names = [os.path.basename(path).split("_")[0] for path in train_files]
513
+ sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
514
+ train_labels = [sorted_classes[x] for x in class_names]
515
+ print('Finish counting ImageNet files')
516
+
517
+ self.train = ImageDataset(resolution, train_files, labels=train_labels, random_crop=random_crop, random_flip=random_flip)
518
+ self.resolution = resolution
519
+ if len(self.train) != 1_281_167:
520
+ print(f'Missing train samples: {len(self.train)} < 1281167')
521
+
522
+ self.K = max(self.train.labels) + 1
523
+ cnt = dict(zip(*np.unique(self.train.labels, return_counts=True)))
524
+ self.cnt = torch.tensor([cnt[k] for k in range(self.K)]).float()
525
+ self.frac = [self.cnt[k] / len(self.train.labels) for k in range(self.K)]
526
+ print(f'{self.K} classes')
527
+ print(f'cnt[:10]: {self.cnt[:10]}')
528
+ print(f'frac[:10]: {self.frac[:10]}')
529
+
530
+ @property
531
+ def data_shape(self):
532
+ return 3, self.resolution, self.resolution
533
+
534
+ @property
535
+ def fid_stat(self):
536
+ return f'assets/fid_stats/fid_stats_imagenet{self.resolution}_guided_diffusion.npz'
537
+
538
+ def sample_label(self, n_samples, device):
539
+ return torch.multinomial(self.cnt, n_samples, replacement=True).to(device)
540
+
541
+ def label_prob(self, k):
542
+ return self.frac[k]
543
+
544
+ class MajorTOMThumbnail(DatasetFactory):
545
+ def __init__(self, path, resolution):
546
+ super().__init__()
547
+
548
+ print(f'Counting MajorTOM thumbnail files from {path}')
549
+ files_list = _list_image_files_recursively(path)
550
+ print('Finish counting MajorTOM thumbnail files')
551
+
552
+ self.dataset = MajorTOMThumbnailDataset(resolution, files_list)
553
+ self.resolution = resolution
554
+ if len(self.dataset) != 1_281_167:
555
+ print(f'Missing train samples: {len(self.dataset)} < 1281167')
556
+
557
+ @property
558
+ def data_shape(self):
559
+ return 3, self.resolution, self.resolution
560
+
561
+ @property
562
+ def has_label(self):
563
+ return False
564
+
565
+ @property
566
+ def fid_stat(self):
567
+ return f'assets/fid_stats/fid_stats_imagenet{self.resolution}_guided_diffusion.npz'
568
+
569
+
570
+ class MajorTOMThumbnailDataset(Dataset):
571
+ def __init__(
572
+ self,
573
+ resolution,
574
+ image_paths,
575
+ ):
576
+ super().__init__()
577
+ self.resolution = resolution
578
+ self.image_paths = image_paths
579
+
580
+ def __len__(self):
581
+ return len(self.image_paths)
582
+
583
+ def __getitem__(self, idx):
584
+ path = self.image_paths[idx]
585
+ filename = os.path.basename(path).split('.')[0]
586
+ pil_image = Image.open(path)
587
+ pil_image.load()
588
+ pil_image = pil_image.convert("RGB")
589
+
590
+ # check that the image has the correct resolution
591
+ if pil_image.size != (self.resolution, self.resolution):
592
+ raise ValueError(f"Image at {path} has size {pil_image.size}, expected {self.resolution}x{self.resolution}")
593
+
594
+ arr = np.array(pil_image).astype(np.float32) / 127.5 - 1
595
+
596
+ return np.transpose(arr, [2, 0, 1]), filename
597
+
598
+ def _list_image_files_recursively(data_dir):
599
+ results = []
600
+ for entry in sorted(os.listdir(data_dir)):
601
+ full_path = os.path.join(data_dir, entry)
602
+ ext = entry.split(".")[-1]
603
+ if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
604
+ results.append(full_path)
605
+ elif os.listdir(full_path):
606
+ results.extend(_list_image_files_recursively(full_path))
607
+ return results
608
+
609
+
610
+ class ImageDataset(Dataset):
611
+ def __init__(
612
+ self,
613
+ resolution,
614
+ image_paths,
615
+ labels,
616
+ random_crop=False,
617
+ random_flip=True,
618
+ ):
619
+ super().__init__()
620
+ self.resolution = resolution
621
+ self.image_paths = image_paths
622
+ self.labels = labels
623
+ self.random_crop = random_crop
624
+ self.random_flip = random_flip
625
+
626
+ def __len__(self):
627
+ return len(self.image_paths)
628
+
629
+ def __getitem__(self, idx):
630
+ path = self.image_paths[idx]
631
+ pil_image = Image.open(path)
632
+ pil_image.load()
633
+ pil_image = pil_image.convert("RGB")
634
+
635
+ if self.random_crop:
636
+ arr = random_crop_arr(pil_image, self.resolution)
637
+ else:
638
+ arr = center_crop_arr(pil_image, self.resolution)
639
+
640
+ if self.random_flip and random.random() < 0.5:
641
+ arr = arr[:, ::-1]
642
+
643
+ arr = arr.astype(np.float32) / 127.5 - 1
644
+
645
+ label = np.array(self.labels[idx], dtype=np.int64)
646
+ return np.transpose(arr, [2, 0, 1]), label
647
+
648
+
649
+ def center_crop_arr(pil_image, image_size):
650
+ # We are not on a new enough PIL to support the `reducing_gap`
651
+ # argument, which uses BOX downsampling at powers of two first.
652
+ # Thus, we do it by hand to improve downsample quality.
653
+ while min(*pil_image.size) >= 2 * image_size:
654
+ pil_image = pil_image.resize(
655
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
656
+ )
657
+
658
+ scale = image_size / min(*pil_image.size)
659
+ pil_image = pil_image.resize(
660
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
661
+ )
662
+
663
+ arr = np.array(pil_image)
664
+ crop_y = (arr.shape[0] - image_size) // 2
665
+ crop_x = (arr.shape[1] - image_size) // 2
666
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
667
+
668
+
669
+ def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
670
+ min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
671
+ max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
672
+ smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
673
+
674
+ # We are not on a new enough PIL to support the `reducing_gap`
675
+ # argument, which uses BOX downsampling at powers of two first.
676
+ # Thus, we do it by hand to improve downsample quality.
677
+ while min(*pil_image.size) >= 2 * smaller_dim_size:
678
+ pil_image = pil_image.resize(
679
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
680
+ )
681
+
682
+ scale = smaller_dim_size / min(*pil_image.size)
683
+ pil_image = pil_image.resize(
684
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
685
+ )
686
+
687
+ arr = np.array(pil_image)
688
+ crop_y = random.randrange(arr.shape[0] - image_size + 1)
689
+ crop_x = random.randrange(arr.shape[1] - image_size + 1)
690
+ return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
691
+
692
+
693
+ # CelebA
694
+
695
+
696
+ class Crop(object):
697
+ def __init__(self, x1, x2, y1, y2):
698
+ self.x1 = x1
699
+ self.x2 = x2
700
+ self.y1 = y1
701
+ self.y2 = y2
702
+
703
+ def __call__(self, img):
704
+ return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1)
705
+
706
+ def __repr__(self):
707
+ return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
708
+ self.x1, self.x2, self.y1, self.y2
709
+ )
710
+
711
+
712
+ class CelebA(DatasetFactory):
713
+ r""" train: 162,770
714
+ val: 19,867
715
+ test: 19,962
716
+ shape: 3 * width * width
717
+ """
718
+
719
+ def __init__(self, path, resolution=64):
720
+ super().__init__()
721
+
722
+ self.resolution = resolution
723
+
724
+ cx = 89
725
+ cy = 121
726
+ x1 = cy - 64
727
+ x2 = cy + 64
728
+ y1 = cx - 64
729
+ y2 = cx + 64
730
+
731
+ transform = transforms.Compose([Crop(x1, x2, y1, y2), transforms.Resize(self.resolution),
732
+ transforms.RandomHorizontalFlip(), transforms.ToTensor(),
733
+ transforms.Normalize(0.5, 0.5)])
734
+ self.train = datasets.CelebA(root=path, split="train", target_type=[], transform=transform, download=True)
735
+ self.train = UnlabeledDataset(self.train)
736
+
737
+ @property
738
+ def data_shape(self):
739
+ return 3, self.resolution, self.resolution
740
+
741
+ @property
742
+ def fid_stat(self):
743
+ return 'assets/fid_stats/fid_stats_celeba64_train_50000_ddim.npz'
744
+
745
+ @property
746
+ def has_label(self):
747
+ return False
748
+
749
+
750
+ # MS COCO
751
+
752
+
753
+ def center_crop(width, height, img):
754
+ resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos']
755
+ crop = np.min(img.shape[:2])
756
+ img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
757
+ (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2]
758
+ try:
759
+ img = Image.fromarray(img, 'RGB')
760
+ except:
761
+ img = Image.fromarray(img)
762
+ img = img.resize((width, height), resample)
763
+
764
+ return np.array(img).astype(np.uint8)
765
+
766
+
767
+ class MSCOCODatabase(Dataset):
768
+ def __init__(self, root, annFile, size=None):
769
+ from pycocotools.coco import COCO
770
+ self.root = root
771
+ self.height = self.width = size
772
+
773
+ self.coco = COCO(annFile)
774
+ self.keys = list(sorted(self.coco.imgs.keys()))
775
+
776
+ def _load_image(self, key: int):
777
+ path = self.coco.loadImgs(key)[0]["file_name"]
778
+ return Image.open(os.path.join(self.root, path)).convert("RGB")
779
+
780
+ def _load_target(self, key: int):
781
+ return self.coco.loadAnns(self.coco.getAnnIds(key))
782
+
783
+ def __len__(self):
784
+ return len(self.keys)
785
+
786
+ def __getitem__(self, index):
787
+ key = self.keys[index]
788
+ image = self._load_image(key)
789
+ image = np.array(image).astype(np.uint8)
790
+ image = center_crop(self.width, self.height, image).astype(np.float32)
791
+ image = (image / 127.5 - 1.0).astype(np.float32)
792
+ image = einops.rearrange(image, 'h w c -> c h w')
793
+
794
+ anns = self._load_target(key)
795
+ target = []
796
+ for ann in anns:
797
+ target.append(ann['caption'])
798
+
799
+ return image, target
800
+
801
+
802
+ def get_feature_dir_info(root):
803
+ files = glob.glob(os.path.join(root, '*.npy'))
804
+ files_caption = glob.glob(os.path.join(root, '*_*.npy'))
805
+ num_data = len(files) - len(files_caption)
806
+ n_captions = {k: 0 for k in range(num_data)}
807
+ for f in files_caption:
808
+ name = os.path.split(f)[-1]
809
+ k1, k2 = os.path.splitext(name)[0].split('_')
810
+ n_captions[int(k1)] += 1
811
+ return num_data, n_captions
812
+
813
+
814
+ class MSCOCOFeatureDataset(Dataset):
815
+ # the image features are got through sample
816
+ def __init__(self, root):
817
+ self.root = root
818
+ self.num_data, self.n_captions = get_feature_dir_info(root)
819
+
820
+ def __len__(self):
821
+ return self.num_data
822
+
823
+ def __getitem__(self, index):
824
+ z = np.load(os.path.join(self.root, f'{index}.npy'))
825
+ k = random.randint(0, self.n_captions[index] - 1)
826
+ c = np.load(os.path.join(self.root, f'{index}_{k}.npy'))
827
+ return z, c
828
+
829
+
830
+ class MSCOCO256Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder & the contexts calculated by clip
831
+ def __init__(self, path, cfg=False, p_uncond=None):
832
+ super().__init__()
833
+ print('Prepare dataset...')
834
+ self.train = MSCOCOFeatureDataset(os.path.join(path, 'train'))
835
+ self.test = MSCOCOFeatureDataset(os.path.join(path, 'val'))
836
+ assert len(self.train) == 82783
837
+ assert len(self.test) == 40504
838
+ print('Prepare dataset ok')
839
+
840
+ self.empty_context = np.load(os.path.join(path, 'empty_context.npy'))
841
+
842
+ if cfg: # classifier free guidance
843
+ assert p_uncond is not None
844
+ print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
845
+ self.train = CFGDataset(self.train, p_uncond, self.empty_context)
846
+
847
+ # text embedding extracted by clip
848
+ # for visulization in t2i
849
+ self.prompts, self.contexts = [], []
850
+ for f in sorted(os.listdir(os.path.join(path, 'run_vis')), key=lambda x: int(x.split('.')[0])):
851
+ prompt, context = np.load(os.path.join(path, 'run_vis', f), allow_pickle=True)
852
+ self.prompts.append(prompt)
853
+ self.contexts.append(context)
854
+ self.contexts = np.array(self.contexts)
855
+
856
+ @property
857
+ def data_shape(self):
858
+ return 4, 32, 32
859
+
860
+ @property
861
+ def fid_stat(self):
862
+ return f'assets/fid_stats/fid_stats_mscoco256_val.npz'
863
+
864
+
865
+ def get_dataset(name, **kwargs):
866
+ if name == 'cifar10':
867
+ return CIFAR10(**kwargs)
868
+ elif name == 'imagenet':
869
+ return ImageNet(**kwargs)
870
+ elif name == 'imagenet256_features':
871
+ return ImageNet256Features(**kwargs)
872
+ elif name == 'imagenet512_features':
873
+ return ImageNet512Features(**kwargs)
874
+ elif name == "majorTOM_S2_256_features":
875
+ return MajorTOM_S2_Features(**kwargs)
876
+ elif name == "majorTOM_tuples_256_features":
877
+ return MajorTOM_Tuples_Features(**kwargs)
878
+ elif name == "majorTOM_lmdb_256_features":
879
+ return MajorTOM_Lmdb_Features(**kwargs)
880
+ elif name == 'celeba':
881
+ return CelebA(**kwargs)
882
+ elif name == 'mscoco256_features':
883
+ return MSCOCO256Features(**kwargs)
884
+ else:
885
+ raise NotImplementedError(name)
src/COP-GEN-Beta/dpm_solver_pp.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ import numpy as np
5
+ import torch.distributed as dist
6
+
7
+
8
+ def interpolate_fn(x: torch.Tensor, xp: torch.Tensor, yp: torch.Tensor) -> torch.Tensor:
9
+ """Performs piecewise linear interpolation for x, using xp and yp keypoints (knots).
10
+ Performs separate interpolation for each channel.
11
+ Args:
12
+ x: [N, C] points to be calibrated (interpolated). Batch with C channels.
13
+ xp: [C, K] x coordinates of the PWL knots. C is the number of channels, K is the number of knots.
14
+ yp: [C, K] y coordinates of the PWL knots. C is the number of channels, K is the number of knots.
15
+ Returns:
16
+ Interpolated points of the shape [N, C].
17
+ The piecewise linear function extends for the whole x axis (the outermost keypoints define the outermost
18
+ infinite lines).
19
+ For example:
20
+ >>> calibrate1d(torch.tensor([[0.5]]), torch.tensor([[0.0, 1.0]]), torch.tensor([[0.0, 2.0]]))
21
+ tensor([[1.0000]])
22
+ >>> calibrate1d(torch.tensor([[-10]]), torch.tensor([[0.0, 1.0]]), torch.tensor([[0.0, 2.0]]))
23
+ tensor([[-20.0000]])
24
+ """
25
+ x_breakpoints = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((x.shape[0], 1, 1))], dim=2)
26
+ num_x_points = xp.shape[1]
27
+ sorted_x_breakpoints, x_indices = torch.sort(x_breakpoints, dim=2)
28
+ x_idx = torch.argmin(x_indices, dim=2)
29
+ cand_start_idx = x_idx - 1
30
+ start_idx = torch.where(
31
+ torch.eq(x_idx, 0),
32
+ torch.tensor(1, device=x.device),
33
+ torch.where(
34
+ torch.eq(x_idx, num_x_points), torch.tensor(num_x_points - 2, device=x.device), cand_start_idx,
35
+ ),
36
+ )
37
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
38
+ start_x = torch.gather(sorted_x_breakpoints, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
39
+ end_x = torch.gather(sorted_x_breakpoints, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
40
+ start_idx2 = torch.where(
41
+ torch.eq(x_idx, 0),
42
+ torch.tensor(0, device=x.device),
43
+ torch.where(
44
+ torch.eq(x_idx, num_x_points), torch.tensor(num_x_points - 2, device=x.device), cand_start_idx,
45
+ ),
46
+ )
47
+ y_positions_expanded = yp.unsqueeze(0).expand(x.shape[0], -1, -1)
48
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
49
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
50
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
51
+ return cand
52
+
53
+
54
+ class NoiseScheduleVP:
55
+ def __init__(self, schedule='discrete', beta_0=1e-4, beta_1=2e-2, total_N=1000, betas=None, alphas_cumprod=None):
56
+ """Create a wrapper class for the forward SDE (VP type).
57
+
58
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
59
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
60
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
61
+
62
+ log_alpha_t = self.marginal_log_mean_coeff(t)
63
+ sigma_t = self.marginal_std(t)
64
+ lambda_t = self.marginal_lambda(t)
65
+
66
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
67
+
68
+ t = self.inverse_lambda(lambda_t)
69
+
70
+ ===============================================================
71
+
72
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
73
+ schedule are the default settings in DDPM and improved-DDPM:
74
+
75
+ beta_min: A `float` number. The smallest beta for the linear schedule.
76
+ beta_max: A `float` number. The largest beta for the linear schedule.
77
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
78
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
79
+ T: A `float` number. The ending time of the forward process.
80
+
81
+ Note that the original DDPM (linear schedule) used the discrete-time label (0 to 999). We convert the discrete-time
82
+ label to the continuous-time time (followed Song et al., 2021), so the beta here is 1000x larger than those in DDPM.
83
+
84
+ ===============================================================
85
+
86
+ Args:
87
+ schedule: A `str`. The noise schedule of the forward SDE ('linear' or 'cosine').
88
+
89
+ Returns:
90
+ A wrapper object of the forward SDE (VP type).
91
+ """
92
+ if schedule not in ['linear', 'discrete', 'cosine']:
93
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'linear' or 'cosine'".format(schedule))
94
+ self.total_N = total_N
95
+ self.beta_0 = beta_0 * 1000.
96
+ self.beta_1 = beta_1 * 1000.
97
+
98
+ if schedule == 'discrete':
99
+ if betas is not None:
100
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
101
+ else:
102
+ assert alphas_cumprod is not None
103
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
104
+ self.total_N = len(log_alphas)
105
+ self.t_discrete = torch.linspace(1. / self.total_N, 1., self.total_N).reshape((1, -1))
106
+ self.log_alpha_discrete = log_alphas.reshape((1, -1))
107
+
108
+ self.cosine_s = 0.008
109
+ self.cosine_beta_max = 999.
110
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
111
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
112
+ self.schedule = schedule
113
+ if schedule == 'cosine':
114
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
115
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
116
+ self.T = 0.9946
117
+ else:
118
+ self.T = 1.
119
+
120
+ def marginal_log_mean_coeff(self, t):
121
+ """
122
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
123
+ """
124
+ if self.schedule == 'linear':
125
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
126
+ elif self.schedule == 'discrete':
127
+ return interpolate_fn(t.reshape((-1, 1)), self.t_discrete.clone().to(t.device), self.log_alpha_discrete.clone().to(t.device)).reshape((-1,))
128
+ elif self.schedule == 'cosine':
129
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
130
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
131
+ return log_alpha_t
132
+ else:
133
+ raise ValueError("Unsupported ")
134
+
135
+ def marginal_alpha(self, t):
136
+ return torch.exp(self.marginal_log_mean_coeff(t))
137
+
138
+ def marginal_std(self, t):
139
+ """
140
+ Compute sigma_t of a given continuous-time label t in [0, T].
141
+ """
142
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
143
+
144
+ def marginal_lambda(self, t):
145
+ """
146
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
147
+ """
148
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
149
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
150
+ return log_mean_coeff - log_std
151
+
152
+ def inverse_lambda(self, lamb):
153
+ """
154
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
155
+ """
156
+ if self.schedule == 'linear':
157
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
158
+ Delta = self.beta_0**2 + tmp
159
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
160
+ elif self.schedule == 'discrete':
161
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
162
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_discrete.clone().to(lamb.device), [1]), torch.flip(self.t_discrete.clone().to(lamb.device), [1]))
163
+ return t.reshape((-1,))
164
+ else:
165
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
166
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
167
+ t = t_fn(log_alpha)
168
+ return t
169
+
170
+
171
+ def model_wrapper(model, noise_schedule=None, is_cond_classifier=False, classifier_fn=None, classifier_scale=1., time_input_type='1', total_N=1000, model_kwargs={}, is_deis=False):
172
+ """Create a wrapper function for the noise prediction model.
173
+
174
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
175
+ firstly wrap the model function to a function that accepts the continuous time as the input.
176
+
177
+ The input `model` has the following format:
178
+
179
+ ``
180
+ model(x, t_input, **model_kwargs) -> noise
181
+ ``
182
+
183
+ where `x` and `noise` have the same shape, and `t_input` is the time label of the model.
184
+ (may be discrete-time labels (i.e. 0 to 999) or continuous-time labels (i.e. epsilon to T).)
185
+
186
+ We wrap the model function to the following format:
187
+
188
+ ``
189
+ def model_fn(x, t_continuous) -> noise:
190
+ t_input = get_model_input_time(t_continuous)
191
+ return model(x, t_input, **model_kwargs)
192
+ ``
193
+
194
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
195
+
196
+ For DPMs with classifier guidance, we also combine the model output with the classifier gradient as used in [1].
197
+
198
+ [1] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," in Advances in Neural
199
+ Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
200
+
201
+ ===============================================================
202
+
203
+ Args:
204
+ model: A noise prediction model with the following format:
205
+ ``
206
+ def model(x, t_input, **model_kwargs):
207
+ return noise
208
+ ``
209
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP. Only used for the classifier guidance.
210
+ is_cond_classifier: A `bool`. Whether to use the classifier guidance.
211
+ classifier_fn: A classifier function. Only used for the classifier guidance. The format is:
212
+ ``
213
+ def classifier_fn(x, t_input):
214
+ return logits
215
+ ``
216
+ classifier_scale: A `float`. The scale for the classifier guidance.
217
+ time_input_type: A `str`. The type for the time input of the model. We support three types:
218
+ - '0': The continuous-time type. In this case, the model is trained on the continuous time,
219
+ so `t_input` = `t_continuous`.
220
+ - '1': The Type-1 discrete type described in the Appendix of DPM-Solver paper.
221
+ **For discrete-time DPMs, we recommend to use this type for DPM-Solver**.
222
+ - '2': The Type-2 discrete type described in the Appendix of DPM-Solver paper.
223
+ total_N: A `int`. The total number of the discrete-time DPMs (default is 1000), used when `time_input_type`
224
+ is '1' or '2'.
225
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
226
+ Returns:
227
+ A function that accepts the continuous time as the input, with the following format:
228
+ ``
229
+ def model_fn(x, t_continuous):
230
+ t_input = get_model_input_time(t_continuous)
231
+ return model(x, t_input, **model_kwargs)
232
+ ``
233
+ """
234
+ def get_model_input_time(t_continuous):
235
+ """
236
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
237
+ """
238
+ if time_input_type == '0':
239
+ # discrete_type == '0' means that the model is continuous-time model.
240
+ # For continuous-time DPMs, the continuous time equals to the discrete time.
241
+ return t_continuous
242
+ elif time_input_type == '1':
243
+ # Type-1 discrete label, as detailed in the Appendix of DPM-Solver.
244
+ return 1000. * torch.max(t_continuous - 1. / total_N, torch.zeros_like(t_continuous).to(t_continuous))
245
+ elif time_input_type == '2':
246
+ # Type-2 discrete label, as detailed in the Appendix of DPM-Solver.
247
+ max_N = (total_N - 1) / total_N * 1000.
248
+ return max_N * t_continuous
249
+ else:
250
+ raise ValueError("Unsupported time input type {}, must be '0' or '1' or '2'".format(time_input_type))
251
+
252
+ def cond_fn(x, t_discrete, y):
253
+ """
254
+ Compute the gradient of the classifier, multiplied with the sclae of the classifier guidance.
255
+ """
256
+ assert y is not None
257
+ with torch.enable_grad():
258
+ x_in = x.detach().requires_grad_(True)
259
+ logits = classifier_fn(x_in, t_discrete)
260
+ log_probs = F.log_softmax(logits, dim=-1)
261
+ selected = log_probs[range(len(logits)), y.view(-1)]
262
+ return classifier_scale * torch.autograd.grad(selected.sum(), x_in)[0]
263
+
264
+ def model_fn(x, t_continuous):
265
+ """
266
+ The noise predicition model function that is used for DPM-Solver.
267
+ """
268
+ if t_continuous.reshape((-1,)).shape[0] == 1:
269
+ t_continuous = torch.ones((x.shape[0],)).to(x.device) * t_continuous
270
+ if is_cond_classifier:
271
+ y = model_kwargs.get("y", None)
272
+ if y is None:
273
+ raise ValueError("For classifier guidance, the label y has to be in the input.")
274
+ t_discrete = get_model_input_time(t_continuous)
275
+ noise_uncond = model(x, t_discrete, **model_kwargs)
276
+ cond_grad = cond_fn(x, t_discrete, y)
277
+ if is_deis:
278
+ sigma_t = noise_schedule.marginal_std(t_continuous / 1000.)
279
+ else:
280
+ sigma_t = noise_schedule.marginal_std(t_continuous)
281
+ dims = len(cond_grad.shape) - 1
282
+ return noise_uncond - sigma_t[(...,) + (None,)*dims] * cond_grad
283
+ else:
284
+ t_discrete = get_model_input_time(t_continuous)
285
+ return model(x, t_discrete, **model_kwargs)
286
+
287
+ return model_fn
288
+
289
+
290
+ class DPM_Solver:
291
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
292
+ """Construct a DPM-Solver.
293
+
294
+ Args:
295
+ model_fn: A noise prediction model function which accepts the continuous-time input
296
+ (t in [epsilon, T]):
297
+ ``
298
+ def model_fn(x, t_continuous):
299
+ return noise
300
+ ``
301
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
302
+ """
303
+ self.model = model_fn
304
+ self.noise_schedule = noise_schedule
305
+ self.predict_x0 = predict_x0
306
+ self.thresholding = thresholding
307
+ self.max_val = max_val
308
+
309
+ def model_fn(self, x, t):
310
+ if self.predict_x0:
311
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
312
+ noise = self.model(x, t)
313
+ dims = len(x.shape) - 1
314
+ x0 = (x - sigma_t[(...,) + (None,)*dims] * noise) / alpha_t[(...,) + (None,)*dims]
315
+ if self.thresholding:
316
+ p = 0.995
317
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
318
+ s = torch.maximum(s, torch.ones_like(s).to(s.device))[(...,) + (None,)*dims]
319
+ x0 = torch.clamp(x0, -s, s) / (s / self.max_val)
320
+ return x0
321
+ else:
322
+ return self.model(x, t)
323
+
324
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
325
+ """Compute the intermediate time steps for sampling.
326
+
327
+ Args:
328
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
329
+ - 'logSNR': uniform logSNR for the time steps, **recommended for DPM-Solver**.
330
+ - 'time_uniform': uniform time for the time steps. (Used in DDIM and DDPM.)
331
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
332
+ t_T: A `float`. The starting time of the sampling (default is T).
333
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
334
+ N: A `int`. The total number of the spacing of the time steps.
335
+ device: A torch device.
336
+ Returns:
337
+ A pytorch tensor of the time steps, with the shape (N + 1,).
338
+ """
339
+ if skip_type == 'logSNR':
340
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
341
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
342
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
343
+ # print(torch.min(torch.abs(logSNR_steps - self.noise_schedule.marginal_lambda(self.noise_schedule.inverse_lambda(logSNR_steps)))).item())
344
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
345
+ elif skip_type == 't2':
346
+ t_order = 2
347
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
348
+ return t
349
+ elif skip_type == 'time_uniform':
350
+ return torch.linspace(t_T, t_0, N + 1).to(device)
351
+ elif skip_type == 'time_quadratic':
352
+ t = torch.linspace(t_0, t_T, 10000000).to(device)
353
+ quadratic_t = torch.sqrt(t)
354
+ quadratic_steps = torch.linspace(quadratic_t[0], quadratic_t[-1], N + 1).to(device)
355
+ return torch.flip(torch.cat([t[torch.searchsorted(quadratic_t, quadratic_steps)[:-1]], t_T * torch.ones((1,)).to(device)], dim=0), dims=[0])
356
+ else:
357
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
358
+
359
+ def get_time_steps_for_dpm_solver_fast(self, skip_type, t_T, t_0, steps, order, device):
360
+ """
361
+ Compute the intermediate time steps and the order of each step for sampling by DPM-Solver-fast.
362
+
363
+ We recommend DPM-Solver-fast for fast sampling of DPMs. Given a fixed number of function evaluations by `steps`,
364
+ the sampling procedure by DPM-Solver-fast is:
365
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
366
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
367
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
368
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
369
+
370
+ ============================================
371
+ Args:
372
+ t_T: A `float`. The starting time of the sampling (default is T).
373
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
374
+ steps: A `int`. The total number of function evaluations (NFE).
375
+ device: A torch device.
376
+ Returns:
377
+ orders: A list of the solver order of each step.
378
+ timesteps: A pytorch tensor of the time steps, with the shape of (K + 1,).
379
+ """
380
+ if order == 3:
381
+ K = steps // 3 + 1
382
+ if steps % 3 == 0:
383
+ orders = [3,] * (K - 2) + [2, 1]
384
+ elif steps % 3 == 1:
385
+ orders = [3,] * (K - 1) + [1]
386
+ else:
387
+ orders = [3,] * (K - 1) + [2]
388
+ timesteps = self.get_time_steps(skip_type, t_T, t_0, K, device)
389
+ return orders, timesteps
390
+ elif order == 2:
391
+ K = steps // 2
392
+ if steps % 2 == 0:
393
+ orders = [2,] * K
394
+ else:
395
+ orders = [2,] * K + [1]
396
+ timesteps = self.get_time_steps(skip_type, t_T, t_0, K, device)
397
+ return orders, timesteps
398
+ else:
399
+ raise ValueError("order must >= 2")
400
+
401
+ def denoise_fn(self, x, s, noise_s=None):
402
+ ns = self.noise_schedule
403
+ dims = len(x.shape) - 1
404
+ log_alpha_s = ns.marginal_log_mean_coeff(s)
405
+ sigma_s = ns.marginal_std(s)
406
+
407
+ if noise_s is None:
408
+ noise_s = self.model_fn(x, s)
409
+ x_0 = (
410
+ (x - sigma_s[(...,) + (None,)*dims] * noise_s) / torch.exp(log_alpha_s)[(...,) + (None,)*dims]
411
+ )
412
+ return x_0
413
+
414
+ def dpm_solver_first_update(self, x, s, t, noise_s=None, return_noise=False):
415
+ """
416
+ A single step for DPM-Solver-1.
417
+
418
+ Args:
419
+ x: A pytorch tensor. The initial value at time `s`.
420
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
421
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
422
+ return_noise: A `bool`. If true, also return the predicted noise at time `s`.
423
+ Returns:
424
+ x_t: A pytorch tensor. The approximated solution at time `t`.
425
+ """
426
+ ns = self.noise_schedule
427
+ dims = len(x.shape) - 1
428
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
429
+ h = lambda_t - lambda_s
430
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
431
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
432
+ alpha_t = torch.exp(log_alpha_t)
433
+
434
+ if self.predict_x0:
435
+ phi_1 = (torch.exp(-h) - 1.) / (-1.)
436
+ if noise_s is None:
437
+ noise_s = self.model_fn(x, s)
438
+ x_t = (
439
+ (sigma_t / sigma_s)[(...,) + (None,)*dims] * x
440
+ + (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
441
+ )
442
+ if return_noise:
443
+ return x_t, {'noise_s': noise_s}
444
+ else:
445
+ return x_t
446
+ else:
447
+ phi_1 = torch.expm1(h)
448
+ if noise_s is None:
449
+ noise_s = self.model_fn(x, s)
450
+ x_t = (
451
+ torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
452
+ - (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
453
+ )
454
+ if return_noise:
455
+ return x_t, {'noise_s': noise_s}
456
+ else:
457
+ return x_t
458
+
459
+ def dpm_solver_second_update(self, x, s, t, r1=0.5, noise_s=None, return_noise=False, solver_type='dpm_solver'):
460
+ """
461
+ A single step for DPM-Solver-2.
462
+
463
+ Args:
464
+ x: A pytorch tensor. The initial value at time `s`.
465
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
466
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
467
+ r1: A `float`. The hyperparameter of the second-order solver. We recommend the default setting `0.5`.
468
+ noise_s: A pytorch tensor. The predicted noise at time `s`.
469
+ If `noise_s` is None, we compute the predicted noise by `x` and `s`; otherwise we directly use it.
470
+ return_noise: A `bool`. If true, also return the predicted noise at time `s` and `s1` (the intermediate time).
471
+ Returns:
472
+ x_t: A pytorch tensor. The approximated solution at time `t`.
473
+ """
474
+ if r1 is None:
475
+ r1 = 0.5
476
+ ns = self.noise_schedule
477
+ dims = len(x.shape) - 1
478
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
479
+ h = lambda_t - lambda_s
480
+ lambda_s1 = lambda_s + r1 * h
481
+ s1 = ns.inverse_lambda(lambda_s1)
482
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
483
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
484
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
485
+
486
+ if self.predict_x0:
487
+ phi_11 = torch.expm1(-r1 * h)
488
+ phi_1 = torch.expm1(-h)
489
+
490
+ if noise_s is None:
491
+ noise_s = self.model_fn(x, s)
492
+ x_s1 = (
493
+ (sigma_s1 / sigma_s)[(...,) + (None,)*dims] * x
494
+ - (alpha_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
495
+ )
496
+ noise_s1 = self.model_fn(x_s1, s1)
497
+ if solver_type == 'dpm_solver':
498
+ x_t = (
499
+ (sigma_t / sigma_s)[(...,) + (None,)*dims] * x
500
+ - (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
501
+ - (0.5 / r1) * (alpha_t * phi_1)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
502
+ )
503
+ elif solver_type == 'taylor':
504
+ x_t = (
505
+ (sigma_t / sigma_s)[(...,) + (None,)*dims] * x
506
+ - (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
507
+ + (1. / r1) * (alpha_t * ((torch.exp(-h) - 1.) / h + 1.))[(...,) + (None,)*dims] * (noise_s1 - noise_s)
508
+ )
509
+ else:
510
+ raise ValueError("solver_type must be either dpm_solver or taylor, got {}".format(solver_type))
511
+ else:
512
+ phi_11 = torch.expm1(r1 * h)
513
+ phi_1 = torch.expm1(h)
514
+
515
+ if noise_s is None:
516
+ noise_s = self.model_fn(x, s)
517
+ x_s1 = (
518
+ torch.exp(log_alpha_s1 - log_alpha_s)[(...,) + (None,)*dims] * x
519
+ - (sigma_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
520
+ )
521
+ noise_s1 = self.model_fn(x_s1, s1)
522
+ if solver_type == 'dpm_solver':
523
+ x_t = (
524
+ torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
525
+ - (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
526
+ - (0.5 / r1) * (sigma_t * phi_1)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
527
+ )
528
+ elif solver_type == 'taylor':
529
+ x_t = (
530
+ torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
531
+ - (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
532
+ - (1. / r1) * (sigma_t * ((torch.exp(h) - 1.) / h - 1.))[(...,) + (None,)*dims] * (noise_s1 - noise_s)
533
+ )
534
+ else:
535
+ raise ValueError("solver_type must be either dpm_solver or taylor, got {}".format(solver_type))
536
+ if return_noise:
537
+ return x_t, {'noise_s': noise_s, 'noise_s1': noise_s1}
538
+ else:
539
+ return x_t
540
+
541
+
542
+ def dpm_multistep_second_update(self, x, noise_prev_list, t_prev_list, t, solver_type="dpm_solver"):
543
+ ns = self.noise_schedule
544
+ dims = len(x.shape) - 1
545
+ noise_prev_1, noise_prev_0 = noise_prev_list
546
+ t_prev_1, t_prev_0 = t_prev_list
547
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
548
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
549
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
550
+ alpha_t = torch.exp(log_alpha_t)
551
+
552
+ h_0 = lambda_prev_0 - lambda_prev_1
553
+ h = lambda_t - lambda_prev_0
554
+ r0 = h_0 / h
555
+ D1_0 = (1. / r0)[(...,) + (None,)*dims] * (noise_prev_0 - noise_prev_1)
556
+ if self.predict_x0:
557
+ if solver_type == 'taylor':
558
+ x_t = (
559
+ (sigma_t / sigma_prev_0)[(...,) + (None,)*dims] * x
560
+ - (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
561
+ + (alpha_t * ((torch.exp(-h) - 1.) / h + 1.))[(...,) + (None,)*dims] * D1_0
562
+ )
563
+ elif solver_type == 'dpm_solver':
564
+ x_t = (
565
+ (sigma_t / sigma_prev_0)[(...,) + (None,)*dims] * x
566
+ - (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
567
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * D1_0
568
+ )
569
+ else:
570
+ if solver_type == 'taylor':
571
+ x_t = (
572
+ torch.exp(log_alpha_t - log_alpha_prev_0)[(...,) + (None,)*dims] * x
573
+ - (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
574
+ - (sigma_t * ((torch.exp(h) - 1.) / h - 1.))[(...,) + (None,)*dims] * D1_0
575
+ )
576
+ elif solver_type == 'dpm_solver':
577
+ x_t = (
578
+ torch.exp(log_alpha_t - log_alpha_prev_0)[(...,) + (None,)*dims] * x
579
+ - (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
580
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * D1_0
581
+ )
582
+ return x_t
583
+
584
+
585
+ def dpm_multistep_third_update(self, x, noise_prev_list, t_prev_list, t, solver_type='dpm_solver'):
586
+ ns = self.noise_schedule
587
+ dims = len(x.shape) - 1
588
+ noise_prev_2, noise_prev_1, noise_prev_0 = noise_prev_list
589
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
590
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
591
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
592
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
593
+ alpha_t = torch.exp(log_alpha_t)
594
+
595
+ h_1 = lambda_prev_1 - lambda_prev_2
596
+ h_0 = lambda_prev_0 - lambda_prev_1
597
+ h = lambda_t - lambda_prev_0
598
+ r0, r1 = h_0 / h, h_1 / h
599
+ D1_0 = (1. / r0)[(...,) + (None,)*dims] * (noise_prev_0 - noise_prev_1)
600
+ D1_1 = (1. / r1)[(...,) + (None,)*dims] * (noise_prev_1 - noise_prev_2)
601
+ D1 = D1_0 + (r0 / (r0 + r1))[(...,) + (None,)*dims] * (D1_0 - D1_1)
602
+ D2 = (1. / (r0 + r1))[(...,) + (None,)*dims] * (D1_0 - D1_1)
603
+ if self.predict_x0:
604
+ x_t = (
605
+ (sigma_t / sigma_prev_0)[(...,) + (None,)*dims] * x
606
+ - (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
607
+ + (alpha_t * ((torch.exp(-h) - 1.) / h + 1.))[(...,) + (None,)*dims] * D1
608
+ - (alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5))[(...,) + (None,)*dims] * D2
609
+ )
610
+ else:
611
+ x_t = (
612
+ torch.exp(log_alpha_t - log_alpha_prev_0)[(...,) + (None,)*dims] * x
613
+ - (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
614
+ - (sigma_t * ((torch.exp(h) - 1.) / h - 1.))[(...,) + (None,)*dims] * D1
615
+ - (sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5))[(...,) + (None,)*dims] * D2
616
+ )
617
+ return x_t
618
+
619
+ def dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., noise_s=None, noise_s1=None, noise_s2=None, return_noise=False, solver_type='dpm_solver'):
620
+ """
621
+ A single step for DPM-Solver-3.
622
+
623
+ Args:
624
+ x: A pytorch tensor. The initial value at time `s`.
625
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
626
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
627
+ r1: A `float`. The hyperparameter of the third-order solver. We recommend the default setting `1 / 3`.
628
+ r2: A `float`. The hyperparameter of the third-order solver. We recommend the default setting `2 / 3`.
629
+ noise_s: A pytorch tensor. The predicted noise at time `s`.
630
+ If `noise_s` is None, we compute the predicted noise by `x` and `s`; otherwise we directly use it.
631
+ noise_s1: A pytorch tensor. The predicted noise at time `s1` (the intermediate time given by `r1`).
632
+ If `noise_s1` is None, we compute the predicted noise by `s1`; otherwise we directly use it.
633
+ Returns:
634
+ x_t: A pytorch tensor. The approximated solution at time `t`.
635
+ """
636
+ if r1 is None:
637
+ r1 = 1. / 3.
638
+ if r2 is None:
639
+ r2 = 2. / 3.
640
+ ns = self.noise_schedule
641
+ dims = len(x.shape) - 1
642
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
643
+ h = lambda_t - lambda_s
644
+ lambda_s1 = lambda_s + r1 * h
645
+ lambda_s2 = lambda_s + r2 * h
646
+ s1 = ns.inverse_lambda(lambda_s1)
647
+ s2 = ns.inverse_lambda(lambda_s2)
648
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
649
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
650
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
651
+
652
+ if self.predict_x0:
653
+ phi_11 = torch.expm1(-r1 * h)
654
+ phi_12 = torch.expm1(-r2 * h)
655
+ phi_1 = torch.expm1(-h)
656
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
657
+ phi_2 = phi_1 / h + 1.
658
+ phi_3 = phi_2 / h - 0.5
659
+
660
+ if noise_s is None:
661
+ noise_s = self.model_fn(x, s)
662
+ if noise_s1 is None:
663
+ x_s1 = (
664
+ (sigma_s1 / sigma_s)[(...,) + (None,)*dims] * x
665
+ - (alpha_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
666
+ )
667
+ noise_s1 = self.model_fn(x_s1, s1)
668
+ if noise_s2 is None:
669
+ x_s2 = (
670
+ (sigma_s2 / sigma_s)[(...,) + (None,)*dims] * x
671
+ - (alpha_s2 * phi_12)[(...,) + (None,)*dims] * noise_s
672
+ + r2 / r1 * (alpha_s2 * phi_22)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
673
+ )
674
+ noise_s2 = self.model_fn(x_s2, s2)
675
+ if solver_type == 'dpm_solver':
676
+ x_t = (
677
+ (sigma_t / sigma_s)[(...,) + (None,)*dims] * x
678
+ - (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
679
+ + (1. / r2) * (alpha_t * phi_2)[(...,) + (None,)*dims] * (noise_s2 - noise_s)
680
+ )
681
+ elif solver_type == 'taylor':
682
+ D1_0 = (1. / r1) * (noise_s1 - noise_s)
683
+ D1_1 = (1. / r2) * (noise_s2 - noise_s)
684
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
685
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
686
+ x_t = (
687
+ (sigma_t / sigma_s)[(...,) + (None,)*dims] * x
688
+ - (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
689
+ + (alpha_t * phi_2)[(...,) + (None,)*dims] * D1
690
+ - (alpha_t * phi_3)[(...,) + (None,)*dims] * D2
691
+ )
692
+ else:
693
+ raise ValueError("solver_type must be either dpm_solver or dpm_solver++, got {}".format(solver_type))
694
+ else:
695
+ phi_11 = torch.expm1(r1 * h)
696
+ phi_12 = torch.expm1(r2 * h)
697
+ phi_1 = torch.expm1(h)
698
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
699
+ phi_2 = phi_1 / h - 1.
700
+ phi_3 = phi_2 / h - 0.5
701
+
702
+ if noise_s is None:
703
+ noise_s = self.model_fn(x, s)
704
+ if noise_s1 is None:
705
+ x_s1 = (
706
+ torch.exp(log_alpha_s1 - log_alpha_s)[(...,) + (None,)*dims] * x
707
+ - (sigma_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
708
+ )
709
+ noise_s1 = self.model_fn(x_s1, s1)
710
+ if noise_s2 is None:
711
+ x_s2 = (
712
+ torch.exp(log_alpha_s2 - log_alpha_s)[(...,) + (None,)*dims] * x
713
+ - (sigma_s2 * phi_12)[(...,) + (None,)*dims] * noise_s
714
+ - r2 / r1 * (sigma_s2 * phi_22)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
715
+ )
716
+ noise_s2 = self.model_fn(x_s2, s2)
717
+ if solver_type == 'dpm_solver':
718
+ x_t = (
719
+ torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
720
+ - (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
721
+ - (1. / r2) * (sigma_t * phi_2)[(...,) + (None,)*dims] * (noise_s2 - noise_s)
722
+ )
723
+ elif solver_type == 'taylor':
724
+ D1_0 = (1. / r1) * (noise_s1 - noise_s)
725
+ D1_1 = (1. / r2) * (noise_s2 - noise_s)
726
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
727
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
728
+ x_t = (
729
+ torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
730
+ - (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
731
+ - (sigma_t * phi_2)[(...,) + (None,)*dims] * D1
732
+ - (sigma_t * phi_3)[(...,) + (None,)*dims] * D2
733
+ )
734
+ else:
735
+ raise ValueError("solver_type must be either dpm_solver or dpm_solver++, got {}".format(solver_type))
736
+
737
+ if return_noise:
738
+ return x_t, {'noise_s': noise_s, 'noise_s1': noise_s1, 'noise_s2': noise_s2}
739
+ else:
740
+ return x_t
741
+
742
+ def dpm_solver_update(self, x, s, t, order, return_noise=False, solver_type='dpm_solver', r1=None, r2=None):
743
+ """
744
+ A single step for DPM-Solver of the given order `order`.
745
+
746
+ Args:
747
+ x: A pytorch tensor. The initial value at time `s`.
748
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
749
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
750
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
751
+ Returns:
752
+ x_t: A pytorch tensor. The approximated solution at time `t`.
753
+ """
754
+ if order == 1:
755
+ return self.dpm_solver_first_update(x, s, t, return_noise=return_noise)
756
+ elif order == 2:
757
+ return self.dpm_solver_second_update(x, s, t, return_noise=return_noise, solver_type=solver_type, r1=r1)
758
+ elif order == 3:
759
+ return self.dpm_solver_third_update(x, s, t, return_noise=return_noise, solver_type=solver_type, r1=r1, r2=r2)
760
+ else:
761
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
762
+
763
+ def dpm_multistep_update(self, x, noise_prev_list, t_prev_list, t, order, solver_type='taylor'):
764
+ """
765
+ A single step for DPM-Solver of the given order `order`.
766
+
767
+ Args:
768
+ x: A pytorch tensor. The initial value at time `s`.
769
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
770
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
771
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
772
+ Returns:
773
+ x_t: A pytorch tensor. The approximated solution at time `t`.
774
+ """
775
+ if order == 1:
776
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, noise_s=noise_prev_list[-1])
777
+ elif order == 2:
778
+ return self.dpm_multistep_second_update(x, noise_prev_list, t_prev_list, t, solver_type=solver_type)
779
+ elif order == 3:
780
+ return self.dpm_multistep_third_update(x, noise_prev_list, t_prev_list, t, solver_type=solver_type)
781
+ else:
782
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
783
+
784
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'):
785
+ """
786
+ The adaptive step size solver based on DPM-Solver.
787
+
788
+ Args:
789
+ x: A pytorch tensor. The initial value at time `t_T`.
790
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
791
+ t_T: A `float`. The starting time of the sampling (default is T).
792
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
793
+ h_init: A `float`. The initial step size (for logSNR).
794
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
795
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
796
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
797
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
798
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
799
+ Returns:
800
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
801
+
802
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
803
+ """
804
+ ns = self.noise_schedule
805
+ s = t_T * torch.ones((x.shape[0],)).to(x)
806
+ lambda_s = ns.marginal_lambda(s)
807
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
808
+ h = h_init * torch.ones_like(s).to(x)
809
+ x_prev = x
810
+ nfe = 0
811
+ if order == 2:
812
+ r1 = 0.5
813
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_noise=True)
814
+ higher_update = lambda x, s, t, **kwargs: self.dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
815
+ elif order == 3:
816
+ r1, r2 = 1. / 3., 2. / 3.
817
+ lower_update = lambda x, s, t: self.dpm_solver_second_update(x, s, t, r1=r1, return_noise=True, solver_type=solver_type)
818
+ higher_update = lambda x, s, t, **kwargs: self.dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
819
+ else:
820
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
821
+ while torch.abs((s - t_0)).mean() > t_err:
822
+ t = ns.inverse_lambda(lambda_s + h)
823
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
824
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
825
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
826
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
827
+ E = norm_fn((x_higher - x_lower) / delta).max()
828
+ if torch.all(E <= 1.):
829
+ x = x_higher
830
+ s = t
831
+ x_prev = x_lower
832
+ lambda_s = ns.marginal_lambda(s)
833
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
834
+ nfe += order
835
+ print('adaptive solver nfe', nfe)
836
+ return x
837
+
838
+ def sample(self, x, steps=10, eps=1e-4, T=None, order=3, skip_type='time_uniform',
839
+ denoise=False, method='fast', solver_type='dpm_solver', atol=0.0078,
840
+ rtol=0.05,
841
+ ):
842
+ """
843
+ Compute the sample at time `eps` by DPM-Solver, given the initial `x` at time `T`.
844
+
845
+ We support the following algorithms:
846
+
847
+ - Adaptive step size DPM-Solver (i.e. DPM-Solver-12 and DPM-Solver-23)
848
+
849
+ - Fixed order DPM-Solver (i.e. DPM-Solver-1, DPM-Solver-2 and DPM-Solver-3).
850
+
851
+ - Fast version of DPM-Solver (i.e. DPM-Solver-fast), which uses uniform logSNR steps and combine
852
+ different orders of DPM-Solver.
853
+
854
+ **We recommend DPM-Solver-fast for both fast sampling in few steps (<=20) and fast convergence in many steps (50 to 100).**
855
+
856
+ Choosing the algorithms:
857
+
858
+ - If `adaptive_step_size` is True:
859
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
860
+ If `order`=2, we use DPM-Solver-12 which combines DPM-Solver-1 and DPM-Solver-2.
861
+ If `order`=3, we use DPM-Solver-23 which combines DPM-Solver-2 and DPM-Solver-3.
862
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
863
+ (NFE) and the sample quality.
864
+
865
+ - If `adaptive_step_size` is False and `fast_version` is True:
866
+ We ignore `order` and use DPM-Solver-fast with number of function evaluations (NFE) = `steps`.
867
+ We ignore `skip_type` and use uniform logSNR steps for DPM-Solver-fast.
868
+ Given a fixed NFE=`steps`, the sampling procedure by DPM-Solver-fast is:
869
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
870
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
871
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
872
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
873
+
874
+ - If `adaptive_step_size` is False and `fast_version` is False:
875
+ We use DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
876
+ We support three types of `skip_type`:
877
+ - 'logSNR': uniform logSNR for the time steps, **recommended for DPM-Solver**.
878
+ - 'time_uniform': uniform time for the time steps. (Used in DDIM and DDPM.)
879
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM.)
880
+
881
+ =====================================================
882
+ Args:
883
+ x: A pytorch tensor. The initial value at time `T` (a sample from the normal distribution).
884
+ steps: A `int`. The total number of function evaluations (NFE).
885
+ eps: A `float`. The ending time of the sampling.
886
+ We recommend `eps`=1e-3 when `steps` <= 15; and `eps`=1e-4 when `steps` > 15.
887
+ T: A `float`. The starting time of the sampling. Default is `None`.
888
+ If `T` is None, we use self.noise_schedule.T.
889
+ order: A `int`. The order of DPM-Solver.
890
+ skip_type: A `str`. The type for the spacing of the time steps. Default is 'logSNR'.
891
+ adaptive_step_size: A `bool`. If true, use the adaptive step size DPM-Solver.
892
+ fast_version: A `bool`. If true, use DPM-Solver-fast (recommended).
893
+ atol: A `float`. The absolute tolerance of the adaptive step size solver.
894
+ rtol: A `float`. The relative tolerance of the adaptive step size solver.
895
+ Returns:
896
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
897
+
898
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
899
+ """
900
+ t_0 = eps
901
+ t_T = self.noise_schedule.T if T is None else T
902
+ device = x.device
903
+ if method == 'adaptive':
904
+ with torch.no_grad():
905
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
906
+ elif method == 'multistep':
907
+ assert steps >= order
908
+ if timesteps is None:
909
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
910
+ assert timesteps.shape[0] - 1 == steps
911
+ with torch.no_grad():
912
+ vec_t = timesteps[0].expand((x.shape[0]))
913
+ noise_prev_list = [self.model_fn(x, vec_t)]
914
+ t_prev_list = [vec_t]
915
+ for init_order in range(1, order):
916
+ vec_t = timesteps[init_order].expand(x.shape[0])
917
+ x = self.dpm_multistep_update(x, noise_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type)
918
+ noise_prev_list.append(self.model_fn(x, vec_t))
919
+ t_prev_list.append(vec_t)
920
+ for step in range(order, steps + 1):
921
+ vec_t = timesteps[step].expand(x.shape[0])
922
+ x = self.dpm_multistep_update(x, noise_prev_list, t_prev_list, vec_t, order, solver_type=solver_type)
923
+ for i in range(order - 1):
924
+ t_prev_list[i] = t_prev_list[i + 1]
925
+ noise_prev_list[i] = noise_prev_list[i + 1]
926
+ t_prev_list[-1] = vec_t
927
+ if step < steps:
928
+ noise_prev_list[-1] = self.model_fn(x, vec_t)
929
+ elif method == 'fast':
930
+ orders, _ = self.get_time_steps_for_dpm_solver_fast(skip_type=skip_type, t_T=t_T, t_0=t_0, steps=steps, order=order, device=device)
931
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
932
+ with torch.no_grad():
933
+ i = 0
934
+ for order in orders:
935
+ vec_s, vec_t = torch.ones((x.shape[0],)).to(device) * timesteps[i], torch.ones((x.shape[0],)).to(device) * timesteps[i + order]
936
+ h = self.noise_schedule.marginal_lambda(timesteps[i + order]) - self.noise_schedule.marginal_lambda(timesteps[i])
937
+ r1 = None if order <= 1 else (self.noise_schedule.marginal_lambda(timesteps[i + 1]) - self.noise_schedule.marginal_lambda(timesteps[i])) / h
938
+ r2 = None if order <= 2 else (self.noise_schedule.marginal_lambda(timesteps[i + 2]) - self.noise_schedule.marginal_lambda(timesteps[i])) / h
939
+ x = self.dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
940
+ i += order
941
+ elif method == 'singlestep':
942
+ N_steps = steps // order
943
+ orders = [order,] * N_steps
944
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=N_steps, device=device)
945
+ assert len(timesteps) - 1 == N_steps
946
+ with torch.no_grad():
947
+ for i, order in enumerate(orders):
948
+ vec_s, vec_t = torch.ones((x.shape[0],)).to(device) * timesteps[i], torch.ones((x.shape[0],)).to(device) * timesteps[i + 1]
949
+ x = self.dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type)
950
+ if denoise:
951
+ x = self.denoise_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
952
+ return x
src/COP-GEN-Beta/encode_majortom_images.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import numpy as np
3
+ import torch
4
+ from datasets import MajorTOMThumbnail
5
+ from torch.utils.data import DataLoader
6
+ from libs.autoencoder import get_model
7
+ import argparse
8
+ from tqdm import tqdm
9
+ import os
10
+
11
+ torch.manual_seed(0)
12
+ np.random.seed(0)
13
+
14
+
15
+ def get_existing_encoded_files(output_dir, image_paths):
16
+ """Returns a set of filenames that already have their encoded features saved"""
17
+ existing_files = set()
18
+ missing_files = set()
19
+
20
+ for img_path in image_paths:
21
+ filename = os.path.basename(img_path).split('.')[0]
22
+ npy_path = os.path.join(output_dir, f'{filename}.npy')
23
+
24
+ if os.path.exists(npy_path):
25
+ existing_files.add(filename)
26
+ else:
27
+ missing_files.add(filename)
28
+
29
+ print(f"\nFound {len(existing_files)} already encoded files")
30
+ print(f"Missing {len(missing_files)} files to encode")
31
+
32
+ return existing_files, missing_files
33
+
34
+ def main(resolution=256):
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument('--path', type=str)
37
+ parser.add_argument('--resolution', type=int, default=256)
38
+ parser.add_argument('--output_dir', type=str)
39
+ args = parser.parse_args()
40
+
41
+ # Create output directory if it doesn't exist
42
+ os.makedirs(args.output_dir, exist_ok=True)
43
+
44
+ datafactory = MajorTOMThumbnail(path=args.path, resolution=args.resolution)
45
+ dataset = datafactory.get_split(split=None, labeled=False, nosplit=True)
46
+ image_paths = dataset.image_paths
47
+
48
+ # Check for existing encoded files
49
+ # existing_files, missing_files = get_existing_encoded_files(args.output_dir, image_paths)
50
+ # TODO: Restart is not working yet
51
+ existing_files = set()
52
+ missing_files = set(image_paths)
53
+
54
+ if len(missing_files) == 0:
55
+ print("All files have already been encoded. Exiting...")
56
+ return
57
+
58
+ # Filter dataset to only process missing files
59
+ filtered_indices = [i for i, path in enumerate(image_paths)
60
+ if os.path.basename(path).split('.')[0] not in existing_files]
61
+ dataset.image_paths = [image_paths[i] for i in filtered_indices]
62
+
63
+ dataset_loader = DataLoader(dataset, batch_size=128, shuffle=False, drop_last=False,
64
+ num_workers=8, pin_memory=True, persistent_workers=True)
65
+
66
+ model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
67
+ # model = nn.DataParallel(model)
68
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
69
+ model.to(device)
70
+
71
+ processed_count = 0
72
+ for img, filename in tqdm(dataset_loader, desc="Encoding images", unit="batch"):
73
+ img = img.to(device)
74
+ moments = model(img, fn='encode_moments')
75
+ moments = moments.detach().cpu().numpy()
76
+
77
+ for moment, fname in zip(moments, filename):
78
+ np.save(f'{args.output_dir}/{fname}.npy', moment)
79
+ processed_count += 1
80
+
81
+ print(f'\nProcessed {processed_count} new files')
82
+ print(f'Total encoded files: {len(existing_files) + processed_count}')
83
+
84
+ # features = []
85
+ # labels = []
86
+ # features = np.concatenate(features, axis=0)
87
+ # labels = np.concatenate(labels, axis=0)
88
+ # print(f'features.shape={features.shape}')
89
+ # print(f'labels.shape={labels.shape}')
90
+ # np.save(f'imagenet{resolution}_features.npy', features)
91
+ # np.save(f'imagenet{resolution}_labels.npy', labels)
92
+
93
+
94
+ if __name__ == "__main__":
95
+ main()
src/COP-GEN-Beta/libs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # codes from third party
src/COP-GEN-Beta/libs/autoencoder.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from einops import rearrange
5
+
6
+
7
+ class LinearAttention(nn.Module):
8
+ def __init__(self, dim, heads=4, dim_head=32):
9
+ super().__init__()
10
+ self.heads = heads
11
+ hidden_dim = dim_head * heads
12
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
13
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
14
+
15
+ def forward(self, x):
16
+ b, c, h, w = x.shape
17
+ qkv = self.to_qkv(x)
18
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
19
+ k = k.softmax(dim=-1)
20
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
21
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
22
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
23
+ return self.to_out(out)
24
+
25
+
26
+ def nonlinearity(x):
27
+ # swish
28
+ return x*torch.sigmoid(x)
29
+
30
+
31
+ def Normalize(in_channels, num_groups=32):
32
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
33
+
34
+
35
+ class Upsample(nn.Module):
36
+ def __init__(self, in_channels, with_conv):
37
+ super().__init__()
38
+ self.with_conv = with_conv
39
+ if self.with_conv:
40
+ self.conv = torch.nn.Conv2d(in_channels,
41
+ in_channels,
42
+ kernel_size=3,
43
+ stride=1,
44
+ padding=1)
45
+
46
+ def forward(self, x):
47
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
48
+ if self.with_conv:
49
+ x = self.conv(x)
50
+ return x
51
+
52
+
53
+ class Downsample(nn.Module):
54
+ def __init__(self, in_channels, with_conv):
55
+ super().__init__()
56
+ self.with_conv = with_conv
57
+ if self.with_conv:
58
+ # no asymmetric padding in torch conv, must do it ourselves
59
+ self.conv = torch.nn.Conv2d(in_channels,
60
+ in_channels,
61
+ kernel_size=3,
62
+ stride=2,
63
+ padding=0)
64
+
65
+ def forward(self, x):
66
+ if self.with_conv:
67
+ pad = (0,1,0,1)
68
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
69
+ x = self.conv(x)
70
+ else:
71
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
72
+ return x
73
+
74
+
75
+ class ResnetBlock(nn.Module):
76
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
77
+ dropout, temb_channels=512):
78
+ super().__init__()
79
+ self.in_channels = in_channels
80
+ out_channels = in_channels if out_channels is None else out_channels
81
+ self.out_channels = out_channels
82
+ self.use_conv_shortcut = conv_shortcut
83
+
84
+ self.norm1 = Normalize(in_channels)
85
+ self.conv1 = torch.nn.Conv2d(in_channels,
86
+ out_channels,
87
+ kernel_size=3,
88
+ stride=1,
89
+ padding=1)
90
+ if temb_channels > 0:
91
+ self.temb_proj = torch.nn.Linear(temb_channels,
92
+ out_channels)
93
+ self.norm2 = Normalize(out_channels)
94
+ self.dropout = torch.nn.Dropout(dropout)
95
+ self.conv2 = torch.nn.Conv2d(out_channels,
96
+ out_channels,
97
+ kernel_size=3,
98
+ stride=1,
99
+ padding=1)
100
+ if self.in_channels != self.out_channels:
101
+ if self.use_conv_shortcut:
102
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ else:
108
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
109
+ out_channels,
110
+ kernel_size=1,
111
+ stride=1,
112
+ padding=0)
113
+
114
+ def forward(self, x, temb):
115
+ h = x
116
+ h = self.norm1(h)
117
+ h = nonlinearity(h)
118
+ h = self.conv1(h)
119
+
120
+ if temb is not None:
121
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
122
+
123
+ h = self.norm2(h)
124
+ h = nonlinearity(h)
125
+ h = self.dropout(h)
126
+ h = self.conv2(h)
127
+
128
+ if self.in_channels != self.out_channels:
129
+ if self.use_conv_shortcut:
130
+ x = self.conv_shortcut(x)
131
+ else:
132
+ x = self.nin_shortcut(x)
133
+
134
+ return x+h
135
+
136
+
137
+ class LinAttnBlock(LinearAttention):
138
+ """to match AttnBlock usage"""
139
+ def __init__(self, in_channels):
140
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
141
+
142
+
143
+ class AttnBlock(nn.Module):
144
+ def __init__(self, in_channels):
145
+ super().__init__()
146
+ self.in_channels = in_channels
147
+
148
+ self.norm = Normalize(in_channels)
149
+ self.q = torch.nn.Conv2d(in_channels,
150
+ in_channels,
151
+ kernel_size=1,
152
+ stride=1,
153
+ padding=0)
154
+ self.k = torch.nn.Conv2d(in_channels,
155
+ in_channels,
156
+ kernel_size=1,
157
+ stride=1,
158
+ padding=0)
159
+ self.v = torch.nn.Conv2d(in_channels,
160
+ in_channels,
161
+ kernel_size=1,
162
+ stride=1,
163
+ padding=0)
164
+ self.proj_out = torch.nn.Conv2d(in_channels,
165
+ in_channels,
166
+ kernel_size=1,
167
+ stride=1,
168
+ padding=0)
169
+
170
+
171
+ def forward(self, x):
172
+ h_ = x
173
+ h_ = self.norm(h_)
174
+ q = self.q(h_)
175
+ k = self.k(h_)
176
+ v = self.v(h_)
177
+
178
+ # compute attention
179
+ b,c,h,w = q.shape
180
+ q = q.reshape(b,c,h*w)
181
+ q = q.permute(0,2,1) # b,hw,c
182
+ k = k.reshape(b,c,h*w) # b,c,hw
183
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
184
+ w_ = w_ * (int(c)**(-0.5))
185
+ w_ = torch.nn.functional.softmax(w_, dim=2)
186
+
187
+ # attend to values
188
+ v = v.reshape(b,c,h*w)
189
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
190
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
191
+ h_ = h_.reshape(b,c,h,w)
192
+
193
+ h_ = self.proj_out(h_)
194
+
195
+ return x+h_
196
+
197
+
198
+ def make_attn(in_channels, attn_type="vanilla"):
199
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
200
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
201
+ if attn_type == "vanilla":
202
+ return AttnBlock(in_channels)
203
+ elif attn_type == "none":
204
+ return nn.Identity(in_channels)
205
+ else:
206
+ return LinAttnBlock(in_channels)
207
+
208
+
209
+ class Encoder(nn.Module):
210
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
211
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
212
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
213
+ **ignore_kwargs):
214
+ super().__init__()
215
+ if use_linear_attn: attn_type = "linear"
216
+ self.ch = ch
217
+ self.temb_ch = 0
218
+ self.num_resolutions = len(ch_mult)
219
+ self.num_res_blocks = num_res_blocks
220
+ self.resolution = resolution
221
+ self.in_channels = in_channels
222
+
223
+ # downsampling
224
+ self.conv_in = torch.nn.Conv2d(in_channels,
225
+ self.ch,
226
+ kernel_size=3,
227
+ stride=1,
228
+ padding=1)
229
+
230
+ curr_res = resolution
231
+ in_ch_mult = (1,)+tuple(ch_mult)
232
+ self.in_ch_mult = in_ch_mult
233
+ self.down = nn.ModuleList()
234
+ for i_level in range(self.num_resolutions):
235
+ block = nn.ModuleList()
236
+ attn = nn.ModuleList()
237
+ block_in = ch*in_ch_mult[i_level]
238
+ block_out = ch*ch_mult[i_level]
239
+ for i_block in range(self.num_res_blocks):
240
+ block.append(ResnetBlock(in_channels=block_in,
241
+ out_channels=block_out,
242
+ temb_channels=self.temb_ch,
243
+ dropout=dropout))
244
+ block_in = block_out
245
+ if curr_res in attn_resolutions:
246
+ attn.append(make_attn(block_in, attn_type=attn_type))
247
+ down = nn.Module()
248
+ down.block = block
249
+ down.attn = attn
250
+ if i_level != self.num_resolutions-1:
251
+ down.downsample = Downsample(block_in, resamp_with_conv)
252
+ curr_res = curr_res // 2
253
+ self.down.append(down)
254
+
255
+ # middle
256
+ self.mid = nn.Module()
257
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
258
+ out_channels=block_in,
259
+ temb_channels=self.temb_ch,
260
+ dropout=dropout)
261
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
262
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
263
+ out_channels=block_in,
264
+ temb_channels=self.temb_ch,
265
+ dropout=dropout)
266
+
267
+ # end
268
+ self.norm_out = Normalize(block_in)
269
+ self.conv_out = torch.nn.Conv2d(block_in,
270
+ 2*z_channels if double_z else z_channels,
271
+ kernel_size=3,
272
+ stride=1,
273
+ padding=1)
274
+
275
+ def forward(self, x):
276
+ # timestep embedding
277
+ temb = None
278
+
279
+ # downsampling
280
+ hs = [self.conv_in(x)]
281
+ for i_level in range(self.num_resolutions):
282
+ for i_block in range(self.num_res_blocks):
283
+ h = self.down[i_level].block[i_block](hs[-1], temb)
284
+ if len(self.down[i_level].attn) > 0:
285
+ h = self.down[i_level].attn[i_block](h)
286
+ hs.append(h)
287
+ if i_level != self.num_resolutions-1:
288
+ hs.append(self.down[i_level].downsample(hs[-1]))
289
+
290
+ # middle
291
+ h = hs[-1]
292
+ h = self.mid.block_1(h, temb)
293
+ h = self.mid.attn_1(h)
294
+ h = self.mid.block_2(h, temb)
295
+
296
+ # end
297
+ h = self.norm_out(h)
298
+ h = nonlinearity(h)
299
+ h = self.conv_out(h)
300
+ return h
301
+
302
+
303
+ class Decoder(nn.Module):
304
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
305
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
306
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
307
+ attn_type="vanilla", **ignorekwargs):
308
+ super().__init__()
309
+ if use_linear_attn: attn_type = "linear"
310
+ self.ch = ch
311
+ self.temb_ch = 0
312
+ self.num_resolutions = len(ch_mult)
313
+ self.num_res_blocks = num_res_blocks
314
+ self.resolution = resolution
315
+ self.in_channels = in_channels
316
+ self.give_pre_end = give_pre_end
317
+ self.tanh_out = tanh_out
318
+
319
+ # compute in_ch_mult, block_in and curr_res at lowest res
320
+ in_ch_mult = (1,)+tuple(ch_mult)
321
+ block_in = ch*ch_mult[self.num_resolutions-1]
322
+ curr_res = resolution // 2**(self.num_resolutions-1)
323
+ self.z_shape = (1,z_channels,curr_res,curr_res)
324
+ print("Working with z of shape {} = {} dimensions.".format(
325
+ self.z_shape, np.prod(self.z_shape)))
326
+
327
+ # z to block_in
328
+ self.conv_in = torch.nn.Conv2d(z_channels,
329
+ block_in,
330
+ kernel_size=3,
331
+ stride=1,
332
+ padding=1)
333
+
334
+ # middle
335
+ self.mid = nn.Module()
336
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
337
+ out_channels=block_in,
338
+ temb_channels=self.temb_ch,
339
+ dropout=dropout)
340
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
341
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
342
+ out_channels=block_in,
343
+ temb_channels=self.temb_ch,
344
+ dropout=dropout)
345
+
346
+ # upsampling
347
+ self.up = nn.ModuleList()
348
+ for i_level in reversed(range(self.num_resolutions)):
349
+ block = nn.ModuleList()
350
+ attn = nn.ModuleList()
351
+ block_out = ch*ch_mult[i_level]
352
+ for i_block in range(self.num_res_blocks+1):
353
+ block.append(ResnetBlock(in_channels=block_in,
354
+ out_channels=block_out,
355
+ temb_channels=self.temb_ch,
356
+ dropout=dropout))
357
+ block_in = block_out
358
+ if curr_res in attn_resolutions:
359
+ attn.append(make_attn(block_in, attn_type=attn_type))
360
+ up = nn.Module()
361
+ up.block = block
362
+ up.attn = attn
363
+ if i_level != 0:
364
+ up.upsample = Upsample(block_in, resamp_with_conv)
365
+ curr_res = curr_res * 2
366
+ self.up.insert(0, up) # prepend to get consistent order
367
+
368
+ # end
369
+ self.norm_out = Normalize(block_in)
370
+ self.conv_out = torch.nn.Conv2d(block_in,
371
+ out_ch,
372
+ kernel_size=3,
373
+ stride=1,
374
+ padding=1)
375
+
376
+ def forward(self, z):
377
+ #assert z.shape[1:] == self.z_shape[1:]
378
+ self.last_z_shape = z.shape
379
+
380
+ # timestep embedding
381
+ temb = None
382
+
383
+ # z to block_in
384
+ h = self.conv_in(z)
385
+
386
+ # middle
387
+ h = self.mid.block_1(h, temb)
388
+ h = self.mid.attn_1(h)
389
+ h = self.mid.block_2(h, temb)
390
+
391
+ # upsampling
392
+ for i_level in reversed(range(self.num_resolutions)):
393
+ for i_block in range(self.num_res_blocks+1):
394
+ h = self.up[i_level].block[i_block](h, temb)
395
+ if len(self.up[i_level].attn) > 0:
396
+ h = self.up[i_level].attn[i_block](h)
397
+ if i_level != 0:
398
+ h = self.up[i_level].upsample(h)
399
+
400
+ # end
401
+ if self.give_pre_end:
402
+ return h
403
+
404
+ h = self.norm_out(h)
405
+ h = nonlinearity(h)
406
+ h = self.conv_out(h)
407
+ if self.tanh_out:
408
+ h = torch.tanh(h)
409
+ return h
410
+
411
+
412
+ class FrozenAutoencoderKL(nn.Module):
413
+ def __init__(self, ddconfig, embed_dim, pretrained_path, scale_factor=0.18215):
414
+ super().__init__()
415
+ print(f'Create autoencoder with scale_factor={scale_factor}')
416
+ self.encoder = Encoder(**ddconfig)
417
+ self.decoder = Decoder(**ddconfig)
418
+ assert ddconfig["double_z"]
419
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
420
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
421
+ self.embed_dim = embed_dim
422
+ self.scale_factor = scale_factor
423
+ m, u = self.load_state_dict(torch.load(pretrained_path, map_location='cpu'))
424
+ assert len(m) == 0 and len(u) == 0
425
+ self.eval()
426
+ self.requires_grad_(False)
427
+
428
+ def encode_moments(self, x):
429
+ h = self.encoder(x)
430
+ moments = self.quant_conv(h)
431
+ return moments
432
+
433
+ def sample(self, moments):
434
+ mean, logvar = torch.chunk(moments, 2, dim=1)
435
+ logvar = torch.clamp(logvar, -30.0, 20.0)
436
+ std = torch.exp(0.5 * logvar)
437
+ z = mean + std * torch.randn_like(mean)
438
+ z = self.scale_factor * z
439
+ return z
440
+
441
+ def encode(self, x):
442
+ moments = self.encode_moments(x)
443
+ z = self.sample(moments)
444
+ return z
445
+
446
+ def decode(self, z):
447
+ z = (1. / self.scale_factor) * z
448
+ z = self.post_quant_conv(z)
449
+ dec = self.decoder(z)
450
+ return dec
451
+
452
+ def forward(self, inputs, fn):
453
+ if fn == 'encode_moments':
454
+ return self.encode_moments(inputs)
455
+ elif fn == 'encode':
456
+ return self.encode(inputs)
457
+ elif fn == 'decode':
458
+ return self.decode(inputs)
459
+ else:
460
+ raise NotImplementedError
461
+
462
+
463
+ def get_model(pretrained_path, scale_factor=0.18215):
464
+ ddconfig = dict(
465
+ double_z=True,
466
+ z_channels=4,
467
+ resolution=256,
468
+ in_channels=3,
469
+ out_ch=3,
470
+ ch=128,
471
+ ch_mult=[1, 2, 4, 4],
472
+ num_res_blocks=2,
473
+ attn_resolutions=[],
474
+ dropout=0.0
475
+ )
476
+ return FrozenAutoencoderKL(ddconfig, 4, pretrained_path, scale_factor)
477
+
478
+
479
+ def main():
480
+ import torchvision.transforms as transforms
481
+ from torchvision.utils import save_image
482
+ import os
483
+ from PIL import Image
484
+
485
+ model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
486
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
487
+ model = model.to(device)
488
+
489
+ scale_factor = 0.18215
490
+ T = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor()])
491
+ path = 'imgs'
492
+ fnames = os.listdir(path)
493
+ for fname in fnames:
494
+ p = os.path.join(path, fname)
495
+ img = Image.open(p)
496
+ img = T(img)
497
+ img = img * 2. - 1
498
+ img = img[None, ...]
499
+ img = img.to(device)
500
+
501
+ # with torch.cuda.amp.autocast():
502
+ # moments = model.encode_moments(img)
503
+ # mean, logvar = torch.chunk(moments, 2, dim=1)
504
+ # logvar = torch.clamp(logvar, -30.0, 20.0)
505
+ # std = torch.exp(0.5 * logvar)
506
+ # zs = [(mean + std * torch.randn_like(mean)) * scale_factor for _ in range(4)]
507
+ # recons = [model.decode(z) for z in zs]
508
+
509
+ with torch.cuda.amp.autocast():
510
+ print('test encode & decode')
511
+ recons = [model.decode(model.encode(img)) for _ in range(4)]
512
+
513
+ out = torch.cat([img, *recons], dim=0)
514
+ out = (out + 1) * 0.5
515
+ save_image(out, f'recons_{fname}')
516
+
517
+
518
+ if __name__ == "__main__":
519
+ main()
src/COP-GEN-Beta/libs/timm.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code from timm 0.3.2
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ import warnings
6
+
7
+
8
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
9
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
10
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
11
+ def norm_cdf(x):
12
+ # Computes standard normal cumulative distribution function
13
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
14
+
15
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
16
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
17
+ "The distribution of values may be incorrect.",
18
+ stacklevel=2)
19
+
20
+ with torch.no_grad():
21
+ # Values are generated by using a truncated uniform distribution and
22
+ # then using the inverse CDF for the normal distribution.
23
+ # Get upper and lower cdf values
24
+ l = norm_cdf((a - mean) / std)
25
+ u = norm_cdf((b - mean) / std)
26
+
27
+ # Uniformly fill tensor with values from [l, u], then translate to
28
+ # [2l-1, 2u-1].
29
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
30
+
31
+ # Use inverse cdf transform for normal distribution to get truncated
32
+ # standard normal
33
+ tensor.erfinv_()
34
+
35
+ # Transform to proper mean, std
36
+ tensor.mul_(std * math.sqrt(2.))
37
+ tensor.add_(mean)
38
+
39
+ # Clamp to ensure it's in the proper range
40
+ tensor.clamp_(min=a, max=b)
41
+ return tensor
42
+
43
+
44
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
45
+ # type: (Tensor, float, float, float, float) -> Tensor
46
+ r"""Fills the input Tensor with values drawn from a truncated
47
+ normal distribution. The values are effectively drawn from the
48
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
49
+ with values outside :math:`[a, b]` redrawn until they are within
50
+ the bounds. The method used for generating the random values works
51
+ best when :math:`a \leq \text{mean} \leq b`.
52
+ Args:
53
+ tensor: an n-dimensional `torch.Tensor`
54
+ mean: the mean of the normal distribution
55
+ std: the standard deviation of the normal distribution
56
+ a: the minimum cutoff value
57
+ b: the maximum cutoff value
58
+ Examples:
59
+ >>> w = torch.empty(3, 5)
60
+ >>> nn.init.trunc_normal_(w)
61
+ """
62
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
63
+
64
+
65
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
66
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
67
+
68
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
69
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
70
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
71
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
72
+ 'survival rate' as the argument.
73
+
74
+ """
75
+ if drop_prob == 0. or not training:
76
+ return x
77
+ keep_prob = 1 - drop_prob
78
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
79
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
80
+ random_tensor.floor_() # binarize
81
+ output = x.div(keep_prob) * random_tensor
82
+ return output
83
+
84
+
85
+ class DropPath(nn.Module):
86
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
87
+ """
88
+ def __init__(self, drop_prob=None):
89
+ super(DropPath, self).__init__()
90
+ self.drop_prob = drop_prob
91
+
92
+ def forward(self, x):
93
+ return drop_path(x, self.drop_prob, self.training)
94
+
95
+
96
+ class Mlp(nn.Module):
97
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
98
+ super().__init__()
99
+ out_features = out_features or in_features
100
+ hidden_features = hidden_features or in_features
101
+ self.fc1 = nn.Linear(in_features, hidden_features)
102
+ self.act = act_layer()
103
+ self.fc2 = nn.Linear(hidden_features, out_features)
104
+ self.drop = nn.Dropout(drop)
105
+
106
+ def forward(self, x):
107
+ x = self.fc1(x)
108
+ x = self.act(x)
109
+ x = self.drop(x)
110
+ x = self.fc2(x)
111
+ x = self.drop(x)
112
+ return x
src/COP-GEN-Beta/libs/triffuser_multi_post_ln.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from .timm import trunc_normal_, DropPath, Mlp
5
+ import einops
6
+ import torch.utils.checkpoint
7
+ import torch.nn.functional as F
8
+
9
+ if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
10
+ ATTENTION_MODE = 'flash'
11
+ else:
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+ ATTENTION_MODE = 'xformers'
16
+ except:
17
+ ATTENTION_MODE = 'math'
18
+ print(f'attention mode is {ATTENTION_MODE}')
19
+
20
+
21
+ def timestep_embedding(timesteps, dim, max_period=10000):
22
+ """
23
+ Create sinusoidal timestep embeddings.
24
+
25
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
26
+ These may be fractional.
27
+ :param dim: the dimension of the output.
28
+ :param max_period: controls the minimum frequency of the embeddings.
29
+ :return: an [N x dim] Tensor of positional embeddings.
30
+ """
31
+ half = dim // 2
32
+ freqs = torch.exp(
33
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
34
+ ).to(device=timesteps.device)
35
+ args = timesteps[:, None].float() * freqs[None]
36
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
37
+ if dim % 2:
38
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
39
+ return embedding
40
+
41
+
42
+ def patchify(imgs, patch_size):
43
+ x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
44
+ return x
45
+
46
+
47
+ def unpatchify(x, in_chans):
48
+ patch_size = int((x.shape[2] // in_chans) ** 0.5)
49
+ h = w = int(x.shape[1] ** .5)
50
+ assert h * w == x.shape[1] and patch_size ** 2 * in_chans == x.shape[2]
51
+ x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
52
+ return x
53
+
54
+
55
+ def interpolate_pos_emb(pos_emb, old_shape, new_shape):
56
+ pos_emb = einops.rearrange(pos_emb, 'B (H W) C -> B C H W', H=old_shape[0], W=old_shape[1])
57
+ pos_emb = F.interpolate(pos_emb, new_shape, mode='bilinear')
58
+ pos_emb = einops.rearrange(pos_emb, 'B C H W -> B (H W) C')
59
+ return pos_emb
60
+
61
+
62
+ class Attention(nn.Module):
63
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
64
+ super().__init__()
65
+ self.num_heads = num_heads
66
+ head_dim = dim // num_heads
67
+ self.scale = qk_scale or head_dim ** -0.5
68
+
69
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
70
+ self.attn_drop = nn.Dropout(attn_drop)
71
+ self.proj = nn.Linear(dim, dim)
72
+ self.proj_drop = nn.Dropout(proj_drop)
73
+
74
+ def forward(self, x):
75
+ B, L, C = x.shape
76
+
77
+ qkv = self.qkv(x)
78
+ if ATTENTION_MODE == 'flash':
79
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
80
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
81
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
82
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
83
+ elif ATTENTION_MODE == 'xformers':
84
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
85
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
86
+ x = xformers.ops.memory_efficient_attention(q, k, v)
87
+ x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
88
+ elif ATTENTION_MODE == 'math':
89
+ with torch.amp.autocast(device_type='cuda', enabled=False):
90
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
91
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
92
+ attn = (q @ k.transpose(-2, -1)) * self.scale
93
+ attn = attn.softmax(dim=-1)
94
+ attn = self.attn_drop(attn)
95
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
96
+ else:
97
+ raise NotImplemented
98
+
99
+ x = self.proj(x)
100
+ x = self.proj_drop(x)
101
+ return x
102
+
103
+
104
+ class Block(nn.Module):
105
+
106
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
107
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
108
+ super().__init__()
109
+ self.norm1 = norm_layer(dim) if skip else None
110
+ self.norm2 = norm_layer(dim)
111
+
112
+ self.attn = Attention(
113
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
114
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
115
+ self.norm3 = norm_layer(dim)
116
+ mlp_hidden_dim = int(dim * mlp_ratio)
117
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
118
+ self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
119
+ self.use_checkpoint = use_checkpoint
120
+
121
+ def forward(self, x, skip=None):
122
+ if self.use_checkpoint:
123
+ return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
124
+ else:
125
+ return self._forward(x, skip)
126
+
127
+ def _forward(self, x, skip=None):
128
+ if self.skip_linear is not None:
129
+ x = self.skip_linear(torch.cat([x, skip], dim=-1))
130
+ x = self.norm1(x)
131
+ x = x + self.drop_path(self.attn(x))
132
+ x = self.norm2(x)
133
+
134
+ x = x + self.drop_path(self.mlp(x))
135
+ x = self.norm3(x)
136
+
137
+ return x
138
+
139
+
140
+ class PatchEmbed(nn.Module):
141
+ """ Image to Patch Embedding
142
+ """
143
+ def __init__(self, patch_size, in_chans=3, embed_dim=768):
144
+ super().__init__()
145
+ self.patch_size = patch_size
146
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
147
+
148
+ def forward(self, x):
149
+ B, C, H, W = x.shape
150
+ assert H % self.patch_size == 0 and W % self.patch_size == 0
151
+ x = self.proj(x).flatten(2).transpose(1, 2)
152
+ return x
153
+
154
+
155
+ class Triffuser(nn.Module):
156
+ def __init__(self, img_size, in_chans, patch_size, embed_dim=768, depth=12,
157
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, pos_drop_rate=0., drop_rate=0., attn_drop_rate=0.,
158
+ norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False,
159
+ num_modalities=None,
160
+ # text_dim=None,
161
+ # num_text_tokens=None,
162
+ clip_img_dim=None # All modalities with the same clip dimension
163
+ ):
164
+ super().__init__()
165
+ self.in_chans = in_chans
166
+ self.patch_size = patch_size
167
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
168
+ self.num_modalities = num_modalities
169
+ if num_modalities is None:
170
+ raise ValueError("num_modalities must be provided")
171
+
172
+ self.patch_embeds = nn.ModuleList([PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) for _ in range(num_modalities)])
173
+ self.img_size = (img_size, img_size) if isinstance(img_size, int) else img_size # the default img size
174
+ assert self.img_size[0] % patch_size == 0 and self.img_size[1] % patch_size == 0
175
+ self.num_patches = (self.img_size[0] // patch_size) * (self.img_size[1] // patch_size)
176
+
177
+ self.time_img_embeds = nn.ModuleList([nn.Sequential(
178
+ nn.Linear(embed_dim, 4 * embed_dim),
179
+ nn.SiLU(),
180
+ nn.Linear(4 * embed_dim, embed_dim),
181
+ ) if mlp_time_embed else nn.Identity() for _ in range(num_modalities)])
182
+
183
+ # self.text_embed = nn.Linear(text_dim, embed_dim)
184
+ # self.text_out = nn.Linear(embed_dim, text_dim)
185
+
186
+ # TODO: We skip clip embedding for now
187
+ # self.clip_img_embed = nn.Linear(clip_img_dim, embed_dim)
188
+ # self.clip_img_out = nn.Linear(embed_dim, clip_img_dim)
189
+
190
+ # self.num_text_tokens = num_text_tokens
191
+ # TODO: ATM we assume the same num_patches for all modalities
192
+ # 1 for time embedding token of each modality
193
+ # num_patches for each modality (assuming the same number of patches for all modalities)
194
+ self.num_tokens = 1 * self.num_modalities + self.num_patches * self.num_modalities
195
+
196
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
197
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
198
+
199
+ self.in_blocks = nn.ModuleList([
200
+ Block(
201
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
202
+ drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, use_checkpoint=use_checkpoint)
203
+ for _ in range(depth // 2)])
204
+
205
+ self.mid_block = Block(
206
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
207
+ drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, use_checkpoint=use_checkpoint)
208
+
209
+ self.out_blocks = nn.ModuleList([
210
+ Block(
211
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
212
+ drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, skip=True, use_checkpoint=use_checkpoint)
213
+ for _ in range(depth // 2)])
214
+
215
+ self.norm = norm_layer(embed_dim)
216
+ self.patch_dim = patch_size ** 2 * in_chans
217
+ self.decoder_preds = nn.ModuleList([nn.Linear(embed_dim, self.patch_dim, bias=True) for _ in range(num_modalities)])
218
+
219
+ trunc_normal_(self.pos_embed, std=.02)
220
+ self.apply(self._init_weights)
221
+
222
+ def _init_weights(self, m):
223
+ if isinstance(m, nn.Linear):
224
+ trunc_normal_(m.weight, std=.02)
225
+ if isinstance(m, nn.Linear) and m.bias is not None:
226
+ nn.init.constant_(m.bias, 0)
227
+ elif isinstance(m, nn.LayerNorm):
228
+ nn.init.constant_(m.bias, 0)
229
+ nn.init.constant_(m.weight, 1.0)
230
+
231
+ @torch.jit.ignore
232
+ def no_weight_decay(self):
233
+ return {'pos_embed'}
234
+
235
+ def forward(self, imgs, t_imgs):
236
+
237
+ assert len(imgs) == len(t_imgs) == self.num_modalities
238
+
239
+ # TODO: We are still assuming all images have the same shape
240
+ _, _, H, W = imgs[0].shape
241
+
242
+ imgs = [self.patch_embeds[i](img) for i, img in enumerate(imgs)]
243
+
244
+ t_imgs_token = [self.time_img_embeds[i](timestep_embedding(t_img, self.embed_dim)) for i, t_img in enumerate(t_imgs)]
245
+ t_imgs_token = [t_img_token.unsqueeze(dim=1) for t_img_token in t_imgs_token]
246
+
247
+ # text = self.text_embed(text)
248
+ # clip_img = self.clip_img_embed(clip_img)
249
+ x = torch.cat((*t_imgs_token, *imgs), dim=1)
250
+
251
+ num_img_tokens = [img.size(1) for img in imgs] # Each image might have different number of tokens
252
+ num_t_tokens = [1] * self.num_modalities # There is only one time token for each modality
253
+
254
+ # TODO: ATM assume all modality images have the same shape
255
+ if H == self.img_size[0] and W == self.img_size[1]:
256
+ pos_embed = self.pos_embed
257
+ else: # interpolate the positional embedding when the input image is not of the default shape
258
+ raise NotImplementedError("Why are we here? Images are not of the default shape. Interpolate positional embedding.")
259
+ pos_embed_others, pos_embed_patches = torch.split(self.pos_embed, [1 + 1 + num_text_tokens + 1, self.num_patches], dim=1)
260
+ pos_embed_patches = interpolate_pos_emb(pos_embed_patches, (self.img_size[0] // self.patch_size, self.img_size[1] // self.patch_size),
261
+ (H // self.patch_size, W // self.patch_size))
262
+ pos_embed = torch.cat((pos_embed_others, pos_embed_patches), dim=1)
263
+
264
+ x = x + pos_embed
265
+ x = self.pos_drop(x)
266
+
267
+ skips = []
268
+ for blk in self.in_blocks:
269
+ x = blk(x)
270
+ skips.append(x)
271
+
272
+ x = self.mid_block(x)
273
+
274
+ for blk in self.out_blocks:
275
+ x = blk(x, skips.pop())
276
+
277
+ x = self.norm(x)
278
+
279
+ all_t_imgs = x.split((*num_t_tokens, *num_img_tokens), dim=1)
280
+
281
+ t_imgs_token_out = all_t_imgs[:self.num_modalities]
282
+ imgs_out = all_t_imgs[self.num_modalities:]
283
+
284
+ imgs_out = [self.decoder_preds[i](img_out) for i, img_out in enumerate(imgs_out)]
285
+ imgs_out = [unpatchify(img_out, self.in_chans) for img_out in imgs_out]
286
+
287
+ # clip_img_out = self.clip_img_out(clip_img_out)
288
+ # text_out = self.text_out(text_out)
289
+
290
+ return imgs_out
src/COP-GEN-Beta/majortom/NMajorTOM.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ from pathlib import Path
6
+ import rasterio as rio
7
+ from PIL import Image
8
+ import torchvision.transforms as transforms
9
+ import random
10
+
11
+ class NMajorTOM(Dataset):
12
+ """NMajorTOM Dataset with multiple modalities (https://huggingface.co/Major-TOM)
13
+
14
+ Args:
15
+ modalities (dict): Dictionary of modality configurations, where each key is a modality name
16
+ and value is a dict containing:
17
+ - df: Metadata dataframe for that modality
18
+ - local_dir: Root directory for that modality
19
+ - tif_bands: List of tif bands to read
20
+ - png_bands: List of png bands to read
21
+ - tif_transforms: List of transforms for tif files
22
+ - png_transforms: List of transforms for png files
23
+ random_flip (bool): Whether to randomly flip all modalities together
24
+ ratio_train_test (float): Ratio of training samples (e.g., 0.8 for 80% train, 20% test)
25
+ seed (int): Random seed for reproducible train/test splits
26
+ """
27
+
28
+ def __init__(self, modalities, random_flip=True, ratio_train_test=0.8, seed=42):
29
+ super().__init__()
30
+ self.modalities = {}
31
+
32
+ # Set random seed for reproducibility
33
+ random.seed(seed)
34
+
35
+ # Process each modality's configuration
36
+ for modality_name, config in modalities.items():
37
+ # Drop rows that are complete duplicates across all relevant columns
38
+ num_rows = len(config['df'])
39
+ if modality_name == 'S1RTC':
40
+ relevant_cols = ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id',
41
+ 'timestamp', 'nodata', 'orbit_state', 'centre_lat',
42
+ 'centre_lon', 'crs', 'parquet_url', 'geometry']
43
+ else:
44
+ relevant_cols = list(config['df'].keys())
45
+ config['df'] = config['df'].drop_duplicates(subset=relevant_cols)
46
+ print(f"Dropped {num_rows - len(config['df'])} duplicates from {modality_name}")
47
+
48
+ # By now, we should have no duplicate grid_cells
49
+ if config['df']['grid_cell'].duplicated().any():
50
+ raise ValueError(f"Found rows with duplicate grid_cells but different values in modality {modality_name}")
51
+
52
+ self.modalities[modality_name] = {
53
+ 'df': config['df'],
54
+ 'local_dir': Path(config['local_dir']) if isinstance(config['local_dir'], str) else config['local_dir'],
55
+ 'tif_bands': config['tif_bands'] if not isinstance(config['tif_bands'], str) else [config['tif_bands']],
56
+ 'png_bands': config['png_bands'] if not isinstance(config['png_bands'], str) else [config['png_bands']],
57
+ 'tif_transforms': transforms.Compose(config['tif_transforms']) if config['tif_transforms'] is not None else None,
58
+ 'png_transforms': transforms.Compose(config['png_transforms']) if config['png_transforms'] is not None else None
59
+ }
60
+
61
+ self.random_flip = random_flip
62
+
63
+ # Get the set of grid_cells for each modality
64
+ grid_cells_by_modality = {
65
+ name: set(mod['df']['grid_cell'].values)
66
+ for name, mod in self.modalities.items()
67
+ }
68
+
69
+ # Check that all modalities share the same grid_cells
70
+ if len(grid_cells_by_modality) > 0:
71
+ reference_grid_cells = grid_cells_by_modality[list(grid_cells_by_modality.keys())[0]]
72
+ for modality_name, grid_cells in grid_cells_by_modality.items():
73
+ if grid_cells != reference_grid_cells:
74
+ missing = reference_grid_cells - grid_cells
75
+ extra = grid_cells - reference_grid_cells
76
+ error_msg = f"Modality {modality_name} has mismatched grid_cells.\n"
77
+ if missing:
78
+ error_msg += f"Missing grid_cells: {missing}\n"
79
+ if extra:
80
+ error_msg += f"Extra grid_cells: {extra}"
81
+ raise ValueError(error_msg)
82
+
83
+ # Sort all dataframes by grid_cell for consistent sampling
84
+ for modality in self.modalities.values():
85
+ modality['df'] = modality['df'].sort_values('grid_cell').reset_index(drop=True)
86
+
87
+
88
+ print("Creating train/test split...")
89
+
90
+ # After sorting dataframes, create train/test split
91
+ all_grid_cells = list(reference_grid_cells)
92
+ random.shuffle(all_grid_cells)
93
+
94
+ n_train = int(len(all_grid_cells) * ratio_train_test)
95
+ self.train_grid_cells = set(all_grid_cells[:n_train])
96
+ self.test_grid_cells = set(all_grid_cells[n_train:])
97
+
98
+ # Let's create a dictionary of grid_cells to split
99
+ self.grid_cell_to_split = {grid_cell: 'train' if grid_cell in self.train_grid_cells else 'test' for grid_cell in reference_grid_cells}
100
+
101
+ print(f"Split dataset into {len(self.train_grid_cells)} train and {len(self.test_grid_cells)} test grid cells")
102
+
103
+ def __len__(self):
104
+ # Return length of any modality (they should all be the same)
105
+ assert len(self.modalities) > 0, "No modalities provided"
106
+ # Get len for each modality and make sure they are the same
107
+ lengths = [len(mod['df']) for mod in self.modalities.values()]
108
+ if not all(x == lengths[0] for x in lengths):
109
+ raise ValueError("All modalities must have the same number of samples")
110
+ return lengths[0]
111
+
112
+ def __getitem__(self, idx):
113
+ result = {}
114
+
115
+ # Generate the same random flip decision for all modalities
116
+ do_flip = self.random_flip and random.random() < 0.5
117
+
118
+ # Get the grid cell for this index (they're all the same across modalities)
119
+ first_modality = list(self.modalities.keys())[0]
120
+ current_grid_cell = self.modalities[first_modality]['df'].iloc[idx]['grid_cell']
121
+
122
+ # Determine if this sample is in train or test set
123
+ split = self.grid_cell_to_split[current_grid_cell]
124
+
125
+ for modality_name, modality in self.modalities.items():
126
+ meta = modality['df'].iloc[idx]
127
+ product_id = meta.product_id if 'product_id' in meta.index else "id"
128
+ grid_cell = meta.grid_cell
129
+ row = grid_cell.split('_')[0]
130
+
131
+ path = modality['local_dir'] / Path(f"{row}/{grid_cell}/{product_id}")
132
+ out_dict = {}
133
+
134
+ # Process TIF bands
135
+ for band in modality['tif_bands']:
136
+ with rio.open(path / f'{band}.tif') as f:
137
+ out = f.read() # out = torch.from_numpy(f.read()).float()
138
+ if modality['tif_transforms'] is not None:
139
+ out = modality['tif_transforms'](out)
140
+ out_dict[band] = out
141
+
142
+ # Process PNG bands
143
+ for band in modality['png_bands']:
144
+ out = Image.open(path / f'{band}.png')
145
+ if modality['png_transforms'] is not None:
146
+ out = modality['png_transforms'](out)
147
+ out_dict[band] = out
148
+
149
+ # Apply the same random flip to all bands in this modality
150
+ if do_flip:
151
+ out_dict = {k: v.flip(-1) for k, v in out_dict.items()}
152
+
153
+ # Add split information to the output dictionary
154
+ out_dict['split'] = split
155
+ out_dict['grid_cell'] = current_grid_cell
156
+
157
+ result[modality_name] = out_dict
158
+
159
+ # Assert the grid_cells are the same for all modalities in the resulting dictionary
160
+ if len(result) > 0:
161
+ first_modality = list(result.keys())[0]
162
+ first_grid_cell = self.modalities[first_modality]['df'].iloc[idx]['grid_cell']
163
+ for modality_name in result.keys():
164
+ current_grid_cell = self.modalities[modality_name]['df'].iloc[idx]['grid_cell']
165
+ if current_grid_cell != first_grid_cell:
166
+ raise ValueError(f"Mismatched grid_cells found: {current_grid_cell} != {first_grid_cell}")
167
+ # Add grid_cell to the output dictionary for verification
168
+ result[modality_name]['grid_cell'] = current_grid_cell
169
+
170
+ return result
src/COP-GEN-Beta/majortom/coverage_vis.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from mpl_toolkits.basemap import Basemap
5
+ import PIL
6
+
7
+ def get_mask(df):
8
+ """
9
+ Take a Major TOM dataframe and create a mask corresponding to available cells
10
+ """
11
+
12
+ mask = np.zeros((2004,4008), dtype=np.uint8)
13
+ row_offset = -1002
14
+ col_offset = -2004
15
+
16
+ nodata = df['nodata'].values > 0.5
17
+
18
+ yy = mask.shape[0] - (np.array(df['grid_row_u']) - row_offset) - 1
19
+ xx = np.array(df['grid_col_r']) - col_offset
20
+
21
+ yy = yy[~nodata]
22
+ xx = xx[~nodata]
23
+
24
+ mask[yy, xx] = 255
25
+
26
+ return PIL.Image.fromarray(mask)
27
+
28
+ def fig2img(fig):
29
+ """Convert a Matplotlib figure to a PIL Image and return it"""
30
+ import io
31
+ buf = io.BytesIO()
32
+ fig.savefig(buf)
33
+ buf.seek(0)
34
+ img = PIL.Image.open(buf)
35
+ return img
36
+
37
+ def light_basemap():
38
+ """
39
+ Bright coloured contours
40
+ """
41
+
42
+ with plt.ioff():
43
+ fig, ax = plt.subplots(figsize=(48,24), dpi=167)
44
+
45
+ m = Basemap(projection='sinu', lat_0=0, lon_0=0, resolution='l', ax=ax)
46
+ m.fillcontinents(color="#9eba9b", lake_color='#CCDDFF')
47
+ m.drawmapboundary(fill_color="#CCDDFF")
48
+ m.drawcountries(color="#666666", linewidth=1)
49
+ m.drawcoastlines(color="#666666", linewidth=1)
50
+
51
+ plt.gca().set_axis_off()
52
+ plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
53
+ hspace = 0, wspace = 0)
54
+ plt.margins(0,0)
55
+
56
+ return fig2img(fig)
57
+
58
+ def dark_basemap():
59
+ """
60
+ Dark contours
61
+ """
62
+
63
+ with plt.ioff():
64
+ fig, ax = plt.subplots(figsize=(48,24), dpi=167)
65
+
66
+ m = Basemap(projection='sinu', lat_0=0, lon_0=0, resolution='l', ax=ax)
67
+ m.fillcontinents(color="#242424", lake_color='#242424')
68
+ m.drawmapboundary(fill_color="#242424")
69
+ m.drawcountries(color="#000000", linewidth=1)
70
+ m.drawcoastlines(color="#000000", linewidth=1)
71
+
72
+ plt.gca().set_axis_off()
73
+ plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
74
+ hspace = 0, wspace = 0)
75
+ plt.margins(0,0)
76
+
77
+ return fig2img(fig)
78
+
79
+ def get_coveragemap(input, input2=None):
80
+ """
81
+ Creates a complete coloured Major TOM coverage figure in the same style as in the official documentation
82
+
83
+ Optionally, input2 can be provided and then, the map plots a map with extra colours indicating cells available only in input (green) or only input2 (blue)
84
+ """
85
+
86
+ if input2 is None:
87
+ return single_coveragemap(input)
88
+ else:
89
+ cmap1 = single_coveragemap(input)
90
+ cmap2 = single_coveragemap(input2)
91
+
92
+ # arrays for mixing
93
+ inp1_arr = np.array(cmap1)[...,:3]
94
+ inp2_arr = np.array(cmap2)[...,:3]
95
+
96
+ common_arr = inp1_arr*(inp1_arr.sum(-1) == inp2_arr.sum(-1))[:,:,None]
97
+ common_arr[:,:,(1,2)] = 0
98
+ inp1_arr[:,:,(0,2)] = 0 # Green - indicates presence of S2 only
99
+ inp2_arr[:,:,(0,1)] = 0 # Blue - indicates presense of DEM only
100
+
101
+ return PIL.Image.fromarray(((common_arr + inp1_arr + inp2_arr)).astype(np.uint8))
102
+
103
+
104
+ def single_coveragemap(input):
105
+ """
106
+ Creates a complete coloured Major TOM coverage figure in the same style as in the official documentation
107
+ """
108
+
109
+ # compute mask if df is provided
110
+ if isinstance(input, pd.DataFrame):
111
+ mask = get_mask(input)
112
+ else:
113
+ mask = input
114
+
115
+ basemap = light_basemap()
116
+ basemap_d = dark_basemap()
117
+
118
+ outside_earth = np.array(basemap.convert('RGBA'))[:, :, 0] == 255
119
+ outside_earth = PIL.Image.fromarray(outside_earth)
120
+
121
+ mask = mask.resize(basemap.size, PIL.Image.NEAREST)
122
+
123
+ basemap.putalpha(mask)
124
+
125
+ # Mask outside of earth
126
+ basemap.paste(outside_earth, (0,0), outside_earth)
127
+
128
+ basemap_d.paste(basemap, (0,0), basemap)
129
+
130
+ return basemap_d
131
+
132
+ if __name__ == '__main__':
133
+ DATASET_NAME = 'Major-TOM/Core-S2L2A'
134
+ meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
135
+ df = pd.read_parquet(meta_path)
136
+
137
+ # This is how you make a coverage figure!
138
+ coverage_img = get_coveragemap(df)
139
+
140
+ coverage_img.save('coverage-example.png', format='PNG')
141
+
142
+ # and this is how you can create an overap for 2 datasets!
143
+ DATASET_NAME = 'Major-TOM/Core-DEM'
144
+ meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
145
+ dem_df = pd.read_parquet(meta_path)
146
+
147
+ coverage_img = get_coveragemap(df,dem_df)
148
+
149
+ coverage_img.save('overlap-coverage-example.png', format='PNG')
src/COP-GEN-Beta/majortom/download_world.py ADDED
@@ -0,0 +1,1009 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from shapely.geometry import box
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from pathlib import Path
5
+ import os
6
+ import geopandas as gpd
7
+ import pandas as pd
8
+ import pyarrow.parquet as pq
9
+ from typing import List, Dict, Set
10
+ import logging
11
+ import urllib.request
12
+ from concurrent import futures
13
+ import fsspec
14
+ from tqdm import tqdm
15
+ import tempfile
16
+ import time
17
+ import random
18
+
19
+
20
+ S2L2A_METADATA = ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'cloud_cover', 'nodata', 'centre_lat', 'centre_lon', 'crs', 'parquet_url', 'parquet_row', 'geometry']
21
+ S2L1C_METADATA = ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'cloud_cover', 'nodata', 'centre_lat', 'centre_lon', 'crs', 'parquet_url', 'parquet_row', 'geometry']
22
+ S1RTC_METADATA = ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'nodata', 'orbit_state', 'centre_lat', 'centre_lon', 'crs', 'parquet_url', 'parquet_row']
23
+ DEM_METADATA = ['grid_cell', 'grid_row_u', 'grid_col_r', 'nodata', 'max_val', 'min_val', 'centre_lat', 'centre_lon', 'crs', 'parquet_url', 'parquet_row', '__index_level_0__']
24
+
25
+ METADATA_COLUMNS = {
26
+ 'Core-S2L2A': S2L2A_METADATA,
27
+ 'Core-S2L1C': S2L1C_METADATA,
28
+ 'Core-S1RTC': S1RTC_METADATA,
29
+ 'Core-DEM': DEM_METADATA
30
+ }
31
+
32
+ CONTENT = {
33
+ 'Core-S2L2A': ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12', 'cloud_mask'],
34
+ 'Core-S2L1C': ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12', 'cloud_mask'],
35
+ 'Core-S1RTC': ['vv', 'vh'],
36
+ 'Core-DEM': ['DEM', 'compressed']
37
+ }
38
+
39
+ # Default max workers for extraction (can be higher as it's CPU-bound)
40
+ MAX_WORKERS = 32
41
+ # Default max workers for download (more conservative to avoid network issues)
42
+ DEFAULT_DOWNLOAD_WORKERS = 8
43
+
44
+ def parse_args():
45
+
46
+ if "INTERACTIVE" in os.environ: # Set INTERACTIVE=1 when running manually
47
+ return argparse.Namespace(
48
+ data_dir="./data/majorTOM",
49
+ bbox=[-180.0, -90.0, 180.0, 90.0],
50
+ sources=['Core-S2L2A', 'Core-S2L1C', 'Core-S1RTC', 'Core-DEM'],
51
+ subset_name="world",
52
+ start_date="2017-01-01",
53
+ end_date="2025-01-01",
54
+ cloud_cover=[0, 10],
55
+ preview=True,
56
+ mode="full",
57
+ delete_parquets=False,
58
+ download_workers=DEFAULT_DOWNLOAD_WORKERS,
59
+ revalidate=False
60
+ )
61
+ else:
62
+ parser = argparse.ArgumentParser(description='Download satellite imagery from Major-TOM dataset')
63
+ parser.add_argument('--data-dir', type=str, default='./data/majorTOM',
64
+ help='Data directory for downloaded files')
65
+ parser.add_argument('--bbox', type=float, nargs=4,
66
+ default=[2.9559111595, 43.8179931641, 55.4920501709, 65.808380127],
67
+ help='Bounding box coordinates: minx miny maxx maxy')
68
+ parser.add_argument('--sources', type=str, nargs='+',
69
+ default=['Core-S2L2A', 'Core-S2L1C', 'Core-S1RTC'],
70
+ help='List of source names for the datasets')
71
+ parser.add_argument('--subset-name', type=str, required=True,
72
+ help='Name for the geographical subset being created')
73
+ parser.add_argument('--start-date', type=str, default='2017-01-01',
74
+ help='Start date for temporal range (YYYY-MM-DD)')
75
+ parser.add_argument('--end-date', type=str, default='2025-01-01',
76
+ help='End date for temporal range (YYYY-MM-DD)')
77
+ parser.add_argument('--cloud-cover', type=float, nargs=2, default=[0, 10],
78
+ help='Cloud cover range (min max)')
79
+ parser.add_argument('--criteria', type=str, default=None,
80
+ help='Criteria for timestamp deduplication. Currently we support "latest"')
81
+ parser.add_argument('--n-samples', type=int, default=None,
82
+ help='Number of samples to download')
83
+ parser.add_argument('--seed', type=int, default=None,
84
+ help='Random seed for reproducibility')
85
+ parser.add_argument('--preview', action='store_true',
86
+ help='If True, only print the number of samples for each source that will be downloaded')
87
+ parser.add_argument('--mode', type=str, choices=['full', 'download', 'extract'], default='full',
88
+ help='Mode of operation: full (download and extract), download (download parquets only), extract (extract from downloaded parquets)')
89
+ parser.add_argument('--delete-parquets', action='store_true',
90
+ help='Delete parquet files after extraction (only used with extract mode)')
91
+ parser.add_argument('--download-workers', type=int, default=DEFAULT_DOWNLOAD_WORKERS,
92
+ help=f'Number of parallel workers for downloading files. Default: {DEFAULT_DOWNLOAD_WORKERS}. Reduce this number if downloads are slow.')
93
+ parser.add_argument('--revalidate', action='store_true',
94
+ help='Force revalidation of all parquet files and redownload if corrupted')
95
+ return parser.parse_args()
96
+
97
+
98
+ def fix_crs(df):
99
+ if df['crs'].iloc[0].startswith('EPSG:EPSG:'):
100
+ df['crs'] = df['crs'].str.replace('EPSG:EPSG:', 'EPSG:', regex=False)
101
+ return df
102
+
103
+ def my_filter_metadata(df,
104
+ region=None,
105
+ daterange=None,
106
+ cloud_cover=(0,100),
107
+ nodata=(0, 1.0)
108
+ ):
109
+ """Filters the Major-TOM dataframe based on several parameters
110
+
111
+ Args:
112
+ df (geopandas dataframe): Parent dataframe
113
+ region (shapely geometry object) : Region of interest
114
+ daterange (tuple) : Inclusive range of dates (example format: '2020-01-01')
115
+ cloud_cover (tuple) : Inclusive percentage range (0-100) of cloud cover
116
+ nodata (tuple) : Inclusive fraction (0.0-1.0) of no data allowed in a sample
117
+
118
+ Returns:
119
+ df: a filtered dataframe
120
+ """
121
+ # temporal filtering
122
+ if daterange is not None and 'timestamp' in df.columns:
123
+ assert (isinstance(daterange, list) or isinstance(daterange, tuple)) and len(daterange)==2
124
+ df = df[df.timestamp >= daterange[0]]
125
+ df = df[df.timestamp <= daterange[1]]
126
+
127
+ # spatial filtering
128
+ if region is not None:
129
+ idxs = df.sindex.query(region)
130
+ df = df.take(idxs)
131
+ # cloud filtering
132
+ if cloud_cover is not None:
133
+ df = df[df.cloud_cover >= cloud_cover[0]]
134
+ df = df[df.cloud_cover <= cloud_cover[1]]
135
+
136
+ # spatial filtering
137
+ if nodata is not None:
138
+ df = df[df.nodata >= nodata[0]]
139
+ df = df[df.nodata <= nodata[1]]
140
+
141
+ return df
142
+
143
+ def my_filter_download(df, local_dir, source_name, by_row=False, verbose=False, tif_columns=None, download_workers=DEFAULT_DOWNLOAD_WORKERS):
144
+ """Downloads and unpacks the data of Major-TOM based on a metadata dataframe"""
145
+ if isinstance(local_dir, str):
146
+ local_dir = Path(local_dir)
147
+
148
+ # identify all parquets that need to be downloaded (group them)
149
+ urls = df.parquet_url.unique()
150
+ print(f'Starting parallel download of {len(urls)} parquet files.') if verbose else None
151
+
152
+ def process_parquet(url):
153
+ # Create a unique temporary file for each thread
154
+ temp_file = tempfile.NamedTemporaryFile(suffix=".parquet", dir=local_dir).name
155
+
156
+ # identify all relevant rows for this parquet
157
+ rows = df[df.parquet_url == url].parquet_row.unique()
158
+
159
+ max_retries = 3
160
+ retry_delay = 5 # seconds
161
+ success = False
162
+ last_error = None
163
+
164
+ for attempt in range(max_retries):
165
+ try:
166
+ if not by_row:
167
+ # Create an opener with a longer timeout
168
+ opener = urllib.request.build_opener()
169
+ opener.addheaders = [('User-agent', 'Mozilla/5.0')]
170
+ urllib.request.install_opener(opener)
171
+
172
+ # Download with timeout using urlopen (30 minutes timeout)
173
+ with urllib.request.urlopen(url, timeout=1800) as response:
174
+ with open(temp_file, 'wb') as out_file:
175
+ out_file.write(response.read())
176
+ temp_path = temp_file
177
+ else:
178
+ f = fsspec.open(url)
179
+ temp_path = f.open()
180
+
181
+ # Process the downloaded parquet file
182
+ try:
183
+ with pq.ParquetFile(temp_path) as pf:
184
+ for row_idx in rows:
185
+ table = pf.read_row_group(row_idx)
186
+
187
+ product_id = table['product_id'][0].as_py() if 'product_id' in table.column_names else "id"
188
+ grid_cell = table['grid_cell'][0].as_py()
189
+ row = grid_cell.split('_')[0]
190
+
191
+ dest = local_dir / Path(f"{source_name}/{row}/{grid_cell}/{product_id}")
192
+ dest.mkdir(exist_ok=True, parents=True)
193
+
194
+ if tif_columns == 'all':
195
+ columns = [col for col in table.column_names if col[0] == 'B']
196
+ if source_name in ['Core-S2L1C', 'Core-S2L2A']:
197
+ columns.append('cloud_mask')
198
+ elif tif_columns is None:
199
+ columns = []
200
+ else:
201
+ columns = tif_columns
202
+
203
+ # Save tifs
204
+ for col in columns:
205
+ with open(dest / f"{col}.tif", "wb") as f:
206
+ f.write(table[col][0].as_py())
207
+
208
+ # Save thumbnail
209
+ with open(dest / "thumbnail.png", "wb") as f:
210
+ f.write(table['thumbnail'][0].as_py())
211
+
212
+ success = True
213
+ break # Successfully processed the file, exit retry loop
214
+
215
+ except Exception as e:
216
+ last_error = f"Error processing parquet content: {str(e)}"
217
+ if attempt < max_retries - 1:
218
+ print(f"Error processing parquet content for {url}, attempt {attempt + 1}/{max_retries}: {str(e)}")
219
+ time.sleep(retry_delay)
220
+ continue
221
+
222
+ finally:
223
+ # Cleanup
224
+ if not by_row:
225
+ try:
226
+ os.remove(temp_path)
227
+ except:
228
+ pass
229
+ else:
230
+ try:
231
+ f.close()
232
+ except:
233
+ pass
234
+
235
+ except urllib.error.HTTPError as e:
236
+ last_error = f"HTTP Error {e.code}: {str(e)}"
237
+ if e.code == 504 and attempt < max_retries - 1:
238
+ print(f"Timeout error for {url}, attempt {attempt + 1}/{max_retries}. Retrying in {retry_delay} seconds...")
239
+ time.sleep(retry_delay)
240
+ continue
241
+ except Exception as e:
242
+ last_error = str(e)
243
+ if attempt < max_retries - 1:
244
+ print(f"Error downloading {url}, attempt {attempt + 1}/{max_retries}: {str(e)}")
245
+ time.sleep(retry_delay)
246
+ continue
247
+
248
+ return {
249
+ 'url': url,
250
+ 'success': success,
251
+ 'error': last_error if not success else None
252
+ }
253
+
254
+ # Use ThreadPoolExecutor for parallel downloads
255
+ # max_workers = min(len(urls), MAX_WORKERS*4) # Use more workers since it's I/O bound
256
+ max_workers = min(len(urls), download_workers)
257
+ print(f"Using {max_workers} workers for parallel downloads") if verbose else None
258
+ results = []
259
+
260
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
261
+ future_to_url = {executor.submit(process_parquet, url): url for url in urls}
262
+
263
+ for future in tqdm(
264
+ futures.as_completed(future_to_url),
265
+ total=len(urls),
266
+ desc=f'Downloading {source_name} parquets'
267
+ ):
268
+ results.append(future.result())
269
+
270
+ # Process results and handle failures
271
+ failed_downloads = [r for r in results if not r['success']]
272
+ if failed_downloads:
273
+ print(f"\nWarning: Failed to download {len(failed_downloads)} parquet files for {source_name}")
274
+ print("\nFailed downloads:")
275
+ for fail in failed_downloads:
276
+ print(f"URL: {fail['url']}")
277
+ print(f"Error: {fail['error']}")
278
+ print("---")
279
+ raise RuntimeError(f"Some parquet files failed to download for {source_name}. Please retry the download.")
280
+
281
+ print(f"Successfully downloaded and processed {len(urls) - len(failed_downloads)} parquet files for {source_name}")
282
+
283
+
284
+ def my_metadata_from_url(access_url, local_url):
285
+ local_url, response = urllib.request.urlretrieve(access_url, local_url)
286
+ df = pq.read_table(local_url).to_pandas()
287
+ if 'timestamp' in df.columns:
288
+ df['timestamp'] = pd.to_datetime(df.timestamp)
289
+ df = fix_crs(df) # Fix CRS typo if present
290
+ gdf = gpd.GeoDataFrame(
291
+ df, geometry=gpd.points_from_xy(df.centre_lon, df.centre_lat), crs=df.crs.iloc[0]
292
+ )
293
+ return gdf
294
+
295
+ def get_metadata(source: str, output_dir: Path) -> gpd.GeoDataFrame:
296
+ """Fetch metadata from HuggingFace dataset for a specific source"""
297
+ access_url = f"https://huggingface.co/datasets/Major-TOM/{source}/resolve/main/metadata.parquet?download=true"
298
+ local_url = output_dir / source / "metadata.parquet"
299
+ local_url.parent.mkdir(exist_ok=True, parents=True)
300
+
301
+ if local_url.exists():
302
+ print(f"Using cached metadata for {source}")
303
+ df = pq.read_table(local_url).to_pandas()
304
+ if 'timestamp' in df.columns:
305
+ df['timestamp'] = pd.to_datetime(df.timestamp)
306
+ df = fix_crs(df)
307
+ gdf = gpd.GeoDataFrame(
308
+ df, geometry=gpd.points_from_xy(df.centre_lon, df.centre_lat), crs=df.crs.iloc[0]
309
+ )
310
+ else:
311
+ print(f"Downloading metadata for {source}...")
312
+ gdf = my_metadata_from_url(access_url, local_url)
313
+
314
+ return gdf
315
+
316
+
317
+ def filter_data(gdf, bbox, cloud_cover, date_range):
318
+ """Filter metadata based on given parameters"""
319
+ region = box(*bbox)
320
+ return my_filter_metadata(
321
+ gdf,
322
+ cloud_cover=cloud_cover,
323
+ region=region,
324
+ daterange=date_range,
325
+ nodata=(0.0, 0.0)
326
+ )
327
+
328
+
329
+ def find_common_samples(filtered_dfs: Dict[str, gpd.GeoDataFrame]) -> Dict[str, gpd.GeoDataFrame]:
330
+ """Find samples that share common grid cells across all datasets"""
331
+ # Create sets of grid_cells for each dataset
332
+ grid_cell_sets = {
333
+ source: set(df['grid_cell'].unique())
334
+ for source, df in filtered_dfs.items()
335
+ }
336
+
337
+ # Find intersection of all grid cell sets
338
+ common_grid_cells = set.intersection(*grid_cell_sets.values())
339
+ print(f"\033[92mFound {len(common_grid_cells)} common grid cells across all sources\033[0m")
340
+
341
+ # Filter dataframes to keep only rows with common grid cells
342
+ filtered_common = {}
343
+ for source, df in filtered_dfs.items():
344
+ filtered_common[source] = df[df['grid_cell'].isin(common_grid_cells)]
345
+ print(f"{source}: {len(filtered_common[source])} samples for common grid cells")
346
+
347
+ return filtered_common
348
+
349
+
350
+ def download_source_files(df: gpd.GeoDataFrame, output_dir: Path, source: str, mode: str = 'full', delete_parquets: bool = False, download_workers: int = DEFAULT_DOWNLOAD_WORKERS, revalidate: bool = False):
351
+ """Download files for a specific source"""
352
+ print(f"Processing files for {source}...")
353
+
354
+ if mode == 'download':
355
+ # Only download parquet files without extracting
356
+ download_parquet_files(
357
+ df,
358
+ local_dir=output_dir,
359
+ source_name=source,
360
+ download_workers=download_workers,
361
+ revalidate=revalidate,
362
+ verbose=True
363
+ )
364
+ elif mode == 'extract':
365
+ # Extract data from already downloaded parquet files
366
+ extract_from_parquet_files(
367
+ df,
368
+ local_dir=output_dir,
369
+ source_name=source,
370
+ delete_parquets=delete_parquets,
371
+ verbose=True,
372
+ tif_columns=CONTENT[source]
373
+ )
374
+ else: # mode == 'full'
375
+ # Use the original function for backwards compatibility
376
+ my_filter_download(
377
+ df,
378
+ local_dir=output_dir,
379
+ source_name=source,
380
+ by_row=False,
381
+ verbose=True,
382
+ tif_columns=CONTENT[source],
383
+ download_workers=download_workers
384
+ )
385
+
386
+ def get_and_filter_source(args, source: str, data_dir: Path) -> gpd.GeoDataFrame:
387
+ """Process a single source: get metadata and filter it"""
388
+ source_dir = data_dir / source
389
+ source_dir.mkdir(exist_ok=True, parents=True)
390
+
391
+ # Get and filter metadata for each source
392
+ gdf = get_metadata(source, data_dir)
393
+
394
+ # Only apply cloud cover filter for Sentinel-2 sources
395
+ cloud_cover_filter = tuple(args.cloud_cover) if source.startswith('Core-S2') else None
396
+
397
+ filtered_df = filter_data(
398
+ gdf,
399
+ bbox=args.bbox,
400
+ cloud_cover=cloud_cover_filter,
401
+ date_range=(args.start_date, args.end_date)
402
+ )
403
+ print(f"Found {len(filtered_df)} samples for {source} in the specified region")
404
+ return filtered_df
405
+
406
+ def download_source_parallel(source_df_tuple: tuple, subset_dir: Path, mode: str = 'full', delete_parquets: bool = False, download_workers: int = DEFAULT_DOWNLOAD_WORKERS, revalidate: bool = False):
407
+ """Download files for a source sequentially, with resume capability"""
408
+ source, df = source_df_tuple
409
+ source_subset_dir = subset_dir / source
410
+ source_subset_dir.mkdir(exist_ok=True, parents=True)
411
+
412
+ # Save filtered metadata
413
+ metadata_path = source_subset_dir / "metadata.parquet"
414
+ df.to_parquet(metadata_path)
415
+ print(f"Saved filtered metadata for {source} to {metadata_path}")
416
+
417
+ # If we're only downloading parquet files, we don't need to check for existing tif files
418
+ if mode == 'download':
419
+ download_source_files(df, subset_dir, source, mode=mode, delete_parquets=delete_parquets, download_workers=download_workers, revalidate=revalidate)
420
+ print(f"Completed parquet downloads for {source}")
421
+ return
422
+
423
+ # If we're extracting and the extraction metadata exists, we don't need to check for existing files
424
+ parquet_dir = subset_dir / source / "parquets"
425
+ extraction_file = parquet_dir / "extraction_metadata.parquet"
426
+ filtered_df_file = parquet_dir / "filtered_df.parquet"
427
+
428
+ if mode == 'extract' and extraction_file.exists() and filtered_df_file.exists():
429
+ print(f"Using saved extraction metadata for {source}")
430
+ download_source_files(df, subset_dir, source, mode=mode, delete_parquets=delete_parquets, download_workers=download_workers, revalidate=revalidate)
431
+ print(f"Completed extraction for {source}")
432
+ return
433
+
434
+ # Filter out already processed grid cells more efficiently
435
+ def get_existing_files(df, subset_dir, source):
436
+ # Create all possible paths
437
+ # For DEM, use 'id' as product_id, for other sources use actual product_id
438
+ product_ids = df['product_id'] if 'product_id' in df.columns else pd.Series(['id'] * len(df))
439
+ grid_cells = df['grid_cell']
440
+ row_dirs = grid_cells.str.split('_').str[0]
441
+
442
+ # Vectorized path creation
443
+ paths = [
444
+ subset_dir / source / row_dir / grid_cell / product_id / "thumbnail.png"
445
+ for row_dir, grid_cell, product_id in zip(row_dirs, grid_cells, product_ids)
446
+ ]
447
+
448
+ # Batch existence check
449
+ exists_mask = [path.exists() for path in tqdm(paths, desc=f"Checking existing files for {source}", unit="file")]
450
+ return pd.Series(exists_mask, index=df.index)
451
+
452
+ # Create mask of unprocessed files
453
+ exists_mask = get_existing_files(df, subset_dir, source)
454
+ df_to_process = df[~exists_mask]
455
+
456
+ if len(df_to_process) == 0:
457
+ print(f"All files for {source} are already processed. Skipping.")
458
+ return
459
+
460
+ print(f"Found {len(df) - len(df_to_process)} already processed files")
461
+ print(f"Processing remaining {len(df_to_process)} files for {source}...")
462
+
463
+ # Process the remaining data files
464
+ download_source_files(df_to_process, subset_dir, source, mode=mode, delete_parquets=delete_parquets, download_workers=download_workers, revalidate=revalidate)
465
+ print(f"Completed processing for {source}")
466
+
467
+ def remove_duplicates(common_dfs: Dict[str, gpd.GeoDataFrame],
468
+ criteria: str = None) -> Dict[str, gpd.GeoDataFrame]:
469
+ """Remove duplicates from common dataframes based on source-specific relevant columns."""
470
+ for source, df in common_dfs.items():
471
+ num_rows = len(df)
472
+
473
+ if 'timestamp' in df.columns:
474
+ if criteria == "latest":
475
+ # Sort by timestamp and keep the latest
476
+ df = df.sort_values(by='timestamp', ascending=False)
477
+ elif criteria == None:
478
+ raise ValueError("Please, specify a criteria for deduplication. Currently we do not support multiple timestamps for the same grid_cell.")
479
+ else:
480
+ raise ValueError("Criteria not supported")
481
+
482
+ # TODO:
483
+ # Product_id includes the timestamp.
484
+ # We ignore one of the two orbit_states to avoid duplicates.
485
+ # We can also ignore cloud_cover since we have already filtered by cloud_cover
486
+ # We also ignore crs. Apparently, there are rows that are entirely duplicates except for the crs (? wierd)
487
+ # We also ignore centre_lat and centre_lon since not always are aligned
488
+ subset_columns = [col for col in df.columns if col not in [
489
+ 'parquet_row', 'parquet_url', 'geometry', 'timestamp', 'product_id',
490
+ 'orbit_state', 'cloud_cover', 'crs', 'centre_lat', 'centre_lon'
491
+ ]]
492
+ df = df.drop_duplicates(subset=subset_columns)
493
+
494
+ # Verify no remaining duplicates in grid_cell
495
+ if df['grid_cell'].duplicated().any():
496
+ print(df[df['grid_cell'].duplicated()])
497
+ raise ValueError(f"Found rows with duplicate grid_cells but different values in source {source}")
498
+
499
+ common_dfs[source] = df
500
+ print(f"\033[94mDropped {num_rows - len(df)} duplicates from {source}\033[0m")
501
+
502
+ return common_dfs
503
+
504
+ def sample_common_dfs(common_dfs: Dict[str, gpd.GeoDataFrame], n_samples: int, seed: int) -> Dict[str, gpd.GeoDataFrame]:
505
+ """Sample common dataframes to have n_samples samples per source"""
506
+ # Get all unique grid cells that appear in all dataframes
507
+ grid_cells_sets = [set(df['grid_cell'].unique()) for df in common_dfs.values()]
508
+ all_grid_cells = list(set.intersection(*grid_cells_sets))
509
+ if not all_grid_cells:
510
+ raise ValueError("No common grid cells found across all sources")
511
+
512
+ # Sort grid cells for reproducibility before sampling
513
+ all_grid_cells.sort()
514
+
515
+ # Randomly sample grid cells
516
+ random.seed(seed)
517
+ sampled_grid_cells = set(random.sample(all_grid_cells, min(n_samples, len(all_grid_cells))))
518
+
519
+ # Filter each dataframe to only include the sampled grid cells
520
+ result = {}
521
+ for source, df in common_dfs.items():
522
+ result[source] = df[df['grid_cell'].isin(sampled_grid_cells)]
523
+ print(f"Sampled {len(result[source])} rows for {source}")
524
+
525
+ return result
526
+
527
+ def is_valid_parquet(parquet_path):
528
+ """
529
+ Checks if a parquet file is valid and not empty.
530
+
531
+ Args:
532
+ parquet_path: Path to the parquet file
533
+
534
+ Returns:
535
+ bool: True if the parquet file is valid, False otherwise
536
+ """
537
+ try:
538
+ # Check if file exists and has a non-zero size (not empty)
539
+ if not os.path.exists(parquet_path) or os.path.getsize(parquet_path) == 0:
540
+ return False
541
+
542
+ # Try to open and read metadata from the parquet file
543
+ with pq.ParquetFile(parquet_path) as pf:
544
+ # Check if there's at least one row group
545
+ if pf.num_row_groups == 0:
546
+ return False
547
+
548
+ # Try to read metadata of the first row group to verify basic integrity
549
+ pf.metadata
550
+
551
+ # Optionally, try reading a small sample of data to further verify
552
+ table = pf.read_row_group(0, columns=['grid_cell'])
553
+
554
+ return True
555
+ except Exception as e:
556
+ print(f"Error validating parquet file {parquet_path}: {str(e)}")
557
+ return False
558
+
559
+ def download_parquet_files(df, local_dir, source_name, download_workers=DEFAULT_DOWNLOAD_WORKERS, revalidate=False, verbose=False):
560
+ """Downloads only the parquet files without extracting data, saving them to disk"""
561
+ if isinstance(local_dir, str):
562
+ local_dir = Path(local_dir)
563
+
564
+ # Create a directory to store parquet files
565
+ parquet_dir = local_dir / source_name / "parquets"
566
+ parquet_dir.mkdir(exist_ok=True, parents=True)
567
+
568
+ # Identify all parquets that need to be downloaded
569
+ urls = df.parquet_url.unique()
570
+ print(f'Starting parallel download of {len(urls)} parquet files.') if verbose else None
571
+
572
+ def download_parquet(url):
573
+ # Get the filename from the URL
574
+ filename = url.split('/')[-1].split('?')[0]
575
+ parquet_path = parquet_dir / filename
576
+
577
+ # Skip if file already exists and is valid (and we're not forcing revalidation)
578
+ if parquet_path.exists() and not revalidate:
579
+ if is_valid_parquet(parquet_path):
580
+ return {
581
+ 'url': url,
582
+ 'path': parquet_path,
583
+ 'success': True,
584
+ 'error': None,
585
+ 'skipped': True
586
+ }
587
+ else:
588
+ # File exists but is corrupted or empty, delete it for redownload
589
+ print(f"Found corrupted or invalid parquet file: {parquet_path}. Will redownload.")
590
+ try:
591
+ os.remove(parquet_path)
592
+ except Exception as e:
593
+ print(f"Warning: Failed to delete corrupted file {parquet_path}: {str(e)}")
594
+ elif parquet_path.exists() and revalidate:
595
+ # If we're revalidating, check the file and delete if invalid
596
+ if not is_valid_parquet(parquet_path):
597
+ print(f"Revalidation: Found corrupted parquet file: {parquet_path}. Will redownload.")
598
+ try:
599
+ os.remove(parquet_path)
600
+ except Exception as e:
601
+ print(f"Warning: Failed to delete corrupted file {parquet_path}: {str(e)}")
602
+ else:
603
+ # File is valid, skip download
604
+ print(f"Revalidation: Confirmed valid parquet file: {parquet_path}")
605
+ return {
606
+ 'url': url,
607
+ 'path': parquet_path,
608
+ 'success': True,
609
+ 'error': None,
610
+ 'skipped': True
611
+ }
612
+
613
+ max_retries = 3
614
+ retry_delay = 5 # seconds
615
+ success = False
616
+ last_error = None
617
+
618
+ for attempt in range(max_retries):
619
+ try:
620
+ # Create an opener with a longer timeout
621
+ opener = urllib.request.build_opener()
622
+ opener.addheaders = [('User-agent', 'Mozilla/5.0')]
623
+ urllib.request.install_opener(opener)
624
+
625
+ # Download with timeout using urlopen (30 minutes timeout)
626
+ with urllib.request.urlopen(url, timeout=1800) as response:
627
+ with open(parquet_path, 'wb') as out_file:
628
+ out_file.write(response.read())
629
+
630
+ # Verify the downloaded file is valid
631
+ if not is_valid_parquet(parquet_path):
632
+ last_error = "Downloaded file is corrupted or invalid"
633
+ if attempt < max_retries - 1:
634
+ print(f"Error: Downloaded parquet file is corrupted, attempt {attempt + 1}/{max_retries}. Retrying...")
635
+ os.remove(parquet_path)
636
+ time.sleep(retry_delay)
637
+ continue
638
+
639
+ success = True
640
+ break # Successfully downloaded, exit retry loop
641
+
642
+ except urllib.error.HTTPError as e:
643
+ last_error = f"HTTP Error {e.code}: {str(e)}"
644
+ if e.code == 504 and attempt < max_retries - 1:
645
+ print(f"Timeout error for {url}, attempt {attempt + 1}/{max_retries}. Retrying in {retry_delay} seconds...")
646
+ time.sleep(retry_delay)
647
+ continue
648
+ except Exception as e:
649
+ last_error = str(e)
650
+ if attempt < max_retries - 1:
651
+ print(f"Error downloading {url}, attempt {attempt + 1}/{max_retries}: {str(e)}")
652
+ time.sleep(retry_delay)
653
+ continue
654
+ # Make sure the file is deleted if it was partially downloaded
655
+ if parquet_path.exists():
656
+ try:
657
+ os.remove(parquet_path)
658
+ except:
659
+ pass
660
+
661
+ return {
662
+ 'url': url,
663
+ 'path': parquet_path if success else None,
664
+ 'success': success,
665
+ 'error': last_error if not success else None,
666
+ 'skipped': False
667
+ }
668
+
669
+ # Use ThreadPoolExecutor for parallel downloads with the specified number of workers
670
+ max_workers = min(len(urls), download_workers)
671
+ print(f"Using {max_workers} workers for parallel downloads") if verbose else None
672
+ results = []
673
+
674
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
675
+ future_to_url = {executor.submit(download_parquet, url): url for url in urls}
676
+
677
+ for future in tqdm(
678
+ futures.as_completed(future_to_url),
679
+ total=len(urls),
680
+ desc=f'Downloading {source_name} parquets'
681
+ ):
682
+ results.append(future.result())
683
+
684
+ # Process results and handle failures
685
+ failed_downloads = [r for r in results if not r['success']]
686
+ skipped_downloads = [r for r in results if r['skipped']]
687
+
688
+ if failed_downloads:
689
+ print(f"\nWarning: Failed to download {len(failed_downloads)} parquet files for {source_name}")
690
+ print("\nFailed downloads:")
691
+ for fail in failed_downloads:
692
+ print(f"URL: {fail['url']}")
693
+ print(f"Error: {fail['error']}")
694
+ print("---")
695
+ raise RuntimeError(f"Some parquet files failed to download for {source_name}. Please retry the download.")
696
+
697
+ print(f"Successfully downloaded {len(results) - len(failed_downloads) - len(skipped_downloads)} parquet files for {source_name}")
698
+ print(f"Skipped {len(skipped_downloads)} valid existing parquet files")
699
+
700
+ # Create a mapping from URLs to local file paths
701
+ url_to_path = {r['url']: r['path'] for r in results if r['success']}
702
+
703
+ # Save the URL to path mapping
704
+ mapping_file = parquet_dir / "url_to_path.parquet"
705
+ mapping_df = pd.DataFrame({
706
+ 'url': list(url_to_path.keys()),
707
+ 'path': [str(path) for path in url_to_path.values()]
708
+ })
709
+ mapping_df.to_parquet(mapping_file)
710
+
711
+ # Save the extraction metadata - creating a dataframe that maps URLs to the rows that need to be extracted
712
+ extraction_meta = []
713
+ for url in urls:
714
+ rows = df[df.parquet_url == url].parquet_row.unique()
715
+ for row in rows:
716
+ extraction_meta.append({
717
+ 'url': url,
718
+ 'row': row
719
+ })
720
+
721
+ extraction_df = pd.DataFrame(extraction_meta)
722
+ extraction_file = parquet_dir / "extraction_metadata.parquet"
723
+ extraction_df.to_parquet(extraction_file)
724
+
725
+ # Also save the full filtered dataframe for reference
726
+ filtered_df_file = parquet_dir / "filtered_df.parquet"
727
+ df.to_parquet(filtered_df_file)
728
+
729
+ print(f"Saved extraction metadata for {len(extraction_meta)} rows across {len(urls)} parquet files")
730
+
731
+ return url_to_path
732
+
733
+ def extract_from_parquet_files(df, local_dir, source_name, delete_parquets=False, verbose=False, tif_columns=None):
734
+ """Extracts data from already downloaded parquet files"""
735
+ if isinstance(local_dir, str):
736
+ local_dir = Path(local_dir)
737
+
738
+ # Path to the directory where parquet files are stored
739
+ parquet_dir = local_dir / source_name / "parquets"
740
+
741
+ # Check if the URL to path mapping exists
742
+ mapping_file = parquet_dir / "url_to_path.parquet"
743
+ if not mapping_file.exists():
744
+ raise FileNotFoundError(f"URL to path mapping file not found at {mapping_file}. Please download parquet files first.")
745
+
746
+ # Load the URL to path mapping
747
+ mapping_df = pd.read_parquet(mapping_file)
748
+ url_to_path = dict(zip(mapping_df['url'], mapping_df['path']))
749
+
750
+ # Try to load the extraction metadata if it exists, otherwise use the provided dataframe
751
+ extraction_file = parquet_dir / "extraction_metadata.parquet"
752
+ filtered_df_file = parquet_dir / "filtered_df.parquet"
753
+
754
+ if extraction_file.exists() and filtered_df_file.exists():
755
+ print("Using saved extraction metadata")
756
+ extraction_df = pd.read_parquet(extraction_file)
757
+
758
+ # We need to load the original filtered df to get all the metadata
759
+ saved_df = pd.read_parquet(filtered_df_file)
760
+
761
+ # If a specific subset of df was provided, filter extraction_df to only those URLs
762
+ if df is not None:
763
+ urls_to_extract = df.parquet_url.unique()
764
+ extraction_df = extraction_df[extraction_df['url'].isin(urls_to_extract)]
765
+ saved_df = saved_df[saved_df.parquet_url.isin(urls_to_extract)]
766
+
767
+ # Replace the input df with the saved one
768
+ df = saved_df
769
+ else:
770
+ # If no saved metadata, create extraction_df from the provided df
771
+ print("No saved extraction metadata found, using provided dataframe")
772
+ extraction_df = []
773
+ for url in df.parquet_url.unique():
774
+ rows = df[df.parquet_url == url].parquet_row.unique()
775
+ for row in rows:
776
+ extraction_df.append({
777
+ 'url': url,
778
+ 'row': row
779
+ })
780
+ extraction_df = pd.DataFrame(extraction_df)
781
+
782
+ # Get all unique URLs that need to be processed
783
+ urls = extraction_df['url'].unique()
784
+ print(f'Starting extraction from {len(urls)} parquet files.') if verbose else None
785
+
786
+ # Check if all required parquet files exist and are valid
787
+ missing_or_invalid_urls = []
788
+ for url in urls:
789
+ if url not in url_to_path:
790
+ missing_or_invalid_urls.append((url, "Missing"))
791
+ elif not is_valid_parquet(url_to_path[url]):
792
+ missing_or_invalid_urls.append((url, "Invalid/Corrupted"))
793
+
794
+ if missing_or_invalid_urls:
795
+ print(f"Warning: {len(missing_or_invalid_urls)} parquet files are missing or corrupted. Please download them first.")
796
+ print("Issues with URLs:")
797
+ for url, issue in missing_or_invalid_urls[:5]: # Show first 5 problem URLs
798
+ print(f" {url} - {issue}")
799
+ if len(missing_or_invalid_urls) > 5:
800
+ print(f" ... and {len(missing_or_invalid_urls) - 5} more")
801
+ raise FileNotFoundError("Some required parquet files are missing or corrupted. Please run the download step again.")
802
+
803
+ def process_parquet(url):
804
+ # Get the local path of the parquet file
805
+ parquet_path = url_to_path[url]
806
+
807
+ # Get the rows in this parquet file that we need to extract
808
+ rows = extraction_df[extraction_df['url'] == url]['row'].unique()
809
+
810
+ success = False
811
+ last_error = None
812
+
813
+ try:
814
+ with pq.ParquetFile(parquet_path) as pf:
815
+ for row_idx in rows:
816
+ table = pf.read_row_group(row_idx)
817
+
818
+ product_id = table['product_id'][0].as_py() if 'product_id' in table.column_names else "id"
819
+ grid_cell = table['grid_cell'][0].as_py()
820
+ row = grid_cell.split('_')[0]
821
+
822
+ dest = local_dir / Path(f"{source_name}/{row}/{grid_cell}/{product_id}")
823
+ dest.mkdir(exist_ok=True, parents=True)
824
+
825
+ if tif_columns == 'all':
826
+ columns = [col for col in table.column_names if col[0] == 'B']
827
+ if source_name in ['Core-S2L1C', 'Core-S2L2A']:
828
+ columns.append('cloud_mask')
829
+ elif tif_columns is None:
830
+ columns = []
831
+ else:
832
+ columns = tif_columns
833
+
834
+ # Save tifs
835
+ for col in columns:
836
+ with open(dest / f"{col}.tif", "wb") as f:
837
+ f.write(table[col][0].as_py())
838
+
839
+ # Save thumbnail
840
+ with open(dest / "thumbnail.png", "wb") as f:
841
+ f.write(table['thumbnail'][0].as_py())
842
+
843
+ success = True
844
+
845
+ # Delete the parquet file if requested
846
+ if delete_parquets:
847
+ try:
848
+ os.remove(parquet_path)
849
+ except Exception as e:
850
+ print(f"Warning: Failed to delete parquet file {parquet_path}: {str(e)}")
851
+
852
+ except Exception as e:
853
+ last_error = str(e)
854
+
855
+ return {
856
+ 'url': url,
857
+ 'path': parquet_path,
858
+ 'success': success,
859
+ 'error': last_error if not success else None
860
+ }
861
+
862
+ # Use ThreadPoolExecutor for parallel processing
863
+ max_workers = min(len(urls), MAX_WORKERS)
864
+ results = []
865
+
866
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
867
+ future_to_url = {executor.submit(process_parquet, url): url for url in urls}
868
+
869
+ for future in tqdm(
870
+ futures.as_completed(future_to_url),
871
+ total=len(urls),
872
+ desc=f'Extracting from {source_name} parquets'
873
+ ):
874
+ results.append(future.result())
875
+
876
+ # Process results and handle failures
877
+ failed_extractions = [r for r in results if not r['success']]
878
+ if failed_extractions:
879
+ print(f"\nWarning: Failed to extract from {len(failed_extractions)} parquet files for {source_name}")
880
+ print("\nFailed extractions:")
881
+ for fail in failed_extractions:
882
+ print(f"URL: {fail['url']}")
883
+ print(f"Path: {fail['path']}")
884
+ print(f"Error: {fail['error']}")
885
+ print("---")
886
+ raise RuntimeError(f"Some parquet extractions failed for {source_name}.")
887
+
888
+ print(f"Successfully extracted data from {len(results) - len(failed_extractions)} parquet files for {source_name}")
889
+
890
+ # Clean up the metadata files if all parquet files were deleted
891
+ if delete_parquets and not any(os.path.exists(r['path']) for r in results):
892
+ try:
893
+ # Delete all metadata files
894
+ for meta_file in [mapping_file, extraction_file, filtered_df_file]:
895
+ if meta_file.exists():
896
+ os.remove(meta_file)
897
+
898
+ # Try to remove the parquets directory if it's empty
899
+ if os.path.exists(parquet_dir) and not os.listdir(parquet_dir):
900
+ os.rmdir(parquet_dir)
901
+
902
+ print(f"Cleaned up metadata files and directory for {source_name}")
903
+ except Exception as e:
904
+ print(f"Warning: Failed to clean up metadata files or directory: {str(e)}")
905
+
906
+ return results
907
+
908
+ def main():
909
+ args = parse_args()
910
+ logging.basicConfig(level=logging.INFO)
911
+
912
+ data_dir = Path(args.data_dir)
913
+ subset_dir = data_dir / args.subset_name
914
+
915
+ # Always process metadata and filtering for all modes
916
+ print("\033[92mFetching and filtering metadata...\033[0m")
917
+
918
+ # Parallel processing of metadata fetching and filtering
919
+ max_workers = min(len(args.sources), MAX_WORKERS)
920
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
921
+ future_to_source = {
922
+ executor.submit(get_and_filter_source, args, source, data_dir): source
923
+ for source in args.sources
924
+ }
925
+
926
+ # Collect results while maintaining order
927
+ filtered_dfs = {}
928
+ for future in futures.as_completed(future_to_source):
929
+ source = future_to_source[future]
930
+ try:
931
+ filtered_dfs[source] = future.result()
932
+ except Exception as e:
933
+ print(f"Error processing {source}: {e}")
934
+ raise e
935
+
936
+ # Synchronization point: find common samples across all sources
937
+ common_dfs = find_common_samples(filtered_dfs)
938
+
939
+ # Remove duplicates for each of the common_dfs
940
+ common_dfs = remove_duplicates(common_dfs, criteria=args.criteria)
941
+
942
+ # After removing duplicates, print the number of samples for each source
943
+ print("\033[92mAfter removing duplicates:\033[0m")
944
+ for source, df in common_dfs.items():
945
+ print(f"{source}: {len(df)} samples for common grid cells")
946
+
947
+ if args.preview:
948
+ return
949
+
950
+ if args.n_samples is not None: # Else, we download all samples.
951
+ print(f"Sampling {args.n_samples} samples per source...")
952
+ common_dfs = sample_common_dfs(common_dfs, args.n_samples, args.seed)
953
+ print(f"Done sampling {args.n_samples} grid cells per source!")
954
+
955
+ # Remove Core-DEM from common_dfs, because it is already downloaded.
956
+ # Comment / Uncomment when needed.
957
+ # common_dfs.pop('Core-DEM')
958
+ # common_dfs.pop('Core-S1RTC')
959
+ # common_dfs.pop('Core-S2L1C')
960
+ # common_dfs.pop('Core-S2L2A')
961
+ print(f"We will only process the following modalities: {list(common_dfs.keys())}")
962
+
963
+ # Print information about download workers
964
+ if args.mode in ['download', 'full']:
965
+ print(f"\033[94mUsing {args.download_workers} workers for parallel downloads\033[0m")
966
+ print("If downloads are slow, try reducing this number with the --download-workers parameter")
967
+
968
+ if args.revalidate:
969
+ print("\033[94mRevalidating all parquet files (will check for corrupted files)\033[0m")
970
+ else:
971
+ print("Use --revalidate to force checking of existing parquet files for corruption")
972
+
973
+ # Execute the appropriate action based on mode
974
+ if args.mode == 'download':
975
+ print("\033[92mStarting download of parquet files...\033[0m")
976
+ for source, df in common_dfs.items():
977
+ print(f"\033[94mDownloading parquets for modality: {source}\033[0m")
978
+ download_source_parallel((source, df), subset_dir, mode='download',
979
+ delete_parquets=args.delete_parquets,
980
+ download_workers=args.download_workers,
981
+ revalidate=args.revalidate)
982
+ print("\033[92mParquet file download complete.\033[0m")
983
+ print("To extract data from these parquet files, run this script with --mode extract")
984
+
985
+ elif args.mode == 'extract':
986
+ print("\033[92mStarting extraction from parquet files...\033[0m")
987
+ for source, df in common_dfs.items():
988
+ print(f"\033[94mExtracting data for modality: {source}\033[0m")
989
+ download_source_parallel((source, df), subset_dir, mode='extract',
990
+ delete_parquets=args.delete_parquets,
991
+ download_workers=args.download_workers,
992
+ revalidate=args.revalidate)
993
+ print("\033[92mData extraction complete.\033[0m")
994
+ if args.delete_parquets:
995
+ print("Parquet files have been deleted.")
996
+ else:
997
+ print("To delete the parquet files, run this script with --mode extract --delete-parquets")
998
+
999
+ else: # mode == 'full'
1000
+ print("\033[92mStarting full download and extraction process...\033[0m")
1001
+ for source, df in common_dfs.items():
1002
+ print(f"\033[94mProcessing modality: {source}\033[0m")
1003
+ download_source_parallel((source, df), subset_dir, mode='full',
1004
+ download_workers=args.download_workers,
1005
+ revalidate=args.revalidate)
1006
+ print("\033[92mDownload and extraction complete.\033[0m")
1007
+
1008
+ if __name__ == "__main__":
1009
+ main()
src/COP-GEN-Beta/prepare_dataset_images.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from libs.autoencoder import get_model
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+ import torchvision
7
+ import torchvision.transforms as transforms
8
+ from tqdm import tqdm
9
+ import os
10
+ import argparse
11
+ from pathlib import Path
12
+ import glob
13
+ from majortom.NMajorTOM import NMajorTOM
14
+ import pyarrow.parquet as pq
15
+ import geopandas as gpd
16
+ import pandas as pd
17
+ from majortom.coverage_vis import get_coveragemap
18
+
19
+ torch.manual_seed(0)
20
+ np.random.seed(0)
21
+
22
+ PATCH_SIZE = 256
23
+ GRID_SIZE = 4 # 4x4 grid of patches
24
+
25
+ SATELLITE_CONFIGS = {
26
+ 'S2L2A': {
27
+ 'tif_bands': ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12', 'cloud_mask'],
28
+ 'png_bands': ['thumbnail'],
29
+ 'tif_transforms': [],
30
+ 'png_transforms': [
31
+ transforms.CenterCrop(PATCH_SIZE * GRID_SIZE), # Crop to 1024x1024
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=(0.5,), std=(0.5,))
34
+ ]
35
+ },
36
+ 'S2L1C': {
37
+ 'tif_bands': ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12', 'cloud_mask'],
38
+ 'png_bands': ['thumbnail'],
39
+ 'tif_transforms': [],
40
+ 'png_transforms': [
41
+ transforms.CenterCrop(PATCH_SIZE * GRID_SIZE),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize(mean=(0.5,), std=(0.5,))
44
+ ]
45
+ },
46
+ 'S1RTC': {
47
+ 'tif_bands': ['vv', 'vh'],
48
+ 'png_bands': ['thumbnail'],
49
+ 'tif_transforms': [],
50
+ 'png_transforms': [
51
+ transforms.CenterCrop(PATCH_SIZE * GRID_SIZE),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize(mean=(0.5,), std=(0.5,))
54
+ ]
55
+ },
56
+ 'DEM': {
57
+ 'tif_bands': ['DEM', 'compressed'],
58
+ 'png_bands': ['thumbnail'],
59
+ 'tif_transforms': [],
60
+ 'png_transforms': [
61
+ transforms.Resize(1068), # First, interpolate to match the resolution of the other modalities (1068x1068)
62
+ transforms.CenterCrop(PATCH_SIZE * GRID_SIZE),
63
+ transforms.ToTensor(),
64
+ transforms.Normalize(mean=(0.5,), std=(0.5,))
65
+ ]
66
+ }
67
+ }
68
+
69
+ def fix_crs(df):
70
+ if df['crs'].iloc[0].startswith('EPSG:EPSG:'):
71
+ df['crs'] = df['crs'].str.replace('EPSG:EPSG:', 'EPSG:', regex=False)
72
+ return df
73
+
74
+ def load_metadata(path):
75
+ df = pq.read_table(path).to_pandas()
76
+ if 'timestamp' in df.columns:
77
+ df['timestamp'] = pd.to_datetime(df.timestamp)
78
+ df = fix_crs(df)
79
+ gdf = gpd.GeoDataFrame(
80
+ df, geometry=gpd.points_from_xy(df.centre_lon, df.centre_lat), crs=df.crs.iloc[0]
81
+ )
82
+ return gdf
83
+
84
+ def process_satellite(subset_path, satellite_types, bands_per_type, ratio_train_test, seed):
85
+ """Process multiple satellite types simultaneously while ensuring they're paired"""
86
+ modalities = {}
87
+ filtered_dfs = {}
88
+
89
+ # First, load metadata for all satellite types
90
+ for sat_type in satellite_types:
91
+ metadata_path = os.path.join(subset_path, f"Core-{sat_type}", "metadata.parquet")
92
+ if not os.path.exists(metadata_path):
93
+ print(f"Skipping {sat_type}: metadata not found at {metadata_path}")
94
+ continue
95
+
96
+ gdf = load_metadata(metadata_path)
97
+ local_dir = os.path.join(subset_path, f"Core-{sat_type}")
98
+
99
+ # Split bands into tif and png based on configuration
100
+ tif_bands = [b for b in bands_per_type[sat_type] if b in SATELLITE_CONFIGS[sat_type]['tif_bands']]
101
+ png_bands = [b for b in bands_per_type[sat_type] if b in SATELLITE_CONFIGS[sat_type]['png_bands']]
102
+
103
+ print(f"\nChecking files for {sat_type}...")
104
+
105
+ # Check which indices have all required files
106
+ valid_indices = []
107
+
108
+ for idx in tqdm(range(len(gdf)), desc=f"Validating {sat_type} samples", unit="samples"):
109
+ row = gdf.iloc[idx]
110
+ grid_cell = row.grid_cell
111
+ row_id = grid_cell.split('_')[0]
112
+ product_id = row.product_id if 'product_id' in row.index else "id"
113
+
114
+ base_path = os.path.join(local_dir, row_id, grid_cell, product_id)
115
+ all_files_exist = True
116
+
117
+ # Check TIF files
118
+ for band in tif_bands:
119
+ if not os.path.exists(os.path.join(base_path, f"{band}.tif")):
120
+ all_files_exist = False
121
+ break
122
+
123
+ # Check PNG files
124
+ if all_files_exist: # Only check PNGs if TIFs exist
125
+ for band in png_bands:
126
+ if not os.path.exists(os.path.join(base_path, f"{band}.png")):
127
+ all_files_exist = False
128
+ break
129
+
130
+ if all_files_exist:
131
+ valid_indices.append(idx)
132
+
133
+ filtered_df = gdf.iloc[valid_indices].copy()
134
+ print(f"Found {len(filtered_df)} valid samples out of {len(gdf)} for {sat_type}")
135
+ filtered_dfs[sat_type] = filtered_df
136
+
137
+ # Find common grid cells across all modalities
138
+ grid_cell_sets = {
139
+ source: set(df['grid_cell'].unique())
140
+ for source, df in filtered_dfs.items()
141
+ }
142
+
143
+ # Find intersection of all grid cell sets
144
+ common_grid_cells = set.intersection(*grid_cell_sets.values())
145
+ print(f"\nFound {len(common_grid_cells)} common grid cells across all modalities")
146
+
147
+ # Filter all modalities to keep only common grid cells
148
+ for sat_type in satellite_types:
149
+ if sat_type not in filtered_dfs:
150
+ continue
151
+
152
+ df = filtered_dfs[sat_type]
153
+ df = df[df['grid_cell'].isin(common_grid_cells)]
154
+ print(f"{sat_type}: {len(df)} samples for common grid cells")
155
+
156
+ modalities[sat_type] = {
157
+ 'df': df,
158
+ 'local_dir': os.path.join(subset_path, f"Core-{sat_type}"),
159
+ 'tif_bands': tif_bands,
160
+ 'png_bands': png_bands,
161
+ 'tif_transforms': SATELLITE_CONFIGS[sat_type]['tif_transforms'],
162
+ 'png_transforms': SATELLITE_CONFIGS[sat_type]['png_transforms']
163
+ }
164
+
165
+ dataset = NMajorTOM(modalities=modalities, ratio_train_test=ratio_train_test, seed=seed)
166
+
167
+ return dataset, len(common_grid_cells)
168
+
169
+ def is_valid_image(filepath):
170
+ """Check if an image file is valid and can be opened. Deletes the file if corrupted."""
171
+ try:
172
+ from PIL import Image
173
+ with Image.open(filepath) as img:
174
+ img.verify() # Verify it's actually an image
175
+ return True
176
+ except Exception:
177
+ print(f" Warning: Corrupted or invalid image found: {filepath}")
178
+ try:
179
+ os.remove(filepath)
180
+ print(f" Deleted corrupted file: {filepath}")
181
+ except Exception as e:
182
+ print(f" Failed to delete corrupted file {filepath}: {e}")
183
+ return False
184
+
185
+ def get_existing_complete_grid_cells(output_dir, satellite_types, bands_per_type, num_grid_cells, expected_patches=16):
186
+ """Returns a set of grid_cells that already have all their patches for all modalities"""
187
+ complete_grid_cells_by_sat = {}
188
+ corrupted_grid_cells = set() # Track grid cells with corrupted files
189
+
190
+ for sat_type in satellite_types:
191
+ sat_base_dir = f"{sat_type}_{'_'.join(bands_per_type[sat_type])}"
192
+ complete_grid_cells_by_sat[sat_type] = set()
193
+
194
+ # Check both train and test directories
195
+ for split in ['train', 'test']:
196
+ dir_path = os.path.join(output_dir, split, sat_base_dir)
197
+ print(f" Checking {dir_path} for existing complete grid cells")
198
+ if not os.path.exists(dir_path):
199
+ print(f" Warning: Directory {dir_path} does not exist")
200
+ continue
201
+
202
+ # Get all PNG files and extract their grid cells
203
+ png_files = glob.glob(os.path.join(dir_path, "*.png"))
204
+ print(f" Found {len(png_files)} PNG files in {dir_path}")
205
+ current_grid_cells = {}
206
+
207
+ for f in png_files:
208
+ # This will now delete the file if it's corrupted
209
+ if not is_valid_image(f):
210
+ # Get the grid cell from the corrupted file
211
+ base_name = os.path.basename(f)
212
+ corrupted_grid_cell = "_".join(base_name.split("_")[:-1])
213
+ # Add to set of corrupted grid cells
214
+ corrupted_grid_cells.add(corrupted_grid_cell)
215
+ # Remove this grid cell from our complete cells since we'll need to regenerate it
216
+ if corrupted_grid_cell in current_grid_cells:
217
+ del current_grid_cells[corrupted_grid_cell]
218
+ continue
219
+
220
+ base_name = os.path.basename(f)
221
+ grid_cell = "_".join(base_name.split("_")[:-1]) # Remove patch number
222
+ current_grid_cells[grid_cell] = current_grid_cells.get(grid_cell, 0) + 1
223
+
224
+ # Keep only grid cells with exactly the expected number of patches
225
+ complete_cells = {gc for gc, count in current_grid_cells.items() if count == expected_patches}
226
+ print(f" Found {len(complete_cells)} complete grid cells in {split} split for {sat_type}")
227
+ complete_grid_cells_by_sat[sat_type].update(complete_cells)
228
+
229
+ print(f"Total complete grid cells for {sat_type}: {len(complete_grid_cells_by_sat[sat_type])}")
230
+
231
+ # Find grid cells that are complete across all satellite types
232
+ if not complete_grid_cells_by_sat:
233
+ return set()
234
+
235
+ complete_grid_cells = set.intersection(*complete_grid_cells_by_sat.values())
236
+
237
+ # Remove any grid cells that had corrupted files
238
+ complete_grid_cells = complete_grid_cells - corrupted_grid_cells
239
+
240
+ # Print detailed debugging information
241
+ print("\nComplete grid cells by satellite type:")
242
+ for sat_type, cells in complete_grid_cells_by_sat.items():
243
+ print(f"{sat_type}: {len(cells)} grid cells")
244
+ print(f"\nGrid cells complete across all types: {len(complete_grid_cells)}")
245
+ if corrupted_grid_cells:
246
+ print(f"Removed {len(corrupted_grid_cells)} grid cells due to corrupted files")
247
+
248
+ if len(complete_grid_cells) < num_grid_cells:
249
+ # Find which grid cells are missing from which satellite types
250
+ all_grid_cells = set.union(*complete_grid_cells_by_sat.values())
251
+ print("\nAnalyzing missing grid cells:")
252
+ for grid_cell in all_grid_cells:
253
+ missing_from = [sat_type for sat_type in satellite_types
254
+ if grid_cell not in complete_grid_cells_by_sat[sat_type]]
255
+ if missing_from:
256
+ print(f"Grid cell {grid_cell} is missing from: {', '.join(missing_from)}")
257
+
258
+ return complete_grid_cells
259
+
260
+ def crop_images(dataset, satellite_types, bands_per_type, output_dir, num_grid_cells, flip=False, center_crop=False):
261
+ """Extract features for all modalities simultaneously while ensuring they're paired"""
262
+ from concurrent.futures import ThreadPoolExecutor
263
+ import itertools
264
+
265
+ # Create output directories if saving PNGs
266
+ for sat_type in satellite_types:
267
+ sat_base_dir = f"{sat_type}_{'_'.join(bands_per_type[sat_type])}"
268
+ os.makedirs(os.path.join(output_dir, 'train', sat_base_dir), exist_ok=True)
269
+ os.makedirs(os.path.join(output_dir, 'test', sat_base_dir), exist_ok=True)
270
+
271
+ # Get already processed grid cells
272
+ print("Checking for existing complete grid cells...")
273
+ # Adjust the expected patch count based on center_crop mode
274
+ expected_patches = 1 if center_crop else GRID_SIZE * GRID_SIZE
275
+ # complete_grid_cells = get_existing_complete_grid_cells(output_dir, satellite_types, bands_per_type, num_grid_cells, expected_patches)
276
+ complete_grid_cells = set()
277
+ print(f"Found {len(complete_grid_cells)} already processed grid cells")
278
+
279
+ # Pre-calculate patch positions (only used if not center_crop)
280
+ patch_positions = list(itertools.product(range(GRID_SIZE), range(GRID_SIZE)))
281
+
282
+ def process_sample(sample):
283
+ """Process a single sample (large image) and return metadata for all its patches"""
284
+ # Check if this grid cell is already processed
285
+ grid_cell = sample[satellite_types[0]]['grid_cell']
286
+ if grid_cell in complete_grid_cells:
287
+ print(f"Skipping {grid_cell} because it already has all its patches")
288
+ return []
289
+
290
+ sample_metadata = []
291
+
292
+ for sat_type in satellite_types:
293
+ modality_data = sample[sat_type]
294
+ split = modality_data['split']
295
+ grid_cell = modality_data['grid_cell']
296
+
297
+ img = modality_data['thumbnail']
298
+
299
+ if center_crop:
300
+ # Calculate center crop coordinates
301
+ h, w = img.shape[-2:]
302
+ start_h = (h - PATCH_SIZE) // 2
303
+ start_w = (w - PATCH_SIZE) // 2
304
+ patch = img[:, start_h:start_h + PATCH_SIZE, start_w:start_w + PATCH_SIZE]
305
+ patches = patch.unsqueeze(0) # Add batch dimension
306
+ else:
307
+ # Original patchifying logic
308
+ C = img.size(0)
309
+ patches = img.unfold(1, PATCH_SIZE, PATCH_SIZE).unfold(2, PATCH_SIZE, PATCH_SIZE)
310
+ patches = patches.permute(0, 1, 2, 3, 4).reshape(C, -1, PATCH_SIZE, PATCH_SIZE)
311
+ patches = patches.permute(1, 0, 2, 3) # [N_patches, C, H, W]
312
+
313
+ if sat_type == 'DEM':
314
+ patches = patches.repeat(1, 3, 1, 1)
315
+
316
+ # Compute paths once
317
+ sat_base_dir = f"{sat_type}_thumbnail"
318
+ save_dir = os.path.join(output_dir, split, sat_base_dir)
319
+
320
+ # Batch denormalize
321
+ patches_denorm = (patches.detach().cpu() + 1) / 2
322
+
323
+ # Save images
324
+ for patch_idx, patch in enumerate(patches_denorm):
325
+ if center_crop:
326
+ filename = f"{grid_cell}_center.png"
327
+ metadata = {
328
+ 'grid_cell': grid_cell,
329
+ 'satellite': sat_type,
330
+ 'bands': 'thumbnail',
331
+ 'split': split,
332
+ 'patch_num': 0,
333
+ 'patch_row': (GRID_SIZE - 1) // 2,
334
+ 'patch_col': (GRID_SIZE - 1) // 2
335
+ }
336
+ else:
337
+ filename = f"{grid_cell}_{patch_idx}.png"
338
+ metadata = {
339
+ 'grid_cell': grid_cell,
340
+ 'satellite': sat_type,
341
+ 'bands': 'thumbnail',
342
+ 'split': split,
343
+ 'patch_num': patch_idx,
344
+ 'patch_row': patch_positions[patch_idx][0],
345
+ 'patch_col': patch_positions[patch_idx][1]
346
+ }
347
+
348
+ torchvision.utils.save_image(patch, os.path.join(save_dir, filename))
349
+ sample_metadata.append(metadata)
350
+
351
+ return sample_metadata
352
+
353
+ # Process samples in parallel
354
+ all_metadata = []
355
+ total_samples = len(dataset)
356
+
357
+ print(f"Processing {total_samples} samples...")
358
+
359
+ with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
360
+ # Create a list to store futures
361
+ futures = []
362
+
363
+ # Submit tasks with progress bar around the dataset iteration
364
+ for sample in tqdm(dataset, total=total_samples,
365
+ desc="Processing samples",
366
+ unit="sample",
367
+ dynamic_ncols=True):
368
+ future = executor.submit(process_sample, sample)
369
+ futures.append(future)
370
+
371
+ # Collect results
372
+ for future in futures:
373
+ metadata = future.result()
374
+ if metadata: # Only add metadata for newly processed samples
375
+ all_metadata.extend(metadata)
376
+
377
+ # Convert to DataFrame and split by train/test
378
+ if all_metadata: # Only process if we have new metadata
379
+ df = pd.DataFrame(all_metadata)
380
+ train_df = df[df['split'] == 'train'].drop('split', axis=1)
381
+ test_df = df[df['split'] == 'test'].drop('split', axis=1)
382
+
383
+ # Load existing metadata if it exists and append new data
384
+ train_path = os.path.join(output_dir, 'train_metadata.parquet')
385
+ test_path = os.path.join(output_dir, 'test_metadata.parquet')
386
+
387
+ if os.path.exists(train_path):
388
+ existing_train = pd.read_parquet(train_path)
389
+ train_df = pd.concat([existing_train, train_df], ignore_index=True)
390
+ # Deduplicate based on all columns
391
+ train_df = train_df.drop_duplicates(subset=['grid_cell', 'satellite', 'patch_num'])
392
+
393
+ if os.path.exists(test_path):
394
+ existing_test = pd.read_parquet(test_path)
395
+ test_df = pd.concat([existing_test, test_df], ignore_index=True)
396
+ # Deduplicate based on all columns
397
+ test_df = test_df.drop_duplicates(subset=['grid_cell', 'satellite', 'patch_num'])
398
+
399
+ # Save metadata
400
+ train_df.to_parquet(train_path)
401
+ test_df.to_parquet(test_path)
402
+
403
+ print(f"Processed {len(all_metadata) // (16 * len(satellite_types))} new grid cells")
404
+ print(f"Total metadata: {len(train_df)} training and {len(test_df)} testing samples")
405
+ else:
406
+ print("No new grid cells to process")
407
+
408
+ def visualize_patches(dataset, satellite_types, bands_per_type, output_dir):
409
+ """Visualize the coverage of patches in a world map"""
410
+ # Take the first satellite type since they're all paired
411
+ sat_type = satellite_types[0]
412
+ modality = dataset.modalities[sat_type]
413
+ df = modality['df']
414
+
415
+ # Lets split into train and test.
416
+ # First, add the split column to the dataframe based on the grid_cell_to_split dictionary
417
+ df['split'] = df['grid_cell'].map(dataset.grid_cell_to_split)
418
+
419
+ # Create coverage map
420
+ coverage_img_all = get_coveragemap(df)
421
+ coverage_img_train = get_coveragemap(df[df['split'] == 'train'])
422
+ coverage_img_test = get_coveragemap(df[df['split'] == 'test'])
423
+ coverage_img_train_test = get_coveragemap(df[df['split'] == 'train'], df[df['split'] == 'test'])
424
+
425
+ # Save the coverage map
426
+ coverage_path_all = os.path.join(output_dir, 'coverage_map_all.png')
427
+ coverage_path_train = os.path.join(output_dir, 'coverage_map_train.png')
428
+ coverage_path_test = os.path.join(output_dir, 'coverage_map_test.png')
429
+ coverage_path_train_test = os.path.join(output_dir, 'coverage_map_train_test.png')
430
+ coverage_img_all.save(coverage_path_all, format='PNG')
431
+ coverage_img_train.save(coverage_path_train, format='PNG')
432
+ coverage_img_test.save(coverage_path_test, format='PNG')
433
+ coverage_img_train_test.save(coverage_path_train_test, format='PNG')
434
+ print(f"Saved coverage maps to {coverage_path_all}, {coverage_path_train}, {coverage_path_test} and {coverage_path_train_test}")
435
+
436
+
437
+ def main():
438
+ parser = argparse.ArgumentParser(description='Extract features from MajorTOM dataset')
439
+ parser.add_argument('--subset_path', required=True, help='Path to the subset folder')
440
+ parser.add_argument('--output_dir', required=True, help='Path to the output directory')
441
+ parser.add_argument('--bands', nargs='+', required=True, help='Bands to process (e.g., B1 B2 B3 DEM vv vh)')
442
+ parser.add_argument('--ratio_train_test', type=float, default=0.95, help='Ratio of training to testing data')
443
+ parser.add_argument('--flip', action='store_true', help='Flip the patches')
444
+ parser.add_argument('--visualize', action='store_true', help='Visualize the patches in a world map')
445
+ parser.add_argument('--seed', type=int, default=42, help='Random seed')
446
+ parser.add_argument('--center_crop', action='store_true', help='Use center crop instead of patchifying')
447
+ args = parser.parse_args()
448
+
449
+ # Get subset name from path
450
+ subset_name = Path(args.subset_path).name
451
+
452
+ print("Flip is set to", args.flip)
453
+ print("Seed is set to", args.seed)
454
+ print("Subset path is", args.subset_path)
455
+ print("Bands are", args.bands)
456
+ print("Ratio train test is", args.ratio_train_test)
457
+ print("Visualize is set to", args.visualize)
458
+ print("Center crop is set to", args.center_crop)
459
+ # Create the main output directory
460
+ all_bands = '_'.join(sorted(args.bands))
461
+ os.makedirs(args.output_dir, exist_ok=True)
462
+
463
+ # Group bands by satellite type
464
+ bands_per_type = {}
465
+ satellite_types = []
466
+ for sat_type, config in SATELLITE_CONFIGS.items():
467
+ all_sat_bands = config['tif_bands'] + config['png_bands']
468
+ sat_bands = [b for b in args.bands if b in all_sat_bands]
469
+ if sat_bands:
470
+ bands_per_type[sat_type] = sat_bands
471
+ satellite_types.append(sat_type)
472
+
473
+ if satellite_types:
474
+ # Process all satellite types together
475
+ dataset, num_grid_cells = process_satellite(args.subset_path, satellite_types, bands_per_type, args.ratio_train_test, args.seed)
476
+
477
+ if args.visualize:
478
+ print("==> Visualizing patches...")
479
+ visualize_patches(dataset, satellite_types, bands_per_type, args.output_dir)
480
+ print("==> Done visualizing patches! Exiting...")
481
+ # exit()
482
+
483
+ print("==> Cropping images...")
484
+ crop_images(dataset, satellite_types, bands_per_type, args.output_dir, num_grid_cells,
485
+ flip=args.flip, center_crop=args.center_crop)
486
+
487
+ if __name__ == "__main__":
488
+ main()
src/COP-GEN-Beta/sample_n_triffuser.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ import torch
3
+ import random
4
+ import utils
5
+ from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
6
+ from absl import logging
7
+ import einops
8
+ import libs.autoencoder
9
+ from torchvision.utils import save_image, make_grid
10
+ import torchvision.transforms as standard_transforms
11
+ import numpy as np
12
+ from PIL import Image
13
+ import time
14
+ import copy
15
+ from datasets import get_dataset
16
+ from torch.utils.data import Dataset, DataLoader
17
+ from tqdm.auto import tqdm
18
+ from torch.utils._pytree import tree_map
19
+ import glob
20
+ import os
21
+ import functools
22
+ from concurrent.futures import ThreadPoolExecutor
23
+
24
+ # Add profiling tools
25
+ class Profiler:
26
+ def __init__(self):
27
+ self.times = {}
28
+
29
+ def profile(self, func):
30
+ @functools.wraps(func)
31
+ def wrapper(*args, **kwargs):
32
+ start_time = time.time()
33
+ result = func(*args, **kwargs)
34
+ end_time = time.time()
35
+
36
+ func_name = func.__name__
37
+ if func_name not in self.times:
38
+ self.times[func_name] = []
39
+ self.times[func_name].append(end_time - start_time)
40
+
41
+ return result
42
+ return wrapper
43
+
44
+ def summary(self):
45
+ print("\n----- Profiling Summary -----")
46
+ for func_name, times in self.times.items():
47
+ avg_time = sum(times) / len(times)
48
+ total_time = sum(times)
49
+ calls = len(times)
50
+ print(f"{func_name}: {total_time:.2f}s total, {avg_time:.4f}s avg, {calls} calls")
51
+ print("----------------------------\n")
52
+
53
+ profiler = Profiler()
54
+
55
+ MODALITIES = {
56
+ 4: ['dem', 's1_rtc', 's2_l1c', 's2_l2a'],
57
+ 3: ['dem', 's1_rtc', 's2_l1c'],
58
+ 2: ['s1_rtc', 's2_l2a'],
59
+ }
60
+
61
+ MODEL_RESOLUTION = 256
62
+
63
+ """
64
+ Sampling for any n modalities
65
+
66
+ > python3 sample_n_triffuser.py --config=path --data_path=path --nnet_path=path \
67
+ --n_mod=int --n_samples=int \
68
+ --generate=[modalities] --condition=[modalities]
69
+
70
+
71
+ Generate all modalities unconditional (joint):
72
+ python3 sample_n_triffuser.py --n_mod=4 --generate=s2_l1c,s2_l2a,s1_rtc,dem
73
+
74
+ Generate a pair unconditional (joint):
75
+ python3 sample_n_triffuser.py --n_mod=4 --generate=s1_rtc,s2_l2a
76
+
77
+ Generate s1_rtc and s2_l2a, conditioned on dem and s2_l1c (conditional):
78
+ python3 sample_n_triffuser.py --n_mod=4 --generate=s1_rtc,s2_l2a --condition=dem,s2_l1c
79
+
80
+ Generate dem conditioned on s1_rtc, s2_l1c, s2_l2a (conditional):
81
+ python3 sample_n_triffuser.py --n_mod=4 --generate=dem --condition=s1_rtc,s2_l1c,s2_l2a
82
+
83
+ Generate dem conditioned on s1_rtc (conditional) (the rest are automatically ignored: s2_l1c, s2_l2a):
84
+ python3 sample_n_triffuser.py --n_mod=4 --generate=dem --condition=s1_rtc
85
+
86
+ Generate dem unconditional (marginal) (no condition, the rest are ignored):
87
+ python3 sample_n_triffuser.py --n_mod=4 --generate=dem
88
+
89
+
90
+ Note:
91
+ --generate flag is mandatory
92
+ "generate" modalities and "condition" modalities should always be different
93
+
94
+
95
+ """
96
+
97
+ class CustomImageDataset(Dataset):
98
+ def __init__(self, folder_path, transform=None):
99
+ self.folder_path = folder_path
100
+ self.transform = transform
101
+ self.image_files = glob.glob(os.path.join(folder_path, "*.png"))
102
+ print("There are", len(self.image_files), "images in the dataset")
103
+
104
+ def __len__(self):
105
+ return len(self.image_files)
106
+
107
+ def __getitem__(self, idx):
108
+ image = Image.open(self.image_files[idx]).convert("RGB")
109
+ if self.transform:
110
+ image = self.transform(image)
111
+ # Return both the image and the filename
112
+ return image, os.path.basename(self.image_files[idx])
113
+
114
+
115
+
116
+ def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
117
+ _betas = (
118
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
119
+ )
120
+ return _betas.numpy()
121
+
122
+ @profiler.profile
123
+ def prepare_contexts(config, images, filenames, device, autoencoder=None):
124
+ """
125
+ If a modality is conditional, we need to return the npy feature encodings
126
+ If a modality is unconditional, we need to return random noise
127
+
128
+ batch_shape = (n_modalities, B, C, H, W)
129
+
130
+ Returns:
131
+ img_contexts: Tensor containing contexts for each modality
132
+ processed_filenames: List of filenames, duplicated and labeled with version suffixes if n_samples > 1
133
+ """
134
+
135
+ # Create a noise tensor with the same shape as the images batch
136
+ if config.data_type == 'lmdb':
137
+ effective_batch_size = images[0].shape[0] * config.n_samples if config.n_samples > 1 else images[0].shape[0]
138
+ img_contexts = torch.randn(config.num_modalities, effective_batch_size, *images[0].shape[1:], device=device)
139
+ elif config.data_type == 'folder-img':
140
+ # Calculate effective batch size (original batch size * n_samples)
141
+ effective_batch_size = images.shape[0] * config.n_samples if config.n_samples > 1 else images.shape[0]
142
+ # Multiply the images_batch shape by 2 because we have both mean and variance
143
+ # as output from the autoencoder
144
+ img_contexts = torch.randn(config.num_modalities, effective_batch_size, 2 * config.z_shape[0],
145
+ config.z_shape[1], config.z_shape[2], device=device)
146
+
147
+ # Process filenames - duplicate them if n_samples > 1 and add version suffixes
148
+ processed_filenames = []
149
+ if config.n_samples > 1:
150
+ for filename in filenames:
151
+ for i in range(config.n_samples):
152
+ processed_filenames.append(f"{filename}_v{i+1}")
153
+ else:
154
+ processed_filenames = filenames
155
+
156
+ # For each modality in the images_batch, if it is conditional, load and duplicate the npy feature encodings
157
+ for i, modality in enumerate(config.modalities):
158
+ if config.condition_modalities_mask[i]:
159
+ if config.data_type == 'lmdb':
160
+ # Duplicate each conditional input n_samples times
161
+ img_contexts[i] = images[i].repeat_interleave(config.n_samples, dim=0)
162
+ elif config.data_type == 'folder-img':
163
+ assert autoencoder is not None, "Autoencoder must be provided for folder-img data type"
164
+ # Duplicate each conditional input n_samples times
165
+ duplicated_batch = images.repeat_interleave(config.n_samples, dim=0)
166
+ img_contexts[i] = autoencoder.encode_moments(duplicated_batch)
167
+
168
+ # Padding the latents experiment
169
+ # duplicated_batch = images.repeat_interleave(config.n_samples, dim=0)
170
+ # intermediate_latents = autoencoder.encode_moments(duplicated_batch)
171
+ # padded_latents = torch.nn.functional.pad(intermediate_latents, (8, 8, 8, 8), mode='reflect')
172
+ # img_contexts[i] = padded_latents
173
+
174
+ return img_contexts, processed_filenames
175
+
176
+ def unpreprocess(v): # to B C H W and [0, 1]
177
+ v = 0.5 * (v + 1.)
178
+ v.clamp_(0., 1.)
179
+ return v
180
+
181
+
182
+ def set_seed(seed: int):
183
+ random.seed(seed)
184
+ np.random.seed(seed)
185
+ torch.manual_seed(seed)
186
+ torch.cuda.manual_seed_all(seed)
187
+
188
+
189
+
190
+ def evaluate(config):
191
+ if config.get('benchmark', False):
192
+ torch.backends.cudnn.benchmark = True
193
+ torch.backends.cudnn.deterministic = False
194
+
195
+ # Create output directory once at the start
196
+ os.makedirs(config.output_path, exist_ok=True)
197
+ # Create a directory for each modality if we are saving as pngs
198
+ if config.save_as == 'pngs':
199
+ for modality in config.generate_modalities:
200
+ os.makedirs(os.path.join(config.output_path, modality), exist_ok=True)
201
+
202
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
203
+ set_seed(config.seed)
204
+
205
+ config = ml_collections.FrozenConfigDict(config)
206
+ utils.set_logger(log_level='info')
207
+
208
+ _betas = stable_diffusion_beta_schedule()
209
+ N = len(_betas)
210
+
211
+ nnet = utils.get_nnet(**config.nnet)
212
+ logging.info(f'load nnet from {config.nnet_path}')
213
+ nnet.load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
214
+ nnet.to(device)
215
+ nnet.eval()
216
+
217
+ if config.data_type == 'lmdb':
218
+ # Edit the dataset path to the data path from the command line arguments
219
+ dataset_config = ml_collections.ConfigDict(config.to_dict())
220
+ dataset_config.dataset.path = config.data_path
221
+
222
+ # Always return the filename
223
+ dataset_config.dataset.return_filename = True
224
+
225
+ dataset = get_dataset(**dataset_config.dataset)
226
+ # TODO: This is not intuitive. Split is train but it is returning the test set. See datasets.py
227
+ test_dataset = dataset.get_split(split='train', labeled=False)
228
+ # Create a generator with fixed seed for reproducible shuffling
229
+ g = torch.Generator()
230
+ g.manual_seed(config.seed) # Using the same seed as set earlier in the code
231
+ dataloader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=True, drop_last=False,
232
+ num_workers=8, pin_memory=True, persistent_workers=True, generator=g)
233
+
234
+ elif config.data_type == 'folder-img':
235
+ print("config.data_path", config.data_path)
236
+
237
+ if config.resolution >= MODEL_RESOLUTION:
238
+ transform = standard_transforms.Compose([
239
+ standard_transforms.CenterCrop(MODEL_RESOLUTION),
240
+ standard_transforms.ToTensor(),
241
+ standard_transforms.Normalize(mean=(0.5,), std=(0.5,)),
242
+ ])
243
+ else:
244
+ padding_4sides = (MODEL_RESOLUTION - config.resolution) // 2
245
+ transform = standard_transforms.Compose([
246
+ standard_transforms.CenterCrop(config.resolution),
247
+ standard_transforms.ToTensor(),
248
+ torch.nn.ReflectionPad2d(padding_4sides),
249
+ standard_transforms.Normalize(mean=(0.5,), std=(0.5,)),
250
+ ])
251
+ dataset = CustomImageDataset(config.data_path, transform=transform)
252
+ dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False, drop_last=False,
253
+ num_workers=8, pin_memory=True, persistent_workers=True)
254
+ else:
255
+ raise ValueError(f"Invalid data type: {config.data_type}. Must be one of ['lmdb', 'folder-img']")
256
+
257
+ autoencoder = libs.autoencoder.get_model(**config.autoencoder)
258
+ autoencoder.to(device)
259
+
260
+
261
+ @profiler.profile
262
+ def split_joint(x, z_imgs, config):
263
+ """
264
+ Input:
265
+ x: (B, C, H, W)
266
+ is only the modalities that are being denoised
267
+ z_imgs: (M, B, C, H, W)
268
+ the original img_latents for all modalities
269
+ (but we only use the ones for the modalities that are being denoised)
270
+ config: config
271
+
272
+ First, split the input into the modalities into correct shape
273
+ Second, return a full list of the modalities,
274
+ including the ones being conditioned on and the ones being ignored.
275
+
276
+ Returns list of all modalities (some are denoised, some are conditioned on, some are ignored)
277
+
278
+ """
279
+
280
+ C, H, W = config.z_shape
281
+ z_dim = C * H * W
282
+ z_generated = x.split([z_dim] * len(config.generate_modalities), dim=1)
283
+ z_generated = {modality: einops.rearrange(z_i, 'B (C H W) -> B C H W', C=C, H=H, W=W)
284
+ for z_i, modality in zip(z_generated, config.generate_modalities)}
285
+
286
+ z = []
287
+ for i, modality in enumerate(config.modalities):
288
+ # Modalities that are being denoised
289
+ if modality in config.generate_modalities:
290
+ z.append(z_generated[modality])
291
+ # Modalities that are being conditioned on
292
+ elif modality in config.condition_modalities:
293
+ z.append(z_imgs[i])
294
+ # Modalities that are ignored
295
+ else:
296
+ z.append(torch.randn(x.shape[0], C, H, W, device=device))
297
+
298
+ return z
299
+
300
+
301
+ @profiler.profile
302
+ def combine_joint(z):
303
+ """
304
+ Input:
305
+ z: list of ONLY the modalities that are being denoised
306
+ Returns:
307
+ z: (B, C * H * W)
308
+ """
309
+ z = torch.concat([einops.rearrange(z_i, 'B C H W -> B (C H W)') for z_i in z], dim=-1)
310
+ return z
311
+
312
+ @torch.cuda.amp.autocast()
313
+ @profiler.profile
314
+ def encode(_batch):
315
+ return autoencoder.encode(_batch)
316
+
317
+ @torch.cuda.amp.autocast()
318
+ @profiler.profile
319
+ def decode(_batch):
320
+ return autoencoder.decode(_batch)
321
+
322
+ def get_data_generator():
323
+ # Run single epoch
324
+ for data in tqdm(dataloader, desc='epoch'):
325
+ yield data
326
+
327
+ logging.info("Num of modalities: %d", config.num_modalities)
328
+ logging.info("Num of images in dataloader: %d", len(dataloader))
329
+ logging.info("Generate modalities: %s", config.generate_modalities)
330
+ logging.info("Condition modalities: %s", config.condition_modalities)
331
+ logging.info("Condition modalities mask: %s", config.condition_modalities_mask)
332
+ logging.info("Generate modalities mask: %s", config.generate_modalities_mask)
333
+ logging.info(f'N={N}')
334
+
335
+
336
+ @profiler.profile
337
+ def run_nnet(x, t, z_imgs):
338
+
339
+ timesteps = [t if mask else torch.zeros_like(t) for mask in config.generate_modalities_mask]
340
+
341
+ # ==== EXPAND TO ALL MODALITIES ====
342
+ z = split_joint(x, z_imgs, config=config)
343
+ # z = {modality1: z_generated_modality1, modality2: z_conditioned_modality2, ...}
344
+
345
+ # == DEBUG CODE: Decode, unprocess, and save both modalities side by side
346
+ # z_decoded_1 = decode(z[0])
347
+ # z_decoded_2 = decode(z[1])
348
+ # z_decoded_1 = unpreprocess(z_decoded_1)
349
+ # z_decoded_2 = unpreprocess(z_decoded_2)
350
+ # z_decoded_combined = torch.cat([z_decoded_1, z_decoded_2], dim=-1) # Concatenate along width dimension
351
+ # print(f"saving image z_decoded_combined_{t}.png")
352
+ # save_image(z_decoded_combined, os.path.join(config.output_path, f"z_supeeerdecoded_{t}.png"))
353
+ # == DEBUG CODE END ==
354
+
355
+ """
356
+ nnet expects:
357
+ - z: (M, B, C, H, W)
358
+ - t_imgs: (M, B)
359
+ where M is the number of modalities.
360
+
361
+ That is, z should be a list of M batches, each batch corresponding to a modality.
362
+ E.g. num_modalities(M)=3, batch_size(B)=16, z_shape(C, H, W)=(4, 32, 32) ->
363
+ z = [(16, 4, 32, 32), (16, 4, 32, 32), (16, 4, 32, 32)]
364
+ t_imgs = [(16,), (16,), (16,)]
365
+ """
366
+
367
+ z_out = nnet(z, t_imgs=timesteps)
368
+
369
+ # ==== SELECT ONLY THE GENERATED MODALITIES for the denoising process ====
370
+ z_out_generated = [z_out[i]
371
+ for i, modality in enumerate(config.modalities)
372
+ if modality in config.generate_modalities]
373
+
374
+ x_out = combine_joint(z_out_generated)
375
+
376
+ if config.sample.scale == 0.:
377
+ return x_out
378
+
379
+ return x_out # TODO: Implement classifier-free guidance if there is time
380
+
381
+ @profiler.profile
382
+ def sample_fn(z_imgs, **kwargs):
383
+ # Calculate effective batch size
384
+ effective_batch_size = z_imgs[0].shape[0]
385
+
386
+ # Generate random initial noise for the modalities being generated/denoised
387
+ _z_init = torch.randn(len(config.generate_modalities), effective_batch_size, *z_imgs[0].shape[1:], device=device)
388
+
389
+ _x_init = combine_joint(_z_init)
390
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
391
+
392
+ @profiler.profile
393
+ def model_fn(x, t_continuous):
394
+ t = t_continuous * N
395
+ return run_nnet(x, t, z_imgs)
396
+
397
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
398
+ with torch.no_grad():
399
+ with torch.autocast(device_type=device):
400
+ start_time = time.time()
401
+ x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)
402
+ end_time = time.time()
403
+ print(f'\ngenerate {config.batch_size} samples with {config.sample.sample_steps} steps takes {end_time - start_time:.2f}s')
404
+
405
+ _zs = split_joint(x, z_imgs, config=config)
406
+
407
+ # Replace the conditional modalities with the original images
408
+ for i, mask in enumerate(config.condition_modalities_mask):
409
+ if mask:
410
+ _zs[i] = z_imgs[i]
411
+
412
+ return _zs
413
+
414
+
415
+ data_generator = get_data_generator()
416
+ for idx_batch, batch in enumerate(data_generator):
417
+
418
+ batch_start_time = time.time()
419
+
420
+ # Unpack the batch into images and filenames
421
+ original_images, original_filenames = batch
422
+
423
+ # print(filenames)
424
+
425
+ # Track data loading and preprocessing time
426
+ preprocess_start = time.time()
427
+ images = tree_map(lambda x: x.to(device), original_images)
428
+ # In addition to preparing the contexts (returns mean and variance),
429
+ # we need to actually sample the values from the distribution
430
+ img_contexts, filenames = prepare_contexts(config, images, original_filenames, device=device, autoencoder=autoencoder)
431
+ z_imgs = torch.stack([autoencoder.sample(img_context) for img_context in img_contexts])
432
+ preprocess_time = time.time() - preprocess_start
433
+
434
+ # Track sampling time
435
+ sample_start = time.time()
436
+ _zs = sample_fn(z_imgs)
437
+ sample_time = time.time() - sample_start
438
+
439
+ # Track decoding time
440
+ decode_start = time.time()
441
+ samples_unstacked = [unpreprocess(decode(_z)) for _z in _zs]
442
+
443
+ # Crop back to input resolution if it is smaller than MODEL_RESOLUTION
444
+ if config.resolution < MODEL_RESOLUTION:
445
+ samples_unstacked = [standard_transforms.functional.center_crop(sample, output_size=config.resolution)
446
+ for sample in samples_unstacked]
447
+
448
+ samples = torch.stack(samples_unstacked, dim=0)
449
+ decode_time = time.time() - decode_start
450
+
451
+ # Track saving time
452
+ save_start = time.time()
453
+
454
+ if config.save_as == 'grid':
455
+
456
+ b = samples.shape[1] # batch size
457
+ # Properly interleave samples from all modalities
458
+ # For each sample index, get all modalities before moving to next sample
459
+ samples = torch.stack([samples[j, i] for i in range(b) for j in range(config.nnet.num_modalities)]).view(-1, *samples.shape[2:])
460
+ # If the number of modalities is 3 then we plot in 9 columns
461
+ n_cols = 9 if config.nnet.num_modalities == 3 else 8
462
+ samples = make_grid(samples, n_cols)
463
+ save_path = os.path.join(config.output_path, f'grid_{idx_batch}.png')
464
+ save_image(samples, save_path)
465
+
466
+
467
+ # plot_real_images = '/home/s2254242/projects/pangaea_terramind/data/test_set_1/test' # We want to plot into a grid_real_images_{idx_batch}.png the real images
468
+ plot_real_images = ''
469
+
470
+ if plot_real_images != '':
471
+ # Load real images from files
472
+ real_images_list = []
473
+ for filename in original_filenames:
474
+ for modality in config.modalities:
475
+ img_path = os.path.join(plot_real_images, modality, f"{filename}.png")
476
+ img = Image.open(img_path).convert("RGB")
477
+ img_tensor = standard_transforms.ToTensor()(img)
478
+ real_images_list.append(img_tensor)
479
+
480
+ # Stack and create grid
481
+ real_images = torch.stack(real_images_list)
482
+ real_grid = make_grid(real_images, n_cols)
483
+ real_save_path = os.path.join(config.output_path, f'grid_real_{idx_batch}.png')
484
+ save_image(real_grid, real_save_path)
485
+
486
+ elif config.save_as == 'pngs':
487
+ # Define a helper function to save a single image
488
+ def save_single_image(args):
489
+ modality_idx, modality, b_idx = args
490
+ filename = filenames[b_idx] if isinstance(filenames, list) else filenames
491
+ save_path = os.path.join(os.path.join(config.output_path, modality), f"{filename}.png")
492
+ save_image(samples[modality_idx][b_idx], save_path)
493
+
494
+ # Create a list of all save operations needed
495
+ save_tasks = []
496
+ for i, modality in enumerate(config.modalities):
497
+ if modality in config.generate_modalities:
498
+ modality_dir = os.path.join(config.output_path, modality)
499
+ for b_idx in range(samples[i].shape[0]):
500
+ save_tasks.append((i, modality, b_idx))
501
+
502
+ # Use ThreadPoolExecutor to parallelize the saving process
503
+ max_workers = min(16, len(save_tasks)) # Limit to 16 threads max
504
+ if max_workers > 0: # Only create pool if there are tasks
505
+ print(f"Saving {len(save_tasks)} images using {max_workers} threads...")
506
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
507
+ list(tqdm(executor.map(save_single_image, save_tasks), total=len(save_tasks), desc="Saving images"))
508
+
509
+ elif config.data_type == 'folder-img':
510
+ # Get indices for all modalities we want to save
511
+ save_modalities = ['s1_rtc', 's2_l2a']
512
+ # append_real_from_paths = ['data/pastis_pngs/sar/']
513
+ modality_indices = [config.modalities.index(m) for m in save_modalities]
514
+
515
+ for i in range(min(config.batch_size, len(filenames))):
516
+ # Stack the samples from different modalities horizontally
517
+ concat_samples = torch.cat([samples[idx, i] for idx in modality_indices], dim=2)
518
+
519
+ # Append real images from specified paths
520
+ real_images = []
521
+ for real_path in append_real_from_paths:
522
+ real_img_path = os.path.join(real_path, filenames[i])
523
+ real_img = Image.open(real_img_path).convert("RGB")
524
+ real_img_tensor = standard_transforms.ToTensor()(real_img)
525
+ real_images.append(real_img_tensor)
526
+
527
+ real_images_tensor = torch.cat(real_images, dim=2) if len(real_images) > 1 else real_images[0]
528
+ concat_samples = torch.cat([concat_samples, real_images_tensor.to(device)], dim=2)
529
+
530
+ save_path = os.path.join(config.output_path, filenames[i])
531
+ save_image(concat_samples, save_path)
532
+
533
+ save_time = time.time() - save_start
534
+
535
+ batch_total_time = time.time() - batch_start_time
536
+
537
+ print(f'\nBatch {idx_batch} timing:')
538
+ print(f' Preprocessing: {preprocess_time:.2f}s ({preprocess_time/batch_total_time*100:.1f}%)')
539
+ print(f' Sampling: {sample_time:.2f}s ({sample_time/batch_total_time*100:.1f}%)')
540
+ print(f' Decoding: {decode_time:.2f}s ({decode_time/batch_total_time*100:.1f}%)')
541
+ print(f' Saving: {save_time:.2f}s ({save_time/batch_total_time*100:.1f}%)')
542
+ print(f' Total: {batch_total_time:.2f}s')
543
+
544
+ print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB')
545
+ print(f'\nresults are saved in {os.path.join(config.output_path)} :)')
546
+
547
+ # After processing, display the profiling summary
548
+ if idx_batch % 5 == 0 or idx_batch == len(dataloader) - 1:
549
+ profiler.summary()
550
+
551
+
552
+ from absl import flags
553
+ from absl import app
554
+ from ml_collections import config_flags
555
+ import os
556
+
557
+
558
+ FLAGS = flags.FLAGS
559
+ config_flags.DEFINE_config_file(
560
+ "config", None, "Configuration.", lock_config=False)
561
+ flags.DEFINE_string("data_path", None, "Path to the data")
562
+ flags.DEFINE_string("data_type", 'lmdb', "Type of data to load (lmdb, folder-img)")
563
+ flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
564
+ flags.DEFINE_string("output_path", None, "The path to save the generated images")
565
+ flags.DEFINE_integer("n_mod", None, "Number of modalities")
566
+ flags.DEFINE_integer("n_samples", 1, "The number of samples to generate with the same condition")
567
+ flags.DEFINE_string("generate", None, "Comma-separated list of modalities to generate (s2_l1c,s2_l2a,s1_rtc,dem)")
568
+ flags.DEFINE_string("condition", None, "Comma-separated list of modalities to condition on (s2_l1c,s2_l2a,s1_rtc,dem)")
569
+ flags.DEFINE_string("save_as", 'grid', "How to save the generated images (grid, pngs)")
570
+ flags.DEFINE_integer("resolution", 256, "The resolution of the images to generate")
571
+ flags.DEFINE_integer("seed", None, "Random seed for reproducibility (overrides config seed)")
572
+
573
+
574
+ def main(argv):
575
+ config = FLAGS.config
576
+ config.nnet_path = FLAGS.nnet_path
577
+ config.data_path = FLAGS.data_path
578
+ config.save_as = FLAGS.save_as
579
+ config.n_samples = FLAGS.n_samples if FLAGS.n_samples else 1
580
+ config.resolution = FLAGS.resolution
581
+
582
+ # Override seed if provided from command line
583
+ if FLAGS.seed is not None:
584
+ config.seed = FLAGS.seed
585
+
586
+ # batch_size controls the number of unique conditional images we use
587
+ config.batch_size = 6
588
+
589
+ config.modalities = MODALITIES[FLAGS.n_mod]
590
+
591
+ if FLAGS.generate is None:
592
+ raise ValueError("--generate flag is mandatory")
593
+
594
+ # Parse generate and condition modalities
595
+ config.generate_modalities = FLAGS.generate.split(',')
596
+ config.condition_modalities = FLAGS.condition.split(',') if FLAGS.condition else []
597
+
598
+ # Sort the modalities by the order of the config.modalities
599
+ config.generate_modalities = sorted(config.generate_modalities, key=lambda x: config.modalities.index(x))
600
+ config.condition_modalities = sorted(config.condition_modalities, key=lambda x: config.modalities.index(x))
601
+
602
+ config.generate_modalities_mask = [mod in config.generate_modalities for mod in config.modalities]
603
+ config.condition_modalities_mask = [mod in config.condition_modalities for mod in config.modalities]
604
+
605
+ # Validate modalities
606
+ valid_modalities = {'s2_l1c', 's2_l2a', 's1_rtc', 'dem'}
607
+ for mod in config.generate_modalities + config.condition_modalities:
608
+ if mod not in valid_modalities:
609
+ raise ValueError(f"Invalid modality: {mod}. Must be one of {valid_modalities}")
610
+
611
+ # Check that generate and condition modalities don't overlap
612
+ if set(config.generate_modalities) & set(config.condition_modalities):
613
+ raise ValueError("Generate and condition modalities must be different")
614
+
615
+ if FLAGS.data_type == 'lmdb':
616
+ # Check that there exists a data.mdb and a lock.mdb in the data path
617
+ if not os.path.exists(os.path.join(config.data_path, 'data.mdb')):
618
+ raise ValueError(f"data.mdb does not exist in {config.data_path}")
619
+ if not os.path.exists(os.path.join(config.data_path, 'lock.mdb')):
620
+ raise ValueError(f"lock.mdb does not exist in {config.data_path}")
621
+ elif FLAGS.data_type == 'folder-img':
622
+ # raise NotImplementedError("Folder-img data type not implemented")
623
+ pass
624
+ else:
625
+ raise ValueError(f"Invalid data type: {FLAGS.data_type}. Must be one of ['lmdb', 'folder-img']")
626
+ config.data_type = FLAGS.data_type
627
+
628
+ assert config.nnet.num_modalities == FLAGS.n_mod, "Number of modalities in the nnet must match the number of modalities in the command line arguments"
629
+ config.num_modalities = FLAGS.n_mod
630
+
631
+ # Format the output path based on conditions and modalities
632
+ clean_generate = [mod.replace('_', '') for mod in config.generate_modalities]
633
+ if config.condition_modalities:
634
+ clean_condition = [mod.replace('_', '') for mod in config.condition_modalities]
635
+ output_dir = f"condition_{'_'.join(clean_condition)}_generate_{'_'.join(clean_generate)}_{config.n_samples}samples"
636
+ else:
637
+ output_dir = f"generate_{'_'.join(clean_generate)}_{config.n_samples}samples"
638
+
639
+ if config.save_as == 'grid':
640
+ config.output_path = os.path.join(FLAGS.output_path, 'grids', output_dir)
641
+ else:
642
+ config.output_path = os.path.join(FLAGS.output_path, output_dir)
643
+
644
+ evaluate(config)
645
+
646
+ # Print final profiling summary
647
+ print("\n===== FINAL PROFILING SUMMARY =====")
648
+ profiler.summary()
649
+
650
+
651
+ if __name__ == "__main__":
652
+ app.run(main)
src/COP-GEN-Beta/scripts/download_rome.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ DATA_DIR="./data/majorTOM"
4
+ START_DATE="2017-01-01"
5
+ END_DATE="2025-01-01"
6
+ SOURCES=("Core-S2L2A" "Core-S2L1C" "Core-S1RTC" "Core-DEM")
7
+
8
+ python3 majortom/download_world.py \
9
+ --data-dir $DATA_DIR \
10
+ --sources "${SOURCES[@]}" \
11
+ --start-date $START_DATE \
12
+ --end-date $END_DATE \
13
+ --cloud-cover 0 10 \
14
+ --subset-name "rome" \
15
+ --bbox 12.2 41.6 13.0 42.2 \
16
+ --criteria "latest" \
17
+ --n-samples 10 \
18
+ --seed 42
src/COP-GEN-Beta/tools/extract_parquet.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import pyarrow.parquet as pq
4
+ import pandas as pd
5
+ import json
6
+ from pathlib import Path
7
+ import argparse
8
+
9
+ def parse_args():
10
+ parser = argparse.ArgumentParser(description='Extract all content from a parquet file')
11
+ parser.add_argument('--parquet-file', type=str, required=True,
12
+ help='Name of the parquet file to extract')
13
+ parser.add_argument('--output-dir', type=str, default='./extracted_data',
14
+ help='Directory to save extracted data (default: ./extracted_data)')
15
+ return parser.parse_args()
16
+
17
+ def extract_parquet_content(parquet_path, output_dir):
18
+ """Extract all content from a parquet file and save it to the output directory"""
19
+ output_dir = Path(output_dir)
20
+ output_dir.mkdir(exist_ok=True, parents=True)
21
+
22
+ print(f"Extracting data from {parquet_path} to {output_dir}")
23
+
24
+ # Open the parquet file
25
+ pf = pq.ParquetFile(parquet_path)
26
+ print(f"File contains {pf.num_row_groups} row groups")
27
+
28
+ # Process each row group
29
+ for rg_idx in range(pf.num_row_groups):
30
+ print(f"\nProcessing row group {rg_idx+1}/{pf.num_row_groups}")
31
+
32
+ # Read the row group
33
+ table = pf.read_row_group(rg_idx)
34
+ df = table.to_pandas()
35
+
36
+ # Create a directory for this row group
37
+ if pf.num_row_groups > 1:
38
+ rg_dir = output_dir / f"row_group_{rg_idx}"
39
+ else:
40
+ rg_dir = output_dir
41
+ rg_dir.mkdir(exist_ok=True)
42
+
43
+ # Get metadata to create more meaningful directory names if possible
44
+ product_id = df['product_id'][0] if 'product_id' in df.columns else f"sample_{rg_idx}"
45
+ grid_cell = df['grid_cell'][0] if 'grid_cell' in df.columns else ""
46
+
47
+ # Create a more descriptive directory name if possible
48
+ sample_dir = rg_dir / f"{grid_cell}_{product_id}" if grid_cell else rg_dir / product_id
49
+ sample_dir.mkdir(exist_ok=True)
50
+
51
+ # Extract and save metadata to JSON
52
+ metadata = {}
53
+ for col in df.columns:
54
+ if df[col].dtype != 'object' or (len(df[col]) > 0 and not isinstance(df[col].iloc[0], bytes)):
55
+ # Convert non-binary data to JSON-serializable format
56
+ try:
57
+ if col == 'timestamp' and pd.api.types.is_datetime64_any_dtype(df[col]):
58
+ metadata[col] = df[col].iloc[0].strftime('%Y-%m-%d %H:%M:%S')
59
+ else:
60
+ value = df[col].iloc[0]
61
+ # Handle numpy types
62
+ if hasattr(value, 'item'):
63
+ metadata[col] = value.item()
64
+ else:
65
+ metadata[col] = value
66
+ except Exception as e:
67
+ metadata[col] = f"Error converting: {str(e)}"
68
+
69
+ # Save metadata
70
+ with open(sample_dir / "metadata.json", "w") as f:
71
+ json.dump(metadata, f, indent=2, default=str)
72
+
73
+ # Extract and save binary data
74
+ binary_columns = []
75
+ for col in df.columns:
76
+ if df[col].dtype == 'object' and len(df[col]) > 0 and isinstance(df[col].iloc[0], bytes):
77
+ binary_columns.append(col)
78
+ binary_data = df[col].iloc[0]
79
+
80
+ # Determine file extension based on common column naming conventions
81
+ if col == 'thumbnail':
82
+ extension = '.png'
83
+ elif col.startswith('B') and col[1:].isdigit(): # Sentinel-2 bands
84
+ extension = '.tif'
85
+ elif col in ['vv', 'vh']: # Sentinel-1 bands
86
+ extension = '.tif'
87
+ elif col == 'DEM': # DEM data
88
+ extension = '.tif'
89
+ elif col == 'cloud_mask':
90
+ extension = '.tif'
91
+ else:
92
+ extension = '.bin' # Generic binary data
93
+
94
+ # Save binary data
95
+ file_path = sample_dir / f"{col}{extension}"
96
+ with open(file_path, "wb") as f:
97
+ f.write(binary_data)
98
+ print(f" Saved {col}{extension}, size: {len(binary_data)/1024:.1f} KB")
99
+
100
+ print(f" Extracted metadata and {len(binary_columns)} binary files to {sample_dir}")
101
+
102
+ def main():
103
+ args = parse_args()
104
+ parquet_path = Path(args.parquet_file)
105
+
106
+ if not parquet_path.exists():
107
+ print(f"Error: File {parquet_path} not found")
108
+ sys.exit(1)
109
+
110
+ # Extract all content
111
+ extract_parquet_content(parquet_path, args.output_dir)
112
+ print("\nExtraction complete!")
113
+
114
+ if __name__ == "__main__":
115
+ main()
src/COP-GEN-Beta/tools/fid_score.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2
+
3
+ The FID metric calculates the distance between two distributions of images.
4
+ Typically, we have summary statistics (mean & covariance matrix) of one
5
+ of these distributions, while the 2nd distribution is given by a GAN.
6
+
7
+ When run as a stand-alone program, it compares the distribution of
8
+ images that are stored as PNG/JPEG at a specified location with a
9
+ distribution given by summary statistics (in pickle format).
10
+
11
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
12
+ the pool_3 layer of the inception net for generated samples and real world
13
+ samples respectively.
14
+
15
+ See --help to see further details.
16
+
17
+ Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18
+ of Tensorflow
19
+
20
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
21
+
22
+ Licensed under the Apache License, Version 2.0 (the "License");
23
+ you may not use this file except in compliance with the License.
24
+ You may obtain a copy of the License at
25
+
26
+ http://www.apache.org/licenses/LICENSE-2.0
27
+
28
+ Unless required by applicable law or agreed to in writing, software
29
+ distributed under the License is distributed on an "AS IS" BASIS,
30
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ See the License for the specific language governing permissions and
32
+ limitations under the License.
33
+ """
34
+ import os
35
+ import pathlib
36
+
37
+ import numpy as np
38
+ import torch
39
+ import torchvision.transforms as TF
40
+ from PIL import Image
41
+ from scipy import linalg
42
+ from torch.nn.functional import adaptive_avg_pool2d
43
+
44
+ try:
45
+ from tqdm import tqdm
46
+ except ImportError:
47
+ # If tqdm is not available, provide a mock version of it
48
+ def tqdm(x):
49
+ return x
50
+
51
+ from .inception import InceptionV3
52
+
53
+
54
+ IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
55
+ 'tif', 'tiff', 'webp'}
56
+
57
+
58
+ class ImagePathDataset(torch.utils.data.Dataset):
59
+ def __init__(self, files, transforms=None):
60
+ self.files = files
61
+ self.transforms = transforms
62
+
63
+ def __len__(self):
64
+ return len(self.files)
65
+
66
+ def __getitem__(self, i):
67
+ path = self.files[i]
68
+ img = Image.open(path).convert('RGB')
69
+ if self.transforms is not None:
70
+ img = self.transforms(img)
71
+ return img
72
+
73
+
74
+ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8):
75
+ """Calculates the activations of the pool_3 layer for all images.
76
+
77
+ Params:
78
+ -- files : List of image files paths
79
+ -- model : Instance of inception model
80
+ -- batch_size : Batch size of images for the model to process at once.
81
+ Make sure that the number of samples is a multiple of
82
+ the batch size, otherwise some samples are ignored. This
83
+ behavior is retained to match the original FID score
84
+ implementation.
85
+ -- dims : Dimensionality of features returned by Inception
86
+ -- device : Device to run calculations
87
+ -- num_workers : Number of parallel dataloader workers
88
+
89
+ Returns:
90
+ -- A numpy array of dimension (num images, dims) that contains the
91
+ activations of the given tensor when feeding inception with the
92
+ query tensor.
93
+ """
94
+ model.eval()
95
+
96
+ if batch_size > len(files):
97
+ print(('Warning: batch size is bigger than the data size. '
98
+ 'Setting batch size to data size'))
99
+ batch_size = len(files)
100
+
101
+ dataset = ImagePathDataset(files, transforms=TF.ToTensor())
102
+ dataloader = torch.utils.data.DataLoader(dataset,
103
+ batch_size=batch_size,
104
+ shuffle=False,
105
+ drop_last=False,
106
+ num_workers=num_workers)
107
+
108
+ pred_arr = np.empty((len(files), dims))
109
+
110
+ start_idx = 0
111
+
112
+ for batch in tqdm(dataloader):
113
+ batch = batch.to(device)
114
+
115
+ with torch.no_grad():
116
+ pred = model(batch)[0]
117
+
118
+ # If model output is not scalar, apply global spatial average pooling.
119
+ # This happens if you choose a dimensionality not equal 2048.
120
+ if pred.size(2) != 1 or pred.size(3) != 1:
121
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
122
+
123
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
124
+
125
+ pred_arr[start_idx:start_idx + pred.shape[0]] = pred
126
+
127
+ start_idx = start_idx + pred.shape[0]
128
+
129
+ return pred_arr
130
+
131
+
132
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
133
+ """Numpy implementation of the Frechet Distance.
134
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
135
+ and X_2 ~ N(mu_2, C_2) is
136
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
137
+
138
+ Stable version by Dougal J. Sutherland.
139
+
140
+ Params:
141
+ -- mu1 : Numpy array containing the activations of a layer of the
142
+ inception net (like returned by the function 'get_predictions')
143
+ for generated samples.
144
+ -- mu2 : The sample mean over activations, precalculated on an
145
+ representative data set.
146
+ -- sigma1: The covariance matrix over activations for generated samples.
147
+ -- sigma2: The covariance matrix over activations, precalculated on an
148
+ representative data set.
149
+
150
+ Returns:
151
+ -- : The Frechet Distance.
152
+ """
153
+
154
+ mu1 = np.atleast_1d(mu1)
155
+ mu2 = np.atleast_1d(mu2)
156
+
157
+ sigma1 = np.atleast_2d(sigma1)
158
+ sigma2 = np.atleast_2d(sigma2)
159
+
160
+ assert mu1.shape == mu2.shape, \
161
+ 'Training and test mean vectors have different lengths'
162
+ assert sigma1.shape == sigma2.shape, \
163
+ 'Training and test covariances have different dimensions'
164
+
165
+ diff = mu1 - mu2
166
+
167
+ # Product might be almost singular
168
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
169
+ if not np.isfinite(covmean).all():
170
+ msg = ('fid calculation produces singular product; '
171
+ 'adding %s to diagonal of cov estimates') % eps
172
+ print(msg)
173
+ offset = np.eye(sigma1.shape[0]) * eps
174
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
175
+
176
+ # Numerical error might give slight imaginary component
177
+ if np.iscomplexobj(covmean):
178
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
179
+ m = np.max(np.abs(covmean.imag))
180
+ raise ValueError('Imaginary component {}'.format(m))
181
+ covmean = covmean.real
182
+
183
+ tr_covmean = np.trace(covmean)
184
+
185
+ return (diff.dot(diff) + np.trace(sigma1)
186
+ + np.trace(sigma2) - 2 * tr_covmean)
187
+
188
+
189
+ def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
190
+ device='cpu', num_workers=8):
191
+ """Calculation of the statistics used by the FID.
192
+ Params:
193
+ -- files : List of image files paths
194
+ -- model : Instance of inception model
195
+ -- batch_size : The images numpy array is split into batches with
196
+ batch size batch_size. A reasonable batch size
197
+ depends on the hardware.
198
+ -- dims : Dimensionality of features returned by Inception
199
+ -- device : Device to run calculations
200
+ -- num_workers : Number of parallel dataloader workers
201
+
202
+ Returns:
203
+ -- mu : The mean over samples of the activations of the pool_3 layer of
204
+ the inception model.
205
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
206
+ the inception model.
207
+ """
208
+ act = get_activations(files, model, batch_size, dims, device, num_workers)
209
+ mu = np.mean(act, axis=0)
210
+ sigma = np.cov(act, rowvar=False)
211
+ return mu, sigma
212
+
213
+
214
+ def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8):
215
+ if path.endswith('.npz'):
216
+ with np.load(path) as f:
217
+ m, s = f['mu'][:], f['sigma'][:]
218
+ else:
219
+ path = pathlib.Path(path)
220
+ files = sorted([file for ext in IMAGE_EXTENSIONS
221
+ for file in path.glob('*.{}'.format(ext))])
222
+ m, s = calculate_activation_statistics(files, model, batch_size,
223
+ dims, device, num_workers)
224
+
225
+ return m, s
226
+
227
+
228
+ def save_statistics_of_path(path, out_path, device=None, batch_size=50, dims=2048, num_workers=8):
229
+ if device is None:
230
+ device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
231
+ else:
232
+ device = torch.device(device)
233
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
234
+ model = InceptionV3([block_idx]).to(device)
235
+ m1, s1 = compute_statistics_of_path(path, model, batch_size, dims, device, num_workers)
236
+ np.savez(out_path, mu=m1, sigma=s1)
237
+
238
+
239
+ def calculate_fid_given_paths(paths, device=None, batch_size=50, dims=2048, num_workers=8):
240
+ """Calculates the FID of two paths"""
241
+ if device is None:
242
+ device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
243
+ else:
244
+ device = torch.device(device)
245
+
246
+ for p in paths:
247
+ if not os.path.exists(p):
248
+ raise RuntimeError('Invalid path: %s' % p)
249
+
250
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
251
+
252
+ model = InceptionV3([block_idx]).to(device)
253
+
254
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
255
+ dims, device, num_workers)
256
+ m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
257
+ dims, device, num_workers)
258
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
259
+
260
+ return fid_value
src/COP-GEN-Beta/tools/inception.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ try:
7
+ from torchvision.models.utils import load_state_dict_from_url
8
+ except ImportError:
9
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
14
+
15
+
16
+ class InceptionV3(nn.Module):
17
+ """Pretrained InceptionV3 network returning feature maps"""
18
+
19
+ # Index of default block of inception to return,
20
+ # corresponds to output of final average pooling
21
+ DEFAULT_BLOCK_INDEX = 3
22
+
23
+ # Maps feature dimensionality to their output blocks indices
24
+ BLOCK_INDEX_BY_DIM = {
25
+ 64: 0, # First max pooling features
26
+ 192: 1, # Second max pooling featurs
27
+ 768: 2, # Pre-aux classifier features
28
+ 2048: 3 # Final average pooling features
29
+ }
30
+
31
+ def __init__(self,
32
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
33
+ resize_input=True,
34
+ normalize_input=True,
35
+ requires_grad=False,
36
+ use_fid_inception=True):
37
+ """Build pretrained InceptionV3
38
+
39
+ Parameters
40
+ ----------
41
+ output_blocks : list of int
42
+ Indices of blocks to return features of. Possible values are:
43
+ - 0: corresponds to output of first max pooling
44
+ - 1: corresponds to output of second max pooling
45
+ - 2: corresponds to output which is fed to aux classifier
46
+ - 3: corresponds to output of final average pooling
47
+ resize_input : bool
48
+ If true, bilinearly resizes input to width and height 299 before
49
+ feeding input to model. As the network without fully connected
50
+ layers is fully convolutional, it should be able to handle inputs
51
+ of arbitrary size, so resizing might not be strictly needed
52
+ normalize_input : bool
53
+ If true, scales the input from range (0, 1) to the range the
54
+ pretrained Inception network expects, namely (-1, 1)
55
+ requires_grad : bool
56
+ If true, parameters of the model require gradients. Possibly useful
57
+ for finetuning the network
58
+ use_fid_inception : bool
59
+ If true, uses the pretrained Inception model used in Tensorflow's
60
+ FID implementation. If false, uses the pretrained Inception model
61
+ available in torchvision. The FID Inception model has different
62
+ weights and a slightly different structure from torchvision's
63
+ Inception model. If you want to compute FID scores, you are
64
+ strongly advised to set this parameter to true to get comparable
65
+ results.
66
+ """
67
+ super(InceptionV3, self).__init__()
68
+
69
+ self.resize_input = resize_input
70
+ self.normalize_input = normalize_input
71
+ self.output_blocks = sorted(output_blocks)
72
+ self.last_needed_block = max(output_blocks)
73
+
74
+ assert self.last_needed_block <= 3, \
75
+ 'Last possible output block index is 3'
76
+
77
+ self.blocks = nn.ModuleList()
78
+
79
+ if use_fid_inception:
80
+ inception = fid_inception_v3()
81
+ else:
82
+ inception = _inception_v3(pretrained=True)
83
+
84
+ # Block 0: input to maxpool1
85
+ block0 = [
86
+ inception.Conv2d_1a_3x3,
87
+ inception.Conv2d_2a_3x3,
88
+ inception.Conv2d_2b_3x3,
89
+ nn.MaxPool2d(kernel_size=3, stride=2)
90
+ ]
91
+ self.blocks.append(nn.Sequential(*block0))
92
+
93
+ # Block 1: maxpool1 to maxpool2
94
+ if self.last_needed_block >= 1:
95
+ block1 = [
96
+ inception.Conv2d_3b_1x1,
97
+ inception.Conv2d_4a_3x3,
98
+ nn.MaxPool2d(kernel_size=3, stride=2)
99
+ ]
100
+ self.blocks.append(nn.Sequential(*block1))
101
+
102
+ # Block 2: maxpool2 to aux classifier
103
+ if self.last_needed_block >= 2:
104
+ block2 = [
105
+ inception.Mixed_5b,
106
+ inception.Mixed_5c,
107
+ inception.Mixed_5d,
108
+ inception.Mixed_6a,
109
+ inception.Mixed_6b,
110
+ inception.Mixed_6c,
111
+ inception.Mixed_6d,
112
+ inception.Mixed_6e,
113
+ ]
114
+ self.blocks.append(nn.Sequential(*block2))
115
+
116
+ # Block 3: aux classifier to final avgpool
117
+ if self.last_needed_block >= 3:
118
+ block3 = [
119
+ inception.Mixed_7a,
120
+ inception.Mixed_7b,
121
+ inception.Mixed_7c,
122
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
123
+ ]
124
+ self.blocks.append(nn.Sequential(*block3))
125
+
126
+ for param in self.parameters():
127
+ param.requires_grad = requires_grad
128
+
129
+ def forward(self, inp):
130
+ """Get Inception feature maps
131
+
132
+ Parameters
133
+ ----------
134
+ inp : torch.autograd.Variable
135
+ Input tensor of shape Bx3xHxW. Values are expected to be in
136
+ range (0, 1)
137
+
138
+ Returns
139
+ -------
140
+ List of torch.autograd.Variable, corresponding to the selected output
141
+ block, sorted ascending by index
142
+ """
143
+ outp = []
144
+ x = inp
145
+
146
+ if self.resize_input:
147
+ x = F.interpolate(x,
148
+ size=(299, 299),
149
+ mode='bilinear',
150
+ align_corners=False)
151
+
152
+ if self.normalize_input:
153
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154
+
155
+ for idx, block in enumerate(self.blocks):
156
+ x = block(x)
157
+ if idx in self.output_blocks:
158
+ outp.append(x)
159
+
160
+ if idx == self.last_needed_block:
161
+ break
162
+
163
+ return outp
164
+
165
+
166
+ def _inception_v3(*args, **kwargs):
167
+ """Wraps `torchvision.models.inception_v3`
168
+
169
+ Skips default weight inititialization if supported by torchvision version.
170
+ See https://github.com/mseitzer/pytorch-fid/issues/28.
171
+ """
172
+ try:
173
+ version = tuple(map(int, torchvision.__version__.split('.')[:2]))
174
+ except ValueError:
175
+ # Just a caution against weird version strings
176
+ version = (0,)
177
+
178
+ if version >= (0, 6):
179
+ kwargs['init_weights'] = False
180
+
181
+ return torchvision.models.inception_v3(*args, **kwargs)
182
+
183
+
184
+ def fid_inception_v3():
185
+ """Build pretrained Inception model for FID computation
186
+
187
+ The Inception model for FID computation uses a different set of weights
188
+ and has a slightly different structure than torchvision's Inception.
189
+
190
+ This method first constructs torchvision's Inception and then patches the
191
+ necessary parts that are different in the FID Inception model.
192
+ """
193
+ inception = _inception_v3(num_classes=1008,
194
+ aux_logits=False,
195
+ pretrained=False)
196
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
197
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
198
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
199
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
200
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
201
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
202
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
203
+ inception.Mixed_7b = FIDInceptionE_1(1280)
204
+ inception.Mixed_7c = FIDInceptionE_2(2048)
205
+
206
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
207
+ inception.load_state_dict(state_dict)
208
+ return inception
209
+
210
+
211
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
212
+ """InceptionA block patched for FID computation"""
213
+ def __init__(self, in_channels, pool_features):
214
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
215
+
216
+ def forward(self, x):
217
+ branch1x1 = self.branch1x1(x)
218
+
219
+ branch5x5 = self.branch5x5_1(x)
220
+ branch5x5 = self.branch5x5_2(branch5x5)
221
+
222
+ branch3x3dbl = self.branch3x3dbl_1(x)
223
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
224
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
225
+
226
+ # Patch: Tensorflow's average pool does not use the padded zero's in
227
+ # its average calculation
228
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
229
+ count_include_pad=False)
230
+ branch_pool = self.branch_pool(branch_pool)
231
+
232
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
233
+ return torch.cat(outputs, 1)
234
+
235
+
236
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
237
+ """InceptionC block patched for FID computation"""
238
+ def __init__(self, in_channels, channels_7x7):
239
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
240
+
241
+ def forward(self, x):
242
+ branch1x1 = self.branch1x1(x)
243
+
244
+ branch7x7 = self.branch7x7_1(x)
245
+ branch7x7 = self.branch7x7_2(branch7x7)
246
+ branch7x7 = self.branch7x7_3(branch7x7)
247
+
248
+ branch7x7dbl = self.branch7x7dbl_1(x)
249
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
250
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
251
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
252
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
253
+
254
+ # Patch: Tensorflow's average pool does not use the padded zero's in
255
+ # its average calculation
256
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
257
+ count_include_pad=False)
258
+ branch_pool = self.branch_pool(branch_pool)
259
+
260
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
261
+ return torch.cat(outputs, 1)
262
+
263
+
264
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
265
+ """First InceptionE block patched for FID computation"""
266
+ def __init__(self, in_channels):
267
+ super(FIDInceptionE_1, self).__init__(in_channels)
268
+
269
+ def forward(self, x):
270
+ branch1x1 = self.branch1x1(x)
271
+
272
+ branch3x3 = self.branch3x3_1(x)
273
+ branch3x3 = [
274
+ self.branch3x3_2a(branch3x3),
275
+ self.branch3x3_2b(branch3x3),
276
+ ]
277
+ branch3x3 = torch.cat(branch3x3, 1)
278
+
279
+ branch3x3dbl = self.branch3x3dbl_1(x)
280
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
281
+ branch3x3dbl = [
282
+ self.branch3x3dbl_3a(branch3x3dbl),
283
+ self.branch3x3dbl_3b(branch3x3dbl),
284
+ ]
285
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
286
+
287
+ # Patch: Tensorflow's average pool does not use the padded zero's in
288
+ # its average calculation
289
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
290
+ count_include_pad=False)
291
+ branch_pool = self.branch_pool(branch_pool)
292
+
293
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
294
+ return torch.cat(outputs, 1)
295
+
296
+
297
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
298
+ """Second InceptionE block patched for FID computation"""
299
+ def __init__(self, in_channels):
300
+ super(FIDInceptionE_2, self).__init__(in_channels)
301
+
302
+ def forward(self, x):
303
+ branch1x1 = self.branch1x1(x)
304
+
305
+ branch3x3 = self.branch3x3_1(x)
306
+ branch3x3 = [
307
+ self.branch3x3_2a(branch3x3),
308
+ self.branch3x3_2b(branch3x3),
309
+ ]
310
+ branch3x3 = torch.cat(branch3x3, 1)
311
+
312
+ branch3x3dbl = self.branch3x3dbl_1(x)
313
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
314
+ branch3x3dbl = [
315
+ self.branch3x3dbl_3a(branch3x3dbl),
316
+ self.branch3x3dbl_3b(branch3x3dbl),
317
+ ]
318
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
319
+
320
+ # Patch: The FID Inception model uses max pooling instead of average
321
+ # pooling. This is likely an error in this specific Inception
322
+ # implementation, as other Inception models use average pooling here
323
+ # (which matches the description in the paper).
324
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
325
+ branch_pool = self.branch_pool(branch_pool)
326
+
327
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
328
+ return torch.cat(outputs, 1)
src/COP-GEN-Beta/tools/inspect_parquet.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import pyarrow.parquet as pq
4
+ import pandas as pd
5
+ from pathlib import Path
6
+ import argparse
7
+
8
+ def parse_args():
9
+ parser = argparse.ArgumentParser(description='Extract information from a parquet file')
10
+ parser.add_argument('--parquet-file', type=str, required=True,
11
+ help='Name of the parquet file in the current directory')
12
+ parser.add_argument('--row-group', type=int, default=None,
13
+ help='Specific row group to extract (default: all row groups)')
14
+ parser.add_argument('--sample-binary', action='store_true',
15
+ help='Print sample of binary content (first 100 bytes)')
16
+ return parser.parse_args()
17
+
18
+ def main():
19
+ args = parse_args()
20
+ parquet_path = Path(args.parquet_file)
21
+
22
+ if not parquet_path.exists():
23
+ print(f"Error: File {parquet_path} not found")
24
+ sys.exit(1)
25
+
26
+ print(f"\n--- Analyzing parquet file: {parquet_path} ---\n")
27
+
28
+ # Open the parquet file
29
+ pf = pq.ParquetFile(parquet_path)
30
+
31
+ # Print basic file information
32
+ print(f"File size: {parquet_path.stat().st_size / (1024*1024):.2f} MB")
33
+ print(f"Number of row groups: {pf.num_row_groups}")
34
+ print(f"Number of rows: {pf.metadata.num_rows}")
35
+ print(f"Number of columns: {len(pf.schema_arrow)}")
36
+
37
+ # Print schema information
38
+ print("\nSchema:")
39
+ for i, field in enumerate(pf.schema_arrow):
40
+ print(f" {i+1}. {field.name}: {field.type}")
41
+
42
+ # Process row groups
43
+ row_groups = [args.row_group] if args.row_group is not None else range(pf.num_row_groups)
44
+
45
+ for rg_idx in row_groups:
46
+ if rg_idx >= pf.num_row_groups:
47
+ print(f"Error: Row group {rg_idx} does not exist (max: {pf.num_row_groups-1})")
48
+ continue
49
+
50
+ print(f"\n--- Row Group {rg_idx} ---")
51
+ # Get row group metadata
52
+ rg_metadata = pf.metadata.row_group(rg_idx)
53
+ print(f"Row count: {rg_metadata.num_rows}")
54
+
55
+ # Read the row group
56
+ table = pf.read_row_group(rg_idx)
57
+ df = table.to_pandas()
58
+
59
+ # Display information about each column
60
+ print("\nColumn information:")
61
+ for col_name in df.columns:
62
+ col_data = df[col_name]
63
+ dtype = col_data.dtype
64
+
65
+ if dtype == 'object':
66
+ # Check if it's binary data
67
+ if len(col_data) > 0 and isinstance(col_data.iloc[0], bytes):
68
+ item_size = len(col_data.iloc[0])
69
+ print(f" {col_name}: Binary data, size: {item_size / 1024:.2f} KB")
70
+
71
+ if args.sample_binary and item_size > 0:
72
+ print(f" Sample (first 100 bytes): {col_data.iloc[0][:100]}")
73
+ else:
74
+ # For non-binary object columns
75
+ print(f" {col_name}: Object type, example: {col_data.iloc[0]}")
76
+ else:
77
+ # For numeric or other columns
78
+ if col_data.size > 0:
79
+ print(f" {col_name}: {dtype}, min: {col_data.min()}, max: {col_data.max()}, example: {col_data.iloc[0]}")
80
+ else:
81
+ print(f" {col_name}: {dtype}, empty column")
82
+
83
+ # Print specific metadata fields for Major-TOM dataset
84
+ if 'product_id' in df.columns:
85
+ print(f"\nProduct ID: {df['product_id'].iloc[0]}")
86
+ if 'grid_cell' in df.columns:
87
+ print(f"Grid Cell: {df['grid_cell'].iloc[0]}")
88
+ if 'timestamp' in df.columns:
89
+ print(f"Timestamp: {df['timestamp'].iloc[0]}")
90
+
91
+ if __name__ == "__main__":
92
+ main()
src/COP-GEN-Beta/tools/print_parquet_urls.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import pyarrow.parquet as pq
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser(description='Read metadata.parquet and print download URLs')
8
+ parser.add_argument('--metadata-path', type=str, required=True,
9
+ help='Path to the metadata.parquet file')
10
+ return parser.parse_args()
11
+
12
+ def main():
13
+ args = parse_args()
14
+ metadata_path = Path(args.metadata_path)
15
+
16
+ # Read the parquet file
17
+ print(f"Reading metadata from: {metadata_path}")
18
+ df = pq.read_table(metadata_path).to_pandas()
19
+
20
+ # Extract unique parquet URLs
21
+ unique_urls = df['parquet_url'].unique()
22
+
23
+ # Print the URLs
24
+ print(f"\nFound {len(unique_urls)} unique parquet file URLs:")
25
+ for url in unique_urls:
26
+ print(url)
27
+
28
+ print(f"\nTotal number of samples in metadata: {len(df)}")
29
+
30
+ if __name__ == "__main__":
31
+ main()
src/COP-GEN-Beta/train_triffuser_discrete.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ import torch
3
+ from torch import multiprocessing as mp
4
+ from datasets import get_dataset
5
+ from torchvision.utils import make_grid, save_image
6
+ import utils
7
+ import einops
8
+ from torch.utils._pytree import tree_map
9
+ import accelerate
10
+ from torch.utils.data import DataLoader
11
+ from tqdm.auto import tqdm
12
+ from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
13
+ import tempfile
14
+ from tools.fid_score import calculate_fid_given_paths
15
+ from absl import logging
16
+ import builtins
17
+ import os
18
+ import wandb
19
+ import libs.autoencoder
20
+ import numpy as np
21
+
22
+
23
+ def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
24
+ _betas = (
25
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
26
+ )
27
+ return _betas.numpy()
28
+
29
+
30
+ def get_skip(alphas, betas):
31
+ N = len(betas) - 1
32
+ skip_alphas = np.ones([N + 1, N + 1], dtype=betas.dtype)
33
+ for s in range(N + 1):
34
+ skip_alphas[s, s + 1:] = alphas[s + 1:].cumprod()
35
+ skip_betas = np.zeros([N + 1, N + 1], dtype=betas.dtype)
36
+ for t in range(N + 1):
37
+ prod = betas[1: t + 1] * skip_alphas[1: t + 1, t]
38
+ skip_betas[:t, t] = (prod[::-1].cumsum())[::-1]
39
+ return skip_alphas, skip_betas
40
+
41
+
42
+ def stp(s, ts: torch.Tensor): # scalar tensor product
43
+ if isinstance(s, np.ndarray):
44
+ s = torch.from_numpy(s).type_as(ts)
45
+ extra_dims = (1,) * (ts.dim() - 1)
46
+ return s.view(-1, *extra_dims) * ts
47
+
48
+
49
+ def mos(a, start_dim=1): # mean of square
50
+ return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
51
+
52
+
53
+ class Schedule(object): # discrete time
54
+ def __init__(self, _betas):
55
+ r""" _betas[0...999] = betas[1...1000]
56
+ for n>=1, betas[n] is the variance of q(xn|xn-1)
57
+ for n=0, betas[0]=0
58
+ """
59
+
60
+ self._betas = _betas
61
+ self.betas = np.append(0., _betas)
62
+ self.alphas = 1. - self.betas
63
+ self.N = len(_betas)
64
+
65
+ assert isinstance(self.betas, np.ndarray) and self.betas[0] == 0
66
+ assert isinstance(self.alphas, np.ndarray) and self.alphas[0] == 1
67
+ assert len(self.betas) == len(self.alphas)
68
+
69
+ # skip_alphas[s, t] = alphas[s + 1: t + 1].prod()
70
+ self.skip_alphas, self.skip_betas = get_skip(self.alphas, self.betas)
71
+ self.cum_alphas = self.skip_alphas[0] # cum_alphas = alphas.cumprod()
72
+ self.cum_betas = self.skip_betas[0]
73
+ self.snr = self.cum_alphas / self.cum_betas
74
+
75
+ def tilde_beta(self, s, t):
76
+ return self.skip_betas[s, t] * self.cum_betas[s] / self.cum_betas[t]
77
+
78
+ def sample(self, x0, multi_modal=False): # sample from q(xn|x0), where n is uniform
79
+ if multi_modal:
80
+ n_list = []
81
+ eps_list = []
82
+ xn_list = []
83
+ for x0_i in x0:
84
+ n = np.random.choice(list(range(1, self.N + 1)), (len(x0_i),))
85
+ eps = torch.randn_like(x0_i)
86
+ xn = stp(self.cum_alphas[n] ** 0.5, x0_i) + stp(self.cum_betas[n] ** 0.5, eps)
87
+ n_list.append(torch.tensor(n, device=x0_i.device))
88
+ eps_list.append(eps)
89
+ xn_list.append(xn)
90
+ return n_list, eps_list, xn_list
91
+ else:
92
+ n = np.random.choice(list(range(1, self.N + 1)), (len(x0),))
93
+ eps = torch.randn_like(x0)
94
+ xn = stp(self.cum_alphas[n] ** 0.5, x0) + stp(self.cum_betas[n] ** 0.5, eps)
95
+ return torch.tensor(n, device=x0.device), eps, xn
96
+
97
+ def __repr__(self):
98
+ return f'Schedule({self.betas[:10]}..., {self.N})'
99
+
100
+
101
+ def LSimple(x0, nnet, schedule, multi_modal=False, **kwargs):
102
+ if multi_modal:
103
+ n_list, eps_list, xn_list = schedule.sample(x0, multi_modal=multi_modal) # n in {1, ..., 1000}
104
+ eps_pred = nnet(xn_list, n_list, **kwargs)
105
+ return sum(mos(n - np_) for n, np_ in zip(eps_list, eps_pred))
106
+ else:
107
+ n, eps, xn = schedule.sample(x0) # n in {1, ..., 1000}
108
+ eps_pred = nnet(xn, n, **kwargs)
109
+ return mos(eps - eps_pred)
110
+
111
+
112
+ def train(config):
113
+ if config.get('benchmark', False):
114
+ torch.backends.cudnn.benchmark = True
115
+ torch.backends.cudnn.deterministic = False
116
+
117
+ mp.set_start_method('spawn')
118
+ accelerator = accelerate.Accelerator()
119
+ device = accelerator.device
120
+ accelerate.utils.set_seed(config.seed, device_specific=True)
121
+ logging.info(f'Process {accelerator.process_index} using device: {device}')
122
+
123
+ config.mixed_precision = accelerator.mixed_precision
124
+ config = ml_collections.FrozenConfigDict(config)
125
+
126
+ assert config.train.batch_size % accelerator.num_processes == 0
127
+ mini_batch_size = config.train.batch_size // accelerator.num_processes
128
+
129
+ if accelerator.is_main_process:
130
+ os.makedirs(config.ckpt_root, exist_ok=True)
131
+ os.makedirs(config.sample_dir, exist_ok=True)
132
+ accelerator.wait_for_everyone()
133
+ if accelerator.is_main_process:
134
+ wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(),
135
+ name=config.hparams, job_type='train', mode='offline')
136
+ utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
137
+ logging.info(config)
138
+ else:
139
+ utils.set_logger(log_level='error')
140
+ builtins.print = lambda *args: None
141
+ logging.info(f'Run on {accelerator.num_processes} devices')
142
+
143
+ dataset = get_dataset(**config.dataset)
144
+ assert os.path.exists(dataset.fid_stat)
145
+ train_dataset = dataset.get_split(split='train', labeled=config.train.mode == 'cond')
146
+ train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
147
+ num_workers=8, pin_memory=True, persistent_workers=True)
148
+
149
+ train_state = utils.initialize_train_state(config, device)
150
+ nnet, nnet_ema, optimizer, train_dataset_loader = accelerator.prepare(
151
+ train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader)
152
+ lr_scheduler = train_state.lr_scheduler
153
+ train_state.resume(config.ckpt_root)
154
+
155
+ autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
156
+ autoencoder.to(device)
157
+
158
+ @ torch.cuda.amp.autocast()
159
+ def encode(_batch):
160
+ return autoencoder.encode(_batch)
161
+
162
+ @ torch.cuda.amp.autocast()
163
+ def decode(_batch):
164
+ return autoencoder.decode(_batch)
165
+
166
+ def get_data_generator():
167
+ while True:
168
+ for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
169
+ yield data
170
+
171
+ data_generator = get_data_generator()
172
+
173
+ _betas = stable_diffusion_beta_schedule()
174
+ _schedule = Schedule(_betas)
175
+ logging.info(f'use {_schedule}')
176
+
177
+
178
+ def train_step(_batch):
179
+ _metrics = dict()
180
+ optimizer.zero_grad()
181
+ if config.train.mode == 'uncond': # Multi-modal data. Sample each modality independently
182
+ if config.train.multi_modal:
183
+ _zs = [autoencoder.sample(modality) if 'feature' in config.dataset.name else encode(modality) for modality in _batch]
184
+ loss = LSimple(_zs, nnet, _schedule, multi_modal=config.train.multi_modal)
185
+ else:
186
+ _z = autoencoder.sample(_batch) if 'feature' in config.dataset.name else encode(_batch)
187
+ loss = LSimple(_z, nnet, _schedule)
188
+ elif config.train.mode == 'cond':
189
+ _z = autoencoder.sample(_batch[0]) if 'feature' in config.dataset.name else encode(_batch[0])
190
+ loss = LSimple(_z, nnet, _schedule, y=_batch[1])
191
+ else:
192
+ raise NotImplementedError(config.train.mode)
193
+ _metrics['loss'] = accelerator.gather(loss.detach()).mean()
194
+ accelerator.backward(loss.mean())
195
+ optimizer.step()
196
+ lr_scheduler.step()
197
+ train_state.ema_update(config.get('ema_rate', 0.9999))
198
+ train_state.step += 1
199
+ return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)
200
+
201
+ def dpm_solver_sample(_n_samples, _sample_steps, **kwargs):
202
+ _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
203
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
204
+
205
+ def model_fn(x, t_continuous):
206
+ t = t_continuous * _schedule.N
207
+ eps_pre = nnet_ema(x, t, **kwargs)
208
+ return eps_pre
209
+
210
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
211
+ _z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / _schedule.N, T=1.)
212
+ return decode(_z)
213
+
214
+ def combine_joint(z):
215
+ z = torch.concat([einops.rearrange(z_i, 'B C H W -> B (C H W)') for z_i in z], dim=-1)
216
+ return z
217
+
218
+ def split_joint(x, n_modalities):
219
+ C, H, W = config.z_shape
220
+ z_dim = C * H * W
221
+ z = x.split([z_dim] * n_modalities, dim=1)
222
+ z = [einops.rearrange(z_i, 'B (C H W) -> B C H W', C=C, H=H, W=W) for z_i in z]
223
+ return z
224
+
225
+ def dpm_solver_sample_multi_modal(_n_modalities, _n_samples, _sample_steps, **kwargs):
226
+ """here"""
227
+
228
+ _z_init = torch.randn(_n_modalities, _n_samples, *config.z_shape, device=device)
229
+ _z_init = combine_joint(_z_init)
230
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
231
+
232
+ def model_fn(x, t_continuous):
233
+ t = t_continuous * _schedule.N
234
+
235
+ timesteps = [t] * _n_modalities
236
+ z = split_joint(x, _n_modalities)
237
+ z_out = nnet_ema(z, t_imgs=timesteps)
238
+ x_out = combine_joint(z_out)
239
+ # eps_pre = nnet_ema(x, t, **kwargs)
240
+ return x_out
241
+
242
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
243
+ _zs = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / _schedule.N, T=1.)
244
+ _zs = split_joint(_zs, _n_modalities)
245
+ samples_unstacked = [decode(_z) for _z in _zs]
246
+ return samples_unstacked
247
+
248
+ def eval_step(n_samples, sample_steps):
249
+ logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}'
250
+ f'mini_batch_size={config.sample.mini_batch_size}')
251
+
252
+ def sample_fn(_n_samples):
253
+ if config.train.mode == 'uncond':
254
+ kwargs = dict()
255
+ elif config.train.mode == 'cond':
256
+ kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
257
+ else:
258
+ raise NotImplementedError
259
+ return dpm_solver_sample(_n_samples, sample_steps, **kwargs)
260
+
261
+
262
+ with tempfile.TemporaryDirectory() as temp_path:
263
+ path = config.sample.path or temp_path
264
+ if accelerator.is_main_process:
265
+ os.makedirs(path, exist_ok=True)
266
+ utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
267
+
268
+ _fid = 0
269
+ if accelerator.is_main_process:
270
+ _fid = calculate_fid_given_paths((dataset.fid_stat, path))
271
+ logging.info(f'step={train_state.step} fid{n_samples}={_fid}')
272
+ with open(os.path.join(config.workdir, 'eval.log'), 'a') as f:
273
+ print(f'step={train_state.step} fid{n_samples}={_fid}', file=f)
274
+ wandb.log({f'fid{n_samples}': _fid}, step=train_state.step)
275
+ _fid = torch.tensor(_fid, device=device)
276
+ _fid = accelerator.reduce(_fid, reduction='sum')
277
+
278
+ return _fid.item()
279
+
280
+ logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')
281
+
282
+ step_fid = []
283
+ while train_state.step < config.train.n_steps:
284
+ nnet.train()
285
+ batch = tree_map(lambda x: x.to(device), next(data_generator))
286
+ metrics = train_step(batch)
287
+
288
+ nnet.eval()
289
+ if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
290
+ logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
291
+ logging.info(config.workdir)
292
+ wandb.log(metrics, step=train_state.step)
293
+
294
+ if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0:
295
+ torch.cuda.empty_cache()
296
+ logging.info('Save a grid of images...')
297
+ if config.train.mode == 'uncond':
298
+ if config.train.multi_modal:
299
+ samples = dpm_solver_sample_multi_modal(_n_modalities=config.nnet.num_modalities, _n_samples=5 * 10, _sample_steps=50)
300
+ else:
301
+ samples = dpm_solver_sample(_n_samples=5 * 10, _sample_steps=50)
302
+ elif config.train.mode == 'cond':
303
+ y = einops.repeat(torch.arange(5, device=device) % dataset.K, 'nrow -> (nrow ncol)', ncol=10)
304
+ samples = dpm_solver_sample(_n_samples=5 * 10, _sample_steps=50, y=y)
305
+ else:
306
+ raise NotImplementedError
307
+
308
+ if config.train.multi_modal:
309
+ samples = torch.stack([dataset.unpreprocess(sample) for sample in samples], dim=0) # stack instead of cat
310
+ b = samples.shape[1] # batch size
311
+ # Properly interleave samples from all modalities
312
+ # For each sample index, get all modalities before moving to next sample
313
+ samples = torch.stack([samples[j, i] for i in range(b) for j in range(config.nnet.num_modalities)]).view(-1, *samples.shape[2:])
314
+ # If the number of modalities is 3 then we plot in 9 columns
315
+ n_cols = 9 if config.nnet.num_modalities == 3 else 10
316
+ samples = make_grid(samples, n_cols)
317
+ else:
318
+ samples = make_grid(dataset.unpreprocess(samples), 10)
319
+ save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png'))
320
+ wandb.log({'samples': wandb.Image(samples)}, step=train_state.step)
321
+ torch.cuda.empty_cache()
322
+ accelerator.wait_for_everyone()
323
+
324
+ if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
325
+ torch.cuda.empty_cache()
326
+ logging.info(f'Save and eval checkpoint {train_state.step}...')
327
+ if accelerator.is_main_process:
328
+ try:
329
+ train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
330
+ except Exception as e:
331
+ logging.error(f" ==> Failed to save checkpoint: {e}!!!")
332
+ accelerator.wait_for_everyone()
333
+ # TODO: Skip FID for now
334
+ # fid = eval_step(n_samples=10000, sample_steps=50) # calculate fid of the saved checkpoint
335
+ # step_fid.append((train_state.step, fid))
336
+ torch.cuda.empty_cache()
337
+ accelerator.wait_for_everyone()
338
+
339
+ logging.info(f'Finish fitting, step={train_state.step}')
340
+ logging.info(f'step_fid: {step_fid}')
341
+ step_best = sorted(step_fid, key=lambda x: x[1])[0][0]
342
+ logging.info(f'step_best: {step_best}')
343
+ train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt'))
344
+ del metrics
345
+ accelerator.wait_for_everyone()
346
+ eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps)
347
+
348
+
349
+
350
+ from absl import flags
351
+ from absl import app
352
+ from ml_collections import config_flags
353
+ import sys
354
+ from pathlib import Path
355
+
356
+
357
+ FLAGS = flags.FLAGS
358
+ config_flags.DEFINE_config_file(
359
+ "config", None, "Training configuration.", lock_config=False)
360
+ flags.mark_flags_as_required(["config"])
361
+ flags.DEFINE_string("workdir", None, "Work unit directory.")
362
+
363
+
364
+ def get_config_name():
365
+ argv = sys.argv
366
+ for i in range(1, len(argv)):
367
+ if argv[i].startswith('--config='):
368
+ return Path(argv[i].split('=')[-1]).stem
369
+
370
+ def get_config_path():
371
+ argv = sys.argv
372
+ for i in range(1, len(argv)):
373
+ if argv[i].startswith('--config='):
374
+ path = argv[i].split('=')[-1]
375
+ if path.startswith('configs/'):
376
+ path = path[len('configs/'):]
377
+ return path
378
+
379
+ def get_hparams():
380
+ argv = sys.argv
381
+ lst = []
382
+ for i in range(1, len(argv)):
383
+ assert '=' in argv[i]
384
+ if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'):
385
+ hparam, val = argv[i].split('=')
386
+ hparam = hparam.split('.')[-1]
387
+ if hparam.endswith('path'):
388
+ val = Path(val).stem
389
+ lst.append(f'{hparam}={val}')
390
+ hparams = '-'.join(lst)
391
+ if hparams == '':
392
+ hparams = 'default'
393
+ return hparams
394
+
395
+
396
+ def main(argv):
397
+ config = FLAGS.config
398
+ # config.config_name = get_config_name()
399
+ config.config_name = get_config_path().strip('.py')
400
+ config.hparams = get_hparams()
401
+ config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams)
402
+ config.ckpt_root = os.path.join(config.workdir, 'ckpts')
403
+ config.sample_dir = os.path.join(config.workdir, 'samples')
404
+ train(config)
405
+
406
+
407
+ if __name__ == "__main__":
408
+ app.run(main)
src/COP-GEN-Beta/utils.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import os
5
+ from tqdm import tqdm
6
+ from torchvision.utils import save_image
7
+ from absl import logging
8
+ from PIL import Image, ImageDraw, ImageFont
9
+
10
+ def set_logger(log_level='info', fname=None):
11
+ import logging as _logging
12
+ handler = logging.get_absl_handler()
13
+ formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s')
14
+ handler.setFormatter(formatter)
15
+ logging.set_verbosity(log_level)
16
+ if fname is not None:
17
+ handler = _logging.FileHandler(fname)
18
+ handler.setFormatter(formatter)
19
+ logging.get_absl_logger().addHandler(handler)
20
+
21
+
22
+ def dct2str(dct):
23
+ return str({k: f'{v:.6g}' for k, v in dct.items()})
24
+
25
+
26
+ def get_nnet(name, **kwargs):
27
+ if name == 'uvit':
28
+ from libs.uvit import UViT
29
+ return UViT(**kwargs)
30
+ elif name == 'uvit_t2i':
31
+ from libs.uvit_t2i import UViT
32
+ return UViT(**kwargs)
33
+ elif name == 'uvit_multi_post_ln':
34
+ from libs.uvit_multi_post_ln import UViT
35
+ return UViT(**kwargs)
36
+ elif name == 'uvit_multi_post_ln_v1':
37
+ from libs.uvit_multi_post_ln_v1 import UViT
38
+ return UViT(**kwargs)
39
+ elif name == 'triffuser_multi_post_ln':
40
+ from libs.triffuser_multi_post_ln import Triffuser
41
+ return Triffuser(**kwargs)
42
+ else:
43
+ raise NotImplementedError(name)
44
+
45
+
46
+ def set_seed(seed: int):
47
+ if seed is not None:
48
+ torch.manual_seed(seed)
49
+ np.random.seed(seed)
50
+
51
+
52
+ def get_optimizer(params, name, **kwargs):
53
+ if name == 'adam':
54
+ from torch.optim import Adam
55
+ return Adam(params, **kwargs)
56
+ elif name == 'adamw':
57
+ from torch.optim import AdamW
58
+ return AdamW(params, **kwargs)
59
+ else:
60
+ raise NotImplementedError(name)
61
+
62
+
63
+ def customized_lr_scheduler(optimizer, warmup_steps=-1):
64
+ from torch.optim.lr_scheduler import LambdaLR
65
+ def fn(step):
66
+ if warmup_steps > 0:
67
+ return min(step / warmup_steps, 1)
68
+ else:
69
+ return 1
70
+ return LambdaLR(optimizer, fn)
71
+
72
+
73
+ def get_lr_scheduler(optimizer, name, **kwargs):
74
+ if name == 'customized':
75
+ return customized_lr_scheduler(optimizer, **kwargs)
76
+ elif name == 'cosine':
77
+ from torch.optim.lr_scheduler import CosineAnnealingLR
78
+ return CosineAnnealingLR(optimizer, **kwargs)
79
+ else:
80
+ raise NotImplementedError(name)
81
+
82
+
83
+ def ema(model_dest: nn.Module, model_src: nn.Module, rate):
84
+ param_dict_src = dict(model_src.named_parameters())
85
+ for p_name, p_dest in model_dest.named_parameters():
86
+ p_src = param_dict_src[p_name]
87
+ assert p_src is not p_dest
88
+ p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
89
+
90
+
91
+ class TrainState(object):
92
+ def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None):
93
+ self.optimizer = optimizer
94
+ self.lr_scheduler = lr_scheduler
95
+ self.step = step
96
+ self.nnet = nnet
97
+ self.nnet_ema = nnet_ema
98
+
99
+ def ema_update(self, rate=0.9999):
100
+ if self.nnet_ema is not None:
101
+ ema(self.nnet_ema, self.nnet, rate)
102
+
103
+ def save(self, path):
104
+ os.makedirs(path, exist_ok=True)
105
+ torch.save(self.step, os.path.join(path, 'step.pth'))
106
+ for key, val in self.__dict__.items():
107
+ if key != 'step' and val is not None:
108
+ torch.save(val.state_dict(), os.path.join(path, f'{key}.pth'))
109
+
110
+ def load(self, path):
111
+ logging.info(f'load from {path}')
112
+ self.step = torch.load(os.path.join(path, 'step.pth'))
113
+ for key, val in self.__dict__.items():
114
+ if key != 'step' and val is not None:
115
+ val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu'))
116
+
117
+ def resume(self, ckpt_root, step=None):
118
+ if not os.path.exists(ckpt_root):
119
+ return
120
+ if step is None:
121
+ ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root)))
122
+ if not ckpts:
123
+ return
124
+ steps = map(lambda x: int(x.split(".")[0]), ckpts)
125
+ step = max(steps)
126
+ ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt')
127
+ logging.info(f'resume from {ckpt_path}')
128
+ self.load(ckpt_path)
129
+
130
+ def to(self, device):
131
+ for key, val in self.__dict__.items():
132
+ if isinstance(val, nn.Module):
133
+ val.to(device)
134
+
135
+
136
+ def cnt_params(model):
137
+ return sum(param.numel() for param in model.parameters())
138
+
139
+
140
+ def initialize_train_state(config, device):
141
+ params = []
142
+
143
+ nnet = get_nnet(**config.nnet)
144
+ params += nnet.parameters()
145
+ nnet_ema = get_nnet(**config.nnet)
146
+ nnet_ema.eval()
147
+ logging.info(f'nnet has {cnt_params(nnet)} parameters')
148
+
149
+ optimizer = get_optimizer(params, **config.optimizer)
150
+ lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler)
151
+
152
+ train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0,
153
+ nnet=nnet, nnet_ema=nnet_ema)
154
+ train_state.ema_update(0)
155
+ train_state.to(device)
156
+ return train_state
157
+
158
+
159
+ def amortize(n_samples, batch_size):
160
+ k = n_samples // batch_size
161
+ r = n_samples % batch_size
162
+ return k * [batch_size] if r == 0 else k * [batch_size] + [r]
163
+
164
+
165
+ def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None):
166
+ os.makedirs(path, exist_ok=True)
167
+ idx = 0
168
+ batch_size = mini_batch_size * accelerator.num_processes
169
+
170
+ for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'):
171
+ samples = unpreprocess_fn(sample_fn(mini_batch_size))
172
+ samples = accelerator.gather(samples.contiguous())[:_batch_size]
173
+ if accelerator.is_main_process:
174
+ for sample in samples:
175
+ save_image(sample, os.path.join(path, f"{idx}.png"))
176
+ idx += 1
177
+
178
+
179
+ def grad_norm(model):
180
+ total_norm = 0.
181
+ for p in model.parameters():
182
+ param_norm = p.grad.data.norm(2)
183
+ total_norm += param_norm.item() ** 2
184
+ total_norm = total_norm ** (1. / 2)
185
+ return total_norm
186
+
187
+
188
+
189
+ def center_crop(width, height, img):
190
+ resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos']
191
+ crop = np.min(img.shape[:2])
192
+ img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
193
+ (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2] # center crop
194
+ try:
195
+ img = Image.fromarray(img, 'RGB')
196
+ except:
197
+ img = Image.fromarray(img)
198
+ img = img.resize((width, height), resample) # resize the center crop from [crop, crop] to [width, height]
199
+
200
+ return np.array(img).astype(np.uint8)
201
+
202
+
203
+ def drawRoundRec(draw, color, x, y, w, h, r):
204
+ drawObject = draw
205
+
206
+ '''Rounds'''
207
+ drawObject.ellipse((x, y, x + r, y + r), fill=color)
208
+ drawObject.ellipse((x + w - r, y, x + w, y + r), fill=color)
209
+ drawObject.ellipse((x, y + h - r, x + r, y + h), fill=color)
210
+ drawObject.ellipse((x + w - r, y + h - r, x + w, y + h), fill=color)
211
+
212
+ '''rec.s'''
213
+ drawObject.rectangle((x + r / 2, y, x + w - (r / 2), y + h), fill=color)
214
+ drawObject.rectangle((x, y + r / 2, x + w, y + h - (r / 2)), fill=color)
215
+
216
+
217
+ def add_water(img, text='UniDiffuser', pos=3):
218
+ width, height = img.size
219
+ scale = 4
220
+ scale_size = 0.5
221
+ img = img.resize((width * scale, height * scale), Image.LANCZOS)
222
+ result = Image.new(img.mode, (width * scale, height * scale), color=(255, 255, 255))
223
+ result.paste(img, box=(0, 0))
224
+
225
+ delta_w = int(width * scale * 0.27 * scale_size) # text width
226
+ delta_h = width * scale * 0.05 * scale_size # text height
227
+ postions = np.array([[0, 0], [0, height * scale - delta_h], [width * scale - delta_w, 0],
228
+ [width * scale - delta_w, height * scale - delta_h]])
229
+ postion = postions[pos]
230
+ # 文本
231
+ draw = ImageDraw.Draw(result)
232
+ fillColor = (107, 92, 231)
233
+ setFont = ImageFont.truetype("assets/ArialBoldMT.ttf", int(width * scale * 0.05 * scale_size))
234
+ delta = 20 * scale_size
235
+ padding = 15 * scale_size
236
+ drawRoundRec(draw, (223, 230, 233), postion[0] - delta - padding, postion[1] - delta - padding,
237
+ w=delta_w + 2 * padding, h=delta_h + 2 * padding, r=50 * scale_size)
238
+ draw.text((postion[0] - delta, postion[1] - delta), text, font=setFont, fill=fillColor)
239
+
240
+ return result.resize((width, height), Image.LANCZOS)