Spaces:
Running
on
Zero
Running
on
Zero
mikonvergence
commited on
Commit
·
82f1234
1
Parent(s):
f7d38e5
github code incorporated
Browse files- app.py +1 -1
- src/COP-GEN-Beta +0 -1
- src/COP-GEN-Beta/.gitignore +13 -0
- src/COP-GEN-Beta/README.md +341 -0
- src/COP-GEN-Beta/configs/majortom/discrete/lmdb/rome_dems1s2s2_cop_gen_beta.py +64 -0
- src/COP-GEN-Beta/configs/majortom/discrete/rome_dems1s2s2_cop_gen_beta.py +62 -0
- src/COP-GEN-Beta/create_lmdb.py +213 -0
- src/COP-GEN-Beta/datasets.py +885 -0
- src/COP-GEN-Beta/dpm_solver_pp.py +952 -0
- src/COP-GEN-Beta/encode_majortom_images.py +95 -0
- src/COP-GEN-Beta/libs/__init__.py +1 -0
- src/COP-GEN-Beta/libs/autoencoder.py +519 -0
- src/COP-GEN-Beta/libs/timm.py +112 -0
- src/COP-GEN-Beta/libs/triffuser_multi_post_ln.py +290 -0
- src/COP-GEN-Beta/majortom/NMajorTOM.py +170 -0
- src/COP-GEN-Beta/majortom/coverage_vis.py +149 -0
- src/COP-GEN-Beta/majortom/download_world.py +1009 -0
- src/COP-GEN-Beta/prepare_dataset_images.py +488 -0
- src/COP-GEN-Beta/sample_n_triffuser.py +652 -0
- src/COP-GEN-Beta/scripts/download_rome.sh +18 -0
- src/COP-GEN-Beta/tools/extract_parquet.py +115 -0
- src/COP-GEN-Beta/tools/fid_score.py +260 -0
- src/COP-GEN-Beta/tools/inception.py +328 -0
- src/COP-GEN-Beta/tools/inspect_parquet.py +92 -0
- src/COP-GEN-Beta/tools/print_parquet_urls.py +31 -0
- src/COP-GEN-Beta/train_triffuser_discrete.py +408 -0
- src/COP-GEN-Beta/utils.py +240 -0
app.py
CHANGED
@@ -11,7 +11,7 @@ with gr.Blocks(theme=theme) as demo:
|
|
11 |
gr.Markdown("# 🔵 COP-GEN-Beta: Unified Generative Modelling of COPernicus Imagery Thumbnails")
|
12 |
gr.Markdown("### Miguel Espinosa, Valerio Marsocci, Yuru Jia, Elliot J. Crowley, Mikolaj Czerkawski")
|
13 |
gr.Markdown('[[Website](https://miquel-espinosa.github.io/cop-gen-beta/)] [[GitHub](https://github.com/miquel-espinosa/COP-GEN-Beta)] [[Model](https://huggingface.co/mespinosami/COP-GEN-Beta)] [[Dataset](https://huggingface.co/Major-TOM)]')
|
14 |
-
gr.Markdown('> ## ⚠️ NOTE: This is a prototype Beta model of COP-GEN. It is based on image thumbnails of Major TOM and does not yet support raw source data. The hillshade visualisation is used for elevation. The full model COP-GEN is coming soon.')
|
15 |
|
16 |
with gr.Column(elem_classes="Main app"):
|
17 |
|
|
|
11 |
gr.Markdown("# 🔵 COP-GEN-Beta: Unified Generative Modelling of COPernicus Imagery Thumbnails")
|
12 |
gr.Markdown("### Miguel Espinosa, Valerio Marsocci, Yuru Jia, Elliot J. Crowley, Mikolaj Czerkawski")
|
13 |
gr.Markdown('[[Website](https://miquel-espinosa.github.io/cop-gen-beta/)] [[GitHub](https://github.com/miquel-espinosa/COP-GEN-Beta)] [[Model](https://huggingface.co/mespinosami/COP-GEN-Beta)] [[Dataset](https://huggingface.co/Major-TOM)]')
|
14 |
+
gr.Markdown('> ## ⚠️ NOTE: This is a prototype Beta model of COP-GEN. It is based on image thumbnails of [[Major TOM](https://huggingface.co/Major-TOM)] and does not yet support raw source data. The hillshade visualisation is used for elevation. The full model COP-GEN is coming soon.')
|
15 |
|
16 |
with gr.Column(elem_classes="Main app"):
|
17 |
|
src/COP-GEN-Beta
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
Subproject commit eef71c50f3a233c30f1e6f8b87d7815494eb1ff2
|
|
|
|
src/COP-GEN-Beta/.gitignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data
|
2 |
+
./data
|
3 |
+
assets
|
4 |
+
./assets
|
5 |
+
workdir
|
6 |
+
./workdir
|
7 |
+
__pycache__/
|
8 |
+
**__pycache__/
|
9 |
+
*.out
|
10 |
+
*.pth
|
11 |
+
out_images/
|
12 |
+
models/
|
13 |
+
models
|
src/COP-GEN-Beta/README.md
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+

|
2 |
+
|
3 |
+
# [CVPRW 2025] 🌍 COP-GEN-Beta: Unified Generative Modelling of COPernicus Imagery Thumbnails
|
4 |
+
|
5 |
+
[](https://huggingface.co/mespinosami/COP-GEN-Beta)
|
6 |
+
[](https://github.com/miquel-espinosa/COP-GEN-Beta)
|
7 |
+
[](https://miquel-espinosa.github.io/cop-gen-beta/)
|
8 |
+
[](https://huggingface.co/mespinosami/COP-GEN-Beta)
|
9 |
+
[](https://www.arxiv.org/abs/2504.08548)
|
10 |
+
<a href="https://colab.research.google.com/github/ESA-PhiLab/Major-TOM/blob/main/03-Filtering-in-Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
11 |
+
|
12 |
+
This repository contains the official implementation of our paper:
|
13 |
+
|
14 |
+
[COP-GEN-Beta: Unified Generative Modelling of COPernicus Imagery Thumbnails,
|
15 |
+
*Miguel Espinosa*, *Valerio Marsocci*, *Yuru Jia*, *Elliot J. Crowley*, *Mikolaj Czerkawski*, CVPRW 2025](https://www.arxiv.org/pdf/2504.08548)
|
16 |
+
|
17 |
+
### Abstract
|
18 |
+
> _In remote sensing, multi-modal data from various sensors capturing the same scene_
|
19 |
+
_offers rich opportunities, but learning a unified representation across these modalities remains a significant challenge._
|
20 |
+
_Traditional methods have often been limited to single or dual-modality approaches._
|
21 |
+
_In this paper, we introduce COP-GEN-Beta, a generative diffusion model trained on optical, radar, and elevation data from the Major TOM dataset._
|
22 |
+
_What sets COP-GEN-Beta apart is its ability to map any subset of modalities to any other, enabling zero-shot modality translation after training._
|
23 |
+
_This is achieved through a sequence-based diffusion transformer, where each modality is controlled by its own timestep embedding._
|
24 |
+
_We extensively evaluate COP-GEN-Beta on thumbnail images from the Major TOM dataset, demonstrating its effectiveness in generating high-quality samples._
|
25 |
+
_Qualitative and quantitative evaluations validate the model's performance, highlighting its potential as a powerful pre-trained model for future remote sensing tasks._
|
26 |
+
|
27 |
+
<!-- <details> -->
|
28 |
+
<!-- <summary><h3><b>Table of Contents</b></h3></summary> -->
|
29 |
+
|
30 |
+
### Table of Contents
|
31 |
+
- [Architecture Overview](#cop-gen-beta-architecture-overview)
|
32 |
+
- [Training Code Instructions](#training-code-instructions)
|
33 |
+
- [0. Basic folder setup](#0-basic-folder-setup)
|
34 |
+
- [1. Download training data](#1-download-training-data-subset-example-rome)
|
35 |
+
- [2. Patchify and encode thumbnails](#2-patchify-and-encode-thumbnails)
|
36 |
+
- [3. Pre-compute features with Stable Diffusion](#3-pre-compute-features-with-stable-diffusion-pretrained-autoencoder)
|
37 |
+
- [4. Convert dataset to LMDB (optional)](#4-convert-dataset-to-lmdb-optional)
|
38 |
+
- [5. Train the model](#5-train-the-model)
|
39 |
+
- [Inference Instructions](#cop-gen-beta-inference-instructions)
|
40 |
+
- [1. Download model checkpoint](#1-download-model-checkpoint)
|
41 |
+
- [2. Run inference on test set](#2-run-inference-on-test-set-rome-subset)
|
42 |
+
- [Example 1: Unconditional generation](#example-1-unconditional-generation)
|
43 |
+
- [Example 2: Single modality conditioning](#example-2-single-modality-conditioning)
|
44 |
+
- [Example 3: 2 modality conditioning](#example-3-2-modality-conditioning)
|
45 |
+
- [Example 4: 3 modality conditioning](#example-4-3-modality-conditioning)
|
46 |
+
|
47 |
+
<!-- </details> -->
|
48 |
+
|
49 |
+
# COP-GEN-Beta: Architecture Overview
|
50 |
+
|
51 |
+
We introduce COP-GEN-Beta, a diffusion model designed to handle multiple remote sensing modalities. Specifically, COP-GEN-Beta operates on four key EO modalities: Digital Elevation Model (DEM), Sentinel-1 Radar Terrain Corrected (S1 RTC), Sentinel-2 Level 1C (S2 L1C), and Sentinel-2 Level 2A (S2 L2A). Unlike previous approaches, which require separate models for per modality, COP-GEN-Beta learns joint, conditional, and marginal distributions within a unified framework.
|
52 |
+
|
53 |
+
This is achieved by (a) sampling a global and dense dataset of these modalities from Major TOM, encoding all images with a pretrained StableDiffusion autoencoder, and (b) training a sequence-based denoising diffusion model using a transformer backbone, where each modality is supplied with its designated timestep. This approach makes it possible to (c) generate all modalities based on any subset thereof that is available.
|
54 |
+
|
55 |
+

|
56 |
+
|
57 |
+
# COP-GEN-Beta: Results
|
58 |
+
|
59 |
+
COP-GEN-Beta's flexible sampling capabilities enable a wide range of downstream applications
|
60 |
+
through various modality translation combinations. By allowing generation of any subset of
|
61 |
+
modalities conditioned on any other subset, our model unlocks numerous practical use cases in
|
62 |
+
remote sensing, from atmospheric correction and DEM generation to dataset expansion.
|
63 |
+
|
64 |
+

|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
# COP-GEN-Beta: Training code instructions.
|
69 |
+
|
70 |
+
## 0. Basic folder setup
|
71 |
+
|
72 |
+
Data will be stored in `./data/`. Create a symlink if you need.
|
73 |
+
|
74 |
+
```bash
|
75 |
+
ln -s /path/to/disk/with/storage/ ./data
|
76 |
+
```
|
77 |
+
|
78 |
+
Download stable diffusion model weights.
|
79 |
+
|
80 |
+
```bash
|
81 |
+
mkdir -p ./assets/stable-diffusion
|
82 |
+
```
|
83 |
+
|
84 |
+
Download from [here](https://drive.google.com/drive/folders/1sV-IvcGUrZeIlTmtuKv9vDJiB4JEHL-f) and place in `./assets/stable-diffusion`.
|
85 |
+
|
86 |
+
|
87 |
+
## 1. Download training data (subset example: Rome)
|
88 |
+
|
89 |
+
Select a subset of data to download for training. For example, lets download the region of Rome.
|
90 |
+
|
91 |
+
```bash
|
92 |
+
sh scripts/download_rome.sh
|
93 |
+
```
|
94 |
+
|
95 |
+
Run `python3 scripts/download_world.py --help` to see the available options.
|
96 |
+
|
97 |
+
The folder structure generated will look like this:
|
98 |
+
```
|
99 |
+
data/majorTOM
|
100 |
+
├── Core-DEM
|
101 |
+
│ ├── metadata.parquet
|
102 |
+
├── Core-S1RTC
|
103 |
+
│ ├── metadata.parquet
|
104 |
+
├── Core-S2L1C
|
105 |
+
│ ├── metadata.parquet
|
106 |
+
├── Core-S2L2A
|
107 |
+
│ ├── metadata.parquet
|
108 |
+
├── rome
|
109 |
+
│ ├── Core-DEM
|
110 |
+
│ │ ├── metadata.parquet (metadata.parquet for rome subset)
|
111 |
+
│ │ ├── <grid_cell>
|
112 |
+
│ │ │ ├── ...
|
113 |
+
| │ │ ├── compressed.tif
|
114 |
+
| │ │ ├── DEM.tif
|
115 |
+
| │ │ └── thumbnail.png
|
116 |
+
│ ├── Core-S1RTC
|
117 |
+
│ │ ├── metadata.parquet (metadata.parquet for rome subset)
|
118 |
+
│ │ ├── <grid_cell>
|
119 |
+
│ │ │ ├── ...
|
120 |
+
│ │ │ ├── vh.tif
|
121 |
+
│ │ │ ├── vv.tif
|
122 |
+
│ │ │ ├── thumbnail.png
|
123 |
+
```
|
124 |
+
|
125 |
+
## 2. Patchify and encode thumbnails.
|
126 |
+
|
127 |
+
Align image modalities (find common grid_cells), patchify into 256x256 patches (thumbnails are 1068x1068, we first crop to 1024x1024, then patchify), create train/test splits.
|
128 |
+
|
129 |
+
```bash
|
130 |
+
python3 prepare_dataset_images.py --subset_path data/majorTOM/rome --output_dir data/majorTOM/rome/rome_thumbnail_png --bands thumbnail
|
131 |
+
```
|
132 |
+
|
133 |
+
## 3. Pre-compute features with Stable Diffusion pretrained autoencoder.
|
134 |
+
|
135 |
+
```bash
|
136 |
+
bands=(DEM_thumbnail S1RTC_thumbnail S2L1C_thumbnail S2L2A_thumbnail)
|
137 |
+
splits=(train test)
|
138 |
+
for band in "${bands[@]}"; do
|
139 |
+
for split in "${splits[@]}"; do
|
140 |
+
python3 encode_majortom_images.py \
|
141 |
+
--path "data/majorTOM/rome/rome_thumbnail_png/${split}/${band}" \
|
142 |
+
--resolution 256 \
|
143 |
+
--output_dir "data/majorTOM/rome/rome_thumbnail_npy/${split}/${band}"
|
144 |
+
done
|
145 |
+
done
|
146 |
+
```
|
147 |
+
|
148 |
+
Folder structure generated for the command above:
|
149 |
+
```
|
150 |
+
data/majorTOM/rome/
|
151 |
+
|
152 |
+
├── train
|
153 |
+
│ ├── DEM_thumbnail
|
154 |
+
│ │ ├── 0.npy
|
155 |
+
│ │ ├── 1.npy
|
156 |
+
│ │ ├── ...
|
157 |
+
│ ├── S1RTC_thumbnail
|
158 |
+
│ │ ├── 0.npy
|
159 |
+
│ │ ├── 1.npy
|
160 |
+
│ │ ├── ...
|
161 |
+
│ ├── S2L1C_thumbnail
|
162 |
+
│ │ ├── 0.npy
|
163 |
+
│ │ ├── 1.npy
|
164 |
+
│ │ ├── ...
|
165 |
+
│ ├── S2L2A_thumbnail
|
166 |
+
│ │ ├── 0.npy
|
167 |
+
│ │ ├── 1.npy
|
168 |
+
│ │ ├── ...
|
169 |
+
├── test
|
170 |
+
│ ├── DEM_thumbnail
|
171 |
+
│ │ ├── 0.npy
|
172 |
+
│ │ ├── 1.npy
|
173 |
+
│ │ ├── ...
|
174 |
+
```
|
175 |
+
## 4. Convert dataset to LMDB (optional).
|
176 |
+
|
177 |
+
Convert npy files to LMDB dataset, for both train and test splits. (Update --batch-size to a lower value if it doesn't work).
|
178 |
+
|
179 |
+
```bash
|
180 |
+
python3 create_lmdb.py \
|
181 |
+
--input-img-dir data/majorTOM/rome/rome_thumbnail_npy/train \
|
182 |
+
--output-dir data/majorTOM/rome/rome_thumbnail_npy_lmdb/train \
|
183 |
+
--input-type npy
|
184 |
+
|
185 |
+
python3 create_lmdb.py \
|
186 |
+
--input-img-dir data/majorTOM/rome/rome_thumbnail_npy/test \
|
187 |
+
--output-dir data/majorTOM/rome/rome_thumbnail_npy_lmdb/test \
|
188 |
+
--input-type npy
|
189 |
+
```
|
190 |
+
|
191 |
+
## 5. Train the model
|
192 |
+
|
193 |
+
Train the model using the following command (2 GPUs, 4 modalities). Adjust the number of GPUs and the config file as needed.
|
194 |
+
|
195 |
+
NOTE on the config file:
|
196 |
+
- Since we are training on a toy dataset, we set the batch size to 8. Feel free to increase it for larger datasets.
|
197 |
+
- Logging frequency, eval frequency, and save frequency are set to 2 for faster training. Feel free to increase it for larger datasets.
|
198 |
+
|
199 |
+
Visual results, checkpoints, and logs are stored in a generated folder called `workdir`.
|
200 |
+
|
201 |
+
### Training with LMDB dataset
|
202 |
+
|
203 |
+
```bash
|
204 |
+
export NUM_GPUS=2
|
205 |
+
accelerate launch \
|
206 |
+
--multi_gpu \
|
207 |
+
--num_processes $NUM_GPUS \
|
208 |
+
--mixed_precision fp16 \
|
209 |
+
train_triffuser_discrete.py \
|
210 |
+
--config="configs/majortom/discrete/lmdb/rome_dems1s2s2_cop_gen_beta.py"
|
211 |
+
```
|
212 |
+
|
213 |
+
### Training without LMDB (tuples dataset)
|
214 |
+
|
215 |
+
```bash
|
216 |
+
export NUM_GPUS=2
|
217 |
+
accelerate launch \
|
218 |
+
--multi_gpu \
|
219 |
+
--num_processes $NUM_GPUS \
|
220 |
+
--mixed_precision fp16 \
|
221 |
+
train_triffuser_discrete.py \
|
222 |
+
--config="configs/majortom/discrete/rome_dems1s2s2_cop_gen_beta.py"
|
223 |
+
```
|
224 |
+
|
225 |
+
|
226 |
+
# COP-GEN-Beta: Inference instructions.
|
227 |
+
|
228 |
+
COP-GEN-Beta is characterized by its great versatility when generating images.
|
229 |
+
Given 4 modalities (DEM, S1RTC, S2L1C, S2L2A), there exist the following generation options:
|
230 |
+
|
231 |
+
- **Unconditional generation:** Generates tuples of 4 modalities without any condition.
|
232 |
+
- **Conditional generation:**
|
233 |
+
- **Single modality conditioning:** Generates missing modalities conditioned on a single modality.
|
234 |
+
- **2 modality conditioning:** Generates missing modalities conditioned on 2 modalities.
|
235 |
+
- **3 modality conditioning:** Generates missing modalities conditioned on 3 modalities.
|
236 |
+
|
237 |
+
## 1. Download model checkpoint
|
238 |
+
|
239 |
+
<!-- To upload the model to Hugging Face Hub just run in the pth folder: -->
|
240 |
+
<!-- huggingface-cli upload mespinosami/COP-GEN-Beta . -->
|
241 |
+
|
242 |
+
Download the model ema checkpoint from Hugging Face (https://huggingface.co/mespinosami/COP-GEN-Beta) [download-link](https://huggingface.co/mespinosami/COP-GEN-Beta/resolve/main/nnet_ema_114000.pth) and place it in `./models` folder.
|
243 |
+
|
244 |
+
This can be done by running:
|
245 |
+
```bash
|
246 |
+
mkdir -p models
|
247 |
+
wget https://huggingface.co/mespinosami/COP-GEN-Beta/resolve/main/nnet_ema_114000.pth -O models/nnet_ema_114000.pth
|
248 |
+
```
|
249 |
+
|
250 |
+
## 2. Run inference on test set (Rome subset)
|
251 |
+
|
252 |
+
To see all the available inference options, run `python3 sample_n_triffuser.py --help`.
|
253 |
+
For instance:
|
254 |
+
- `--n_samples` controls the number of samples to generate for the same input condition
|
255 |
+
(useful to evaluate the generation variability),
|
256 |
+
- `--generate` is the list (comma separated) of modalities to generate,
|
257 |
+
- `--condition` is the list (comma separated) of modalities to condition on,
|
258 |
+
|
259 |
+
|
260 |
+
### Example 1: Unconditional generation
|
261 |
+
|
262 |
+
Generates all modalities (DEM, S1RTC, S2L2A, S2L1C).
|
263 |
+
```bash
|
264 |
+
python3 sample_n_triffuser.py \
|
265 |
+
--config configs/majortom/discrete/lmdb/rome_dems1s2s2_cop_gen_beta.py \
|
266 |
+
--data_path data/majorTOM/rome/rome_thumbnail_npy_lmdb/test \
|
267 |
+
--data_type lmdb \
|
268 |
+
--nnet_path models/nnet_ema_114000.pth \
|
269 |
+
--n_mod 4 \
|
270 |
+
--generate dem,s1_rtc,s2_l2a,s2_l1c \
|
271 |
+
--output_path out_images \
|
272 |
+
--n_samples 4 \
|
273 |
+
--save_as grid
|
274 |
+
```
|
275 |
+
|
276 |
+
### Example 2: Single modality conditioning
|
277 |
+
|
278 |
+
Conditioning on S1RTC to generate DEM, S2L2A, and S2L1C.
|
279 |
+
|
280 |
+
```bash
|
281 |
+
python3 sample_n_triffuser.py \
|
282 |
+
--config configs/majortom/discrete/lmdb/rome_dems1s2s2_cop_gen_beta.py \
|
283 |
+
--data_path data/majorTOM/rome/rome_thumbnail_npy_lmdb/test \
|
284 |
+
--data_type lmdb \
|
285 |
+
--nnet_path models/nnet_ema_114000.pth \
|
286 |
+
--n_mod 4 \
|
287 |
+
--condition s1_rtc \
|
288 |
+
--generate dem,s2_l2a,s2_l1c \
|
289 |
+
--output_path out_images \
|
290 |
+
--n_samples 4 \
|
291 |
+
--save_as grid
|
292 |
+
```
|
293 |
+
|
294 |
+
### Example 3: 2 modality conditioning
|
295 |
+
|
296 |
+
Conditioning on DEM and S1RTC to generate S2L2A and S2L1C.
|
297 |
+
|
298 |
+
```bash
|
299 |
+
python3 sample_n_triffuser.py \
|
300 |
+
--config configs/majortom/discrete/lmdb/rome_dems1s2s2_cop_gen_beta.py \
|
301 |
+
--data_path data/majorTOM/rome/rome_thumbnail_npy_lmdb/test \
|
302 |
+
--data_type lmdb \
|
303 |
+
--nnet_path models/nnet_ema_114000.pth \
|
304 |
+
--n_mod 4 \
|
305 |
+
--condition dem,s1_rtc \
|
306 |
+
--generate s2_l2a,s2_l1c \
|
307 |
+
--output_path out_images \
|
308 |
+
--n_samples 4 \
|
309 |
+
--save_as grid
|
310 |
+
```
|
311 |
+
|
312 |
+
### Example 4: 3 modality conditioning
|
313 |
+
|
314 |
+
Conditioning on DEM, S1RTC, and S2L2A to generate S2L1C.
|
315 |
+
|
316 |
+
```bash
|
317 |
+
python3 sample_n_triffuser.py \
|
318 |
+
--config configs/majortom/discrete/lmdb/rome_dems1s2s2_cop_gen_beta.py \
|
319 |
+
--data_path data/majorTOM/rome/rome_thumbnail_npy_lmdb/test \
|
320 |
+
--data_type lmdb \
|
321 |
+
--nnet_path models/nnet_ema_114000.pth \
|
322 |
+
--n_mod 4 \
|
323 |
+
--condition dem,s1_rtc,s2_l2a \
|
324 |
+
--generate s2_l1c \
|
325 |
+
--output_path out_images \
|
326 |
+
--n_samples 4 \
|
327 |
+
--save_as grid
|
328 |
+
```
|
329 |
+
|
330 |
+
# Citation
|
331 |
+
|
332 |
+
If you find this work useful, please cite it as follows:
|
333 |
+
|
334 |
+
```bibtex
|
335 |
+
@inproceedings{espinosa2025copgenbeta,
|
336 |
+
title={COP-GEN-Beta: Unified Generative Modelling of COPernicus Imagery Thumbnails},
|
337 |
+
author={Espinosa, Miguel and Marsocci, Valerio and Jia, Yuru and Crowley, Elliot J. and Czerkawski, Mikolaj},
|
338 |
+
booktitle={CVPRW},
|
339 |
+
year={2025}
|
340 |
+
}
|
341 |
+
```
|
src/COP-GEN-Beta/configs/majortom/discrete/lmdb/rome_dems1s2s2_cop_gen_beta.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ml_collections
|
2 |
+
def d(**kwargs):
|
3 |
+
"""Helper of creating a config dict."""
|
4 |
+
return ml_collections.ConfigDict(initial_dictionary=kwargs)
|
5 |
+
def get_config():
|
6 |
+
config = ml_collections.ConfigDict()
|
7 |
+
config.seed = 1234
|
8 |
+
config.pred = "noise_pred"
|
9 |
+
config.z_shape = (4, 32, 32)
|
10 |
+
config.autoencoder = d(pretrained_path="assets/stable-diffusion/autoencoder_kl_ema.pth")
|
11 |
+
config.train = d(
|
12 |
+
n_steps=500000,
|
13 |
+
batch_size=8, # Increase to 512 for larger datasets
|
14 |
+
mode="uncond",
|
15 |
+
log_interval=2, # Increase to 100 for larger datasets
|
16 |
+
eval_interval=2, # Increase to 1500 for larger datasets
|
17 |
+
save_interval=2, # Increase to 1500 for larger datasets
|
18 |
+
multi_modal=True,
|
19 |
+
)
|
20 |
+
config.optimizer = d(
|
21 |
+
name="adamw",
|
22 |
+
lr=0.0002,
|
23 |
+
weight_decay=0.03,
|
24 |
+
betas=(0.99, 0.99),
|
25 |
+
)
|
26 |
+
config.lr_scheduler = d(name="customized", warmup_steps=5000)
|
27 |
+
config.nnet = d(
|
28 |
+
name="triffuser_multi_post_ln",
|
29 |
+
img_size=32,
|
30 |
+
in_chans=4,
|
31 |
+
patch_size=2,
|
32 |
+
embed_dim=1024,
|
33 |
+
depth=20,
|
34 |
+
num_heads=16,
|
35 |
+
mlp_ratio=4,
|
36 |
+
qkv_bias=False,
|
37 |
+
pos_drop_rate=0.,
|
38 |
+
drop_rate=0.,
|
39 |
+
attn_drop_rate=0.,
|
40 |
+
mlp_time_embed=False,
|
41 |
+
num_modalities=4,
|
42 |
+
use_checkpoint=True,
|
43 |
+
)
|
44 |
+
config.dataset = d(
|
45 |
+
name="majorTOM_lmdb_256_features",
|
46 |
+
path="data/majorTOM/rome/rome_thumbnail_npy_lmdb/train",
|
47 |
+
# name="majorTOM_tuples_256_features",
|
48 |
+
# paths=["data/majorTOM/northern_italy/northern_italy_thumbnail_npy/train/DEM_thumbnail",
|
49 |
+
# "data/majorTOM/northern_italy/northern_italy_thumbnail_npy/train/S1RTC_thumbnail",
|
50 |
+
# "data/majorTOM/northern_italy/northern_italy_thumbnail_npy/train/S2L1C_thumbnail",
|
51 |
+
# "data/majorTOM/northern_italy/northern_italy_thumbnail_npy/train/S2L2A_thumbnail"],
|
52 |
+
cfg=False,
|
53 |
+
p_uncond=0.1, # 0.15
|
54 |
+
)
|
55 |
+
config.sample = d(
|
56 |
+
sample_steps=50,
|
57 |
+
n_samples=50000,
|
58 |
+
mini_batch_size=50, # the decoder is large
|
59 |
+
algorithm="dpm_solver",
|
60 |
+
cfg=True,
|
61 |
+
scale=0.4,
|
62 |
+
path="",
|
63 |
+
)
|
64 |
+
return config
|
src/COP-GEN-Beta/configs/majortom/discrete/rome_dems1s2s2_cop_gen_beta.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ml_collections
|
2 |
+
def d(**kwargs):
|
3 |
+
"""Helper of creating a config dict."""
|
4 |
+
return ml_collections.ConfigDict(initial_dictionary=kwargs)
|
5 |
+
def get_config():
|
6 |
+
config = ml_collections.ConfigDict()
|
7 |
+
config.seed = 1234
|
8 |
+
config.pred = "noise_pred"
|
9 |
+
config.z_shape = (4, 32, 32)
|
10 |
+
config.autoencoder = d(pretrained_path="assets/stable-diffusion/autoencoder_kl_ema.pth")
|
11 |
+
config.train = d(
|
12 |
+
n_steps=500000,
|
13 |
+
batch_size=8, # Increase to 512 for larger datasets
|
14 |
+
mode="uncond",
|
15 |
+
log_interval=2, # Increase to 100 for larger datasets
|
16 |
+
eval_interval=2, # Increase to 1500 for larger datasets
|
17 |
+
save_interval=2, # Increase to 1500 for larger datasets
|
18 |
+
multi_modal=True,
|
19 |
+
)
|
20 |
+
config.optimizer = d(
|
21 |
+
name="adamw",
|
22 |
+
lr=0.0002,
|
23 |
+
weight_decay=0.03,
|
24 |
+
betas=(0.99, 0.99),
|
25 |
+
)
|
26 |
+
config.lr_scheduler = d(name="customized", warmup_steps=5000)
|
27 |
+
config.nnet = d(
|
28 |
+
name="triffuser_multi_post_ln",
|
29 |
+
img_size=32,
|
30 |
+
in_chans=4,
|
31 |
+
patch_size=2,
|
32 |
+
embed_dim=1024,
|
33 |
+
depth=20,
|
34 |
+
num_heads=16,
|
35 |
+
mlp_ratio=4,
|
36 |
+
qkv_bias=False,
|
37 |
+
pos_drop_rate=0.,
|
38 |
+
drop_rate=0.,
|
39 |
+
attn_drop_rate=0.,
|
40 |
+
mlp_time_embed=False,
|
41 |
+
num_modalities=4,
|
42 |
+
use_checkpoint=True,
|
43 |
+
)
|
44 |
+
config.dataset = d(
|
45 |
+
name="majorTOM_tuples_256_features",
|
46 |
+
paths=["data/majorTOM/rome/rome_thumbnail_npy/train/DEM_thumbnail",
|
47 |
+
"data/majorTOM/rome/rome_thumbnail_npy/train/S1RTC_thumbnail",
|
48 |
+
"data/majorTOM/rome/rome_thumbnail_npy/train/S2L1C_thumbnail",
|
49 |
+
"data/majorTOM/rome/rome_thumbnail_npy/train/S2L2A_thumbnail"],
|
50 |
+
cfg=False,
|
51 |
+
p_uncond=0.1, # 0.15
|
52 |
+
)
|
53 |
+
config.sample = d(
|
54 |
+
sample_steps=50,
|
55 |
+
n_samples=50000,
|
56 |
+
mini_batch_size=50, # the decoder is large
|
57 |
+
algorithm="dpm_solver",
|
58 |
+
cfg=True,
|
59 |
+
scale=0.4,
|
60 |
+
path="",
|
61 |
+
)
|
62 |
+
return config
|
src/COP-GEN-Beta/create_lmdb.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Author: Chenhongyi Yang
|
3 |
+
Reference: GPViT https://github.com/ChenhongyiYang/GPViT
|
4 |
+
"""
|
5 |
+
|
6 |
+
"""
|
7 |
+
This script will generate a paired LMDB database for all modalities found in the input directory.
|
8 |
+
Thus, the input directory should contain subdirectories for each modality, each containing a set of images.
|
9 |
+
The names for the paired images in the different subdirectories should be the same.
|
10 |
+
|
11 |
+
# Example:
|
12 |
+
python3 scripts/create_lmdb.py \
|
13 |
+
--input-img-dir data/majorTOM/northern_italy/northern_italy_thumbnail_npy/train \
|
14 |
+
--output-dir data/majorTOM/northern_italy/northern_italy_thumbnail_npy_lmdb/train \
|
15 |
+
--input-type npy
|
16 |
+
"""
|
17 |
+
|
18 |
+
|
19 |
+
import glob
|
20 |
+
import blobfile as bf
|
21 |
+
import os
|
22 |
+
import re
|
23 |
+
import time
|
24 |
+
from collections import defaultdict
|
25 |
+
from concurrent.futures import ThreadPoolExecutor
|
26 |
+
from typing import Tuple
|
27 |
+
import pickle
|
28 |
+
|
29 |
+
import cv2
|
30 |
+
import lmdb
|
31 |
+
|
32 |
+
import argparse
|
33 |
+
parser = argparse.ArgumentParser('Convert LMDB dataset')
|
34 |
+
parser.add_argument('--input-img-dir', help='Path to ImageNet training images')
|
35 |
+
parser.add_argument('--output-dir', help='Path to output training lmdb dataset')
|
36 |
+
parser.add_argument('--input-type', choices=['png', 'npy'],
|
37 |
+
help='Type of input to encode: "png" for PNG images or "npy" for NPY features')
|
38 |
+
parser.add_argument('--batch-size', type=int, default=10000,
|
39 |
+
help='Batch size for processing images')
|
40 |
+
# parser.add_argument('val-img-dir', 'Path to ImageNet validation images')
|
41 |
+
# parser.add_argument('val-out', 'Path to output validation lmdb dataset')
|
42 |
+
args = parser.parse_args()
|
43 |
+
|
44 |
+
_10TB = 10 * (1 << 40)
|
45 |
+
|
46 |
+
class LmdbDataExporter(object):
|
47 |
+
"""
|
48 |
+
making LMDB database
|
49 |
+
"""
|
50 |
+
# label_pattern = re.compile(r'/.*/.*?(\d+)$')
|
51 |
+
|
52 |
+
def __init__(self,
|
53 |
+
img_dir=None,
|
54 |
+
output_path=None,
|
55 |
+
batch_size=None):
|
56 |
+
"""
|
57 |
+
img_dir: imgs directory
|
58 |
+
output_path: LMDB output path
|
59 |
+
"""
|
60 |
+
self.img_dir = img_dir
|
61 |
+
self.output_path = output_path
|
62 |
+
self.batch_size = batch_size
|
63 |
+
|
64 |
+
if not os.path.exists(img_dir):
|
65 |
+
raise Exception(f'{img_dir} does not exist!')
|
66 |
+
|
67 |
+
if not os.path.exists(output_path):
|
68 |
+
os.makedirs(output_path)
|
69 |
+
|
70 |
+
self.lmdb_env = lmdb.open(output_path, map_size=_10TB, max_dbs=4)
|
71 |
+
self.modalities = self._get_modalities()
|
72 |
+
|
73 |
+
def _get_modalities(self):
|
74 |
+
"""Get list of modalities (subdirectories) in the input directory"""
|
75 |
+
return [d for d in os.listdir(self.img_dir)
|
76 |
+
if os.path.isdir(os.path.join(self.img_dir, d))]
|
77 |
+
|
78 |
+
def export(self):
|
79 |
+
idx = 0
|
80 |
+
results = []
|
81 |
+
st = time.time()
|
82 |
+
iter_img_lst = self.read_imgs()
|
83 |
+
length = self.get_length()
|
84 |
+
print(f'length: {length}')
|
85 |
+
while True:
|
86 |
+
items = []
|
87 |
+
try:
|
88 |
+
while len(items) < self.batch_size:
|
89 |
+
items.append(next(iter_img_lst))
|
90 |
+
except StopIteration:
|
91 |
+
break
|
92 |
+
|
93 |
+
with ThreadPoolExecutor() as executor:
|
94 |
+
results.extend(executor.map(self._extract_once, items))
|
95 |
+
|
96 |
+
if len(results) >= self.batch_size:
|
97 |
+
self.save_to_lmdb(results)
|
98 |
+
idx += self.batch_size
|
99 |
+
et = time.time()
|
100 |
+
print(f'time: {(et-st)}(s) count: {idx}')
|
101 |
+
st = time.time()
|
102 |
+
# Progressively decrease batch size for remaining items
|
103 |
+
remaining = length - idx
|
104 |
+
if remaining < self.batch_size:
|
105 |
+
self.batch_size = max(remaining // 2, 1)
|
106 |
+
print(f'batch_size is reduced to: {self.batch_size}')
|
107 |
+
del results[:]
|
108 |
+
|
109 |
+
et = time.time()
|
110 |
+
print(f'time: {(et-st)}(s) count: {idx}')
|
111 |
+
self.save_to_lmdb(results)
|
112 |
+
# self.save_total(idx)
|
113 |
+
print('Total length:', len(results))
|
114 |
+
del results[:]
|
115 |
+
|
116 |
+
def save_to_lmdb(self, results):
|
117 |
+
"""
|
118 |
+
persist to lmdb
|
119 |
+
"""
|
120 |
+
with self.lmdb_env.begin(write=True) as txn:
|
121 |
+
while results:
|
122 |
+
img_key, img_byte = results.pop()
|
123 |
+
if img_key is None or img_byte is None:
|
124 |
+
continue
|
125 |
+
txn.put(img_key, img_byte)
|
126 |
+
|
127 |
+
def save_total(self, total: int):
|
128 |
+
"""
|
129 |
+
persist all numbers of imgs
|
130 |
+
"""
|
131 |
+
with self.lmdb_env.begin(write=True, buffers=True) as txn:
|
132 |
+
txn.put('total'.encode(), str(total).encode())
|
133 |
+
|
134 |
+
def _extract_once(self, item) -> Tuple[bytes, bytes]:
|
135 |
+
image_name = item[1]
|
136 |
+
modality_data = item[2] # Dictionary of modality -> file path
|
137 |
+
|
138 |
+
# Create a dictionary to store all modality data
|
139 |
+
data_dict = {}
|
140 |
+
|
141 |
+
# Read each modality's data
|
142 |
+
for modality, file_path in modality_data.items():
|
143 |
+
if args.input_type == 'image':
|
144 |
+
img = cv2.imread(file_path)
|
145 |
+
if img is None:
|
146 |
+
print(f'{file_path} is a bad img file.')
|
147 |
+
return None, None
|
148 |
+
_, img_byte = cv2.imencode('.png', img)
|
149 |
+
data_dict[modality] = img_byte.tobytes()
|
150 |
+
else: # feature
|
151 |
+
try:
|
152 |
+
import numpy as np
|
153 |
+
features = np.load(file_path)
|
154 |
+
data_dict[modality] = features.tobytes()
|
155 |
+
except Exception as e:
|
156 |
+
print(f'Error loading {file_path}: {e}')
|
157 |
+
return None, None
|
158 |
+
|
159 |
+
return (image_name.encode('ascii'), pickle.dumps(data_dict))
|
160 |
+
|
161 |
+
def get_length(self):
|
162 |
+
# Just count files in the first modality directory
|
163 |
+
if not self.modalities:
|
164 |
+
return 0
|
165 |
+
first_modality_dir = os.path.join(self.img_dir, self.modalities[0])
|
166 |
+
img_list = glob.glob(os.path.join(first_modality_dir, '*.npy'))
|
167 |
+
return len(img_list)
|
168 |
+
|
169 |
+
def _list_image_files_recursively(self, data_dir):
|
170 |
+
results = []
|
171 |
+
for entry in sorted(bf.listdir(data_dir)):
|
172 |
+
full_path = bf.join(data_dir, entry)
|
173 |
+
ext = entry.split(".")[-1]
|
174 |
+
if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif", "npy"]:
|
175 |
+
results.append(full_path)
|
176 |
+
elif bf.isdir(full_path):
|
177 |
+
results.extend(self._list_image_files_recursively(full_path))
|
178 |
+
return results
|
179 |
+
|
180 |
+
def read_imgs(self):
|
181 |
+
# Create a dictionary to store files by their base name
|
182 |
+
file_groups = defaultdict(dict)
|
183 |
+
|
184 |
+
# File extension based on input type
|
185 |
+
extensions = ['.png'] if args.input_type == 'png' else ['.npy']
|
186 |
+
|
187 |
+
# Collect files from each modality
|
188 |
+
for modality in self.modalities:
|
189 |
+
modality_path = os.path.join(self.img_dir, modality)
|
190 |
+
for file_path in self._list_image_files_recursively(modality_path):
|
191 |
+
ext = os.path.splitext(file_path)[1].lower()
|
192 |
+
if ext in extensions:
|
193 |
+
base_name = os.path.basename(file_path)
|
194 |
+
file_groups[base_name][modality] = file_path
|
195 |
+
|
196 |
+
# Only yield complete groups
|
197 |
+
for idx, (base_name, modality_files) in enumerate(file_groups.items()):
|
198 |
+
if len(modality_files) == len(self.modalities):
|
199 |
+
item = (idx, base_name, modality_files)
|
200 |
+
yield item
|
201 |
+
else:
|
202 |
+
print(f"Skipping incomplete group {base_name}, found modalities: {list(modality_files.keys())}")
|
203 |
+
|
204 |
+
|
205 |
+
if __name__ == '__main__':
|
206 |
+
input_img_dir = args.input_img_dir
|
207 |
+
output_dir = args.output_dir
|
208 |
+
|
209 |
+
exporter = LmdbDataExporter(
|
210 |
+
input_img_dir,
|
211 |
+
output_dir,
|
212 |
+
batch_size=args.batch_size)
|
213 |
+
exporter.export()
|
src/COP-GEN-Beta/datasets.py
ADDED
@@ -0,0 +1,885 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from torchvision import datasets
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
from PIL import Image
|
9 |
+
import os
|
10 |
+
import glob
|
11 |
+
import einops
|
12 |
+
import torchvision.transforms.functional as F
|
13 |
+
|
14 |
+
|
15 |
+
class UnlabeledDataset(Dataset):
|
16 |
+
def __init__(self, dataset):
|
17 |
+
self.dataset = dataset
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
return len(self.dataset)
|
21 |
+
|
22 |
+
def __getitem__(self, item):
|
23 |
+
# data = tuple(self.dataset[item][:-1]) # remove label
|
24 |
+
data = self.dataset[item]
|
25 |
+
if len(data) == 1:
|
26 |
+
data = data[0]
|
27 |
+
return data
|
28 |
+
|
29 |
+
|
30 |
+
class LabeledDataset(Dataset):
|
31 |
+
def __init__(self, dataset, labels):
|
32 |
+
self.dataset = dataset
|
33 |
+
self.labels = labels
|
34 |
+
|
35 |
+
def __len__(self):
|
36 |
+
return len(self.dataset)
|
37 |
+
|
38 |
+
def __getitem__(self, item):
|
39 |
+
return self.dataset[item], self.labels[item]
|
40 |
+
|
41 |
+
|
42 |
+
class CFGDataset(Dataset): # for classifier free guidance
|
43 |
+
def __init__(self, dataset, p_uncond, empty_token):
|
44 |
+
self.dataset = dataset
|
45 |
+
self.p_uncond = p_uncond
|
46 |
+
self.empty_token = empty_token
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.dataset)
|
50 |
+
|
51 |
+
def __getitem__(self, item):
|
52 |
+
x, y = self.dataset[item]
|
53 |
+
if random.random() < self.p_uncond:
|
54 |
+
y = self.empty_token
|
55 |
+
return x, y
|
56 |
+
|
57 |
+
|
58 |
+
class DatasetFactory(object):
|
59 |
+
|
60 |
+
def __init__(self):
|
61 |
+
self.train = None
|
62 |
+
self.test = None
|
63 |
+
|
64 |
+
def get_split(self, split, labeled=False, nosplit=False):
|
65 |
+
if nosplit:
|
66 |
+
return self.dataset
|
67 |
+
if split == "train":
|
68 |
+
dataset = self.train
|
69 |
+
elif split == "test":
|
70 |
+
dataset = self.test
|
71 |
+
else:
|
72 |
+
raise ValueError
|
73 |
+
|
74 |
+
if self.has_label:
|
75 |
+
return dataset if labeled else UnlabeledDataset(dataset)
|
76 |
+
else:
|
77 |
+
assert not labeled
|
78 |
+
return dataset
|
79 |
+
|
80 |
+
def unpreprocess(self, v): # to B C H W and [0, 1]
|
81 |
+
v = 0.5 * (v + 1.)
|
82 |
+
v.clamp_(0., 1.)
|
83 |
+
return v
|
84 |
+
|
85 |
+
@property
|
86 |
+
def has_label(self):
|
87 |
+
return True
|
88 |
+
|
89 |
+
@property
|
90 |
+
def data_shape(self):
|
91 |
+
raise NotImplementedError
|
92 |
+
|
93 |
+
@property
|
94 |
+
def data_dim(self):
|
95 |
+
return int(np.prod(self.data_shape))
|
96 |
+
|
97 |
+
@property
|
98 |
+
def fid_stat(self):
|
99 |
+
return None
|
100 |
+
|
101 |
+
def sample_label(self, n_samples, device):
|
102 |
+
raise NotImplementedError
|
103 |
+
|
104 |
+
def label_prob(self, k):
|
105 |
+
raise NotImplementedError
|
106 |
+
|
107 |
+
|
108 |
+
# CIFAR10
|
109 |
+
|
110 |
+
class CIFAR10(DatasetFactory):
|
111 |
+
r""" CIFAR10 dataset
|
112 |
+
|
113 |
+
Information of the raw dataset:
|
114 |
+
train: 50,000
|
115 |
+
test: 10,000
|
116 |
+
shape: 3 * 32 * 32
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, path, random_flip=False, cfg=False, p_uncond=None):
|
120 |
+
super().__init__()
|
121 |
+
|
122 |
+
transform_train = [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]
|
123 |
+
transform_test = [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]
|
124 |
+
if random_flip: # only for train
|
125 |
+
transform_train.append(transforms.RandomHorizontalFlip())
|
126 |
+
transform_train = transforms.Compose(transform_train)
|
127 |
+
transform_test = transforms.Compose(transform_test)
|
128 |
+
self.train = datasets.CIFAR10(path, train=True, transform=transform_train, download=True)
|
129 |
+
self.test = datasets.CIFAR10(path, train=False, transform=transform_test, download=True)
|
130 |
+
|
131 |
+
assert len(self.train.targets) == 50000
|
132 |
+
self.K = max(self.train.targets) + 1
|
133 |
+
self.cnt = torch.tensor([len(np.where(np.array(self.train.targets) == k)[0]) for k in range(self.K)]).float()
|
134 |
+
self.frac = [self.cnt[k] / 50000 for k in range(self.K)]
|
135 |
+
print(f'{self.K} classes')
|
136 |
+
print(f'cnt: {self.cnt}')
|
137 |
+
print(f'frac: {self.frac}')
|
138 |
+
|
139 |
+
if cfg: # classifier free guidance
|
140 |
+
assert p_uncond is not None
|
141 |
+
print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
|
142 |
+
self.train = CFGDataset(self.train, p_uncond, self.K)
|
143 |
+
|
144 |
+
@property
|
145 |
+
def data_shape(self):
|
146 |
+
return 3, 32, 32
|
147 |
+
|
148 |
+
@property
|
149 |
+
def fid_stat(self):
|
150 |
+
return 'assets/fid_stats/fid_stats_cifar10_train_pytorch.npz'
|
151 |
+
|
152 |
+
def sample_label(self, n_samples, device):
|
153 |
+
return torch.multinomial(self.cnt, n_samples, replacement=True).to(device)
|
154 |
+
|
155 |
+
def label_prob(self, k):
|
156 |
+
return self.frac[k]
|
157 |
+
|
158 |
+
|
159 |
+
# ImageNet
|
160 |
+
|
161 |
+
|
162 |
+
class FeatureDataset(Dataset):
|
163 |
+
def __init__(self, path):
|
164 |
+
super().__init__()
|
165 |
+
self.path = path
|
166 |
+
# names = sorted(os.listdir(path))
|
167 |
+
# self.files = [os.path.join(path, name) for name in names]
|
168 |
+
|
169 |
+
def __len__(self):
|
170 |
+
return 1_281_167 * 2 # consider the random flip
|
171 |
+
|
172 |
+
def __getitem__(self, idx):
|
173 |
+
path = os.path.join(self.path, f'{idx}.npy')
|
174 |
+
z, label = np.load(path, allow_pickle=True)
|
175 |
+
return z, label
|
176 |
+
|
177 |
+
|
178 |
+
class MajorTOM_S2_FeatureDataset(Dataset):
|
179 |
+
|
180 |
+
def __init__(self, path, transform=None):
|
181 |
+
super().__init__()
|
182 |
+
self.path = path
|
183 |
+
self.transform = transform
|
184 |
+
# names = sorted(os.listdir(path))
|
185 |
+
# self.files = [os.path.join(path, name) for name in names]
|
186 |
+
|
187 |
+
def __len__(self):
|
188 |
+
return len(glob.glob(f"{self.path}/*.npy"))
|
189 |
+
|
190 |
+
def __getitem__(self, idx):
|
191 |
+
path = os.path.join(self.path, f"{idx}.npy")
|
192 |
+
moment = np.load(path, allow_pickle=True).copy()
|
193 |
+
if self.transform is not None:
|
194 |
+
moment = self.transform(moment)
|
195 |
+
return moment
|
196 |
+
|
197 |
+
|
198 |
+
class MajorTOM_Tuples_FeatureDataset(Dataset):
|
199 |
+
|
200 |
+
def __init__(self, paths, transform=None):
|
201 |
+
super().__init__()
|
202 |
+
self.paths = paths
|
203 |
+
self.transform = transform
|
204 |
+
print(f"Gathering filenames...")
|
205 |
+
self.filenames = [os.path.splitext(os.path.basename(f))[0] for f in glob.glob(f"{self.paths[0]}/*.npy")]
|
206 |
+
print(f"Found {len(self.filenames)} filenames across all paths")
|
207 |
+
|
208 |
+
def __len__(self):
|
209 |
+
return len(self.filenames)
|
210 |
+
|
211 |
+
def __getitem__(self, idx):
|
212 |
+
# Return npy files for each modality. Always in the same order
|
213 |
+
moments = []
|
214 |
+
for path in self.paths:
|
215 |
+
path = os.path.join(path, f"{self.filenames[idx]}.npy")
|
216 |
+
moment = np.load(path, allow_pickle=True).copy()
|
217 |
+
if self.transform is not None:
|
218 |
+
moment = self.transform(moment)
|
219 |
+
moments.append(moment)
|
220 |
+
return moments
|
221 |
+
|
222 |
+
|
223 |
+
class MajorTOM_S2_Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder
|
224 |
+
def __init__(self, path, cfg=False, p_uncond=None):
|
225 |
+
super().__init__()
|
226 |
+
print("Prepare dataset...")
|
227 |
+
# transform_train = [transforms.ToTensor()]
|
228 |
+
transform_train = []
|
229 |
+
self.train = MajorTOM_S2_FeatureDataset(
|
230 |
+
path, transform=transforms.Compose(transform_train)
|
231 |
+
)
|
232 |
+
self.path = path
|
233 |
+
print("Prepare dataset ok")
|
234 |
+
self.K = 1000
|
235 |
+
|
236 |
+
if cfg: # classifier free guidance
|
237 |
+
assert p_uncond is not None
|
238 |
+
print(f"prepare the dataset for classifier free guidance with p_uncond={p_uncond}")
|
239 |
+
self.train = CFGDataset(self.train, p_uncond, self.K)
|
240 |
+
|
241 |
+
def get_split(self, split, labeled=False):
|
242 |
+
if split == "train":
|
243 |
+
dataset = self.train
|
244 |
+
elif split == "test":
|
245 |
+
dataset = self.test
|
246 |
+
else:
|
247 |
+
raise ValueError
|
248 |
+
|
249 |
+
if self.has_label:
|
250 |
+
return dataset if labeled else UnlabeledDataset(dataset)
|
251 |
+
else:
|
252 |
+
assert not labeled
|
253 |
+
return dataset
|
254 |
+
|
255 |
+
@property
|
256 |
+
def data_shape(self):
|
257 |
+
return 4, 133, 133
|
258 |
+
|
259 |
+
@property
|
260 |
+
def fid_stat(self):
|
261 |
+
return f"assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz"
|
262 |
+
|
263 |
+
def sample_label(self, n_samples, device):
|
264 |
+
return torch.randint(0, 1000, (n_samples,), device=device)
|
265 |
+
|
266 |
+
|
267 |
+
class MajorTOM_Tuples_Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder
|
268 |
+
def __init__(self, paths, cfg=False, p_uncond=None):
|
269 |
+
super().__init__()
|
270 |
+
print("Prepare dataset...")
|
271 |
+
# transform_train = [transforms.ToTensor()]
|
272 |
+
transform_train = []
|
273 |
+
self.train = MajorTOM_Tuples_FeatureDataset(
|
274 |
+
paths, transform=transforms.Compose(transform_train)
|
275 |
+
)
|
276 |
+
self.paths = paths
|
277 |
+
print("Prepare dataset ok")
|
278 |
+
self.K = 1000
|
279 |
+
|
280 |
+
if cfg: # classifier free guidance
|
281 |
+
assert p_uncond is not None
|
282 |
+
print(f"prepare the dataset for classifier free guidance with p_uncond={p_uncond}")
|
283 |
+
self.train = CFGDataset(self.train, p_uncond, self.K)
|
284 |
+
|
285 |
+
def get_split(self, split, labeled=False):
|
286 |
+
if split == "train":
|
287 |
+
dataset = self.train
|
288 |
+
elif split == "test":
|
289 |
+
dataset = self.test
|
290 |
+
else:
|
291 |
+
raise ValueError
|
292 |
+
|
293 |
+
if self.has_label:
|
294 |
+
return dataset if labeled else UnlabeledDataset(dataset)
|
295 |
+
else:
|
296 |
+
assert not labeled
|
297 |
+
return dataset
|
298 |
+
|
299 |
+
@property
|
300 |
+
def data_shape(self):
|
301 |
+
return "blablabla"
|
302 |
+
|
303 |
+
@property
|
304 |
+
def fid_stat(self):
|
305 |
+
return f"assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz"
|
306 |
+
|
307 |
+
def sample_label(self, n_samples, device):
|
308 |
+
raise NotImplementedError
|
309 |
+
return torch.randint(0, 1000, (n_samples,), device=device)
|
310 |
+
|
311 |
+
|
312 |
+
class MajorTOM_Lmdb_FeatureDataset(Dataset):
|
313 |
+
def __init__(self, path, transform=None, return_filename=False):
|
314 |
+
super().__init__()
|
315 |
+
import pickle
|
316 |
+
|
317 |
+
self.transform = transform
|
318 |
+
self.path = path # Store the path instead of the environment
|
319 |
+
self.return_filename = return_filename
|
320 |
+
|
321 |
+
# Create a temporary environment just to get the stats and keys
|
322 |
+
import lmdb
|
323 |
+
env = lmdb.open(
|
324 |
+
path,
|
325 |
+
max_readers=1,
|
326 |
+
readonly=True,
|
327 |
+
lock=False,
|
328 |
+
readahead=False,
|
329 |
+
meminit=False,
|
330 |
+
)
|
331 |
+
|
332 |
+
# Get total number of entries
|
333 |
+
with env.begin(write=False) as txn:
|
334 |
+
self.length = txn.stat()["entries"]
|
335 |
+
|
336 |
+
# Load or create cache of keys
|
337 |
+
root_split = path.split("/")
|
338 |
+
cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}")
|
339 |
+
if os.path.isfile(cache_file):
|
340 |
+
self.keys = pickle.load(open(cache_file, "rb"))
|
341 |
+
else:
|
342 |
+
with env.begin(write=False) as txn:
|
343 |
+
self.keys = [key for key, _ in txn.cursor()]
|
344 |
+
pickle.dump(self.keys, open(cache_file, "wb"))
|
345 |
+
|
346 |
+
# Close the temporary environment
|
347 |
+
env.close()
|
348 |
+
|
349 |
+
# Create environment lazily in each worker
|
350 |
+
self._env = None
|
351 |
+
|
352 |
+
def _init_db(self):
|
353 |
+
"""Initialize LMDB environment"""
|
354 |
+
import lmdb
|
355 |
+
self._env = lmdb.open(
|
356 |
+
self.path,
|
357 |
+
max_readers=1,
|
358 |
+
readonly=True,
|
359 |
+
lock=False,
|
360 |
+
readahead=False,
|
361 |
+
meminit=False,
|
362 |
+
)
|
363 |
+
|
364 |
+
@property
|
365 |
+
def env(self):
|
366 |
+
"""Get LMDB environment, creating it if necessary"""
|
367 |
+
if self._env is None:
|
368 |
+
self._init_db()
|
369 |
+
return self._env
|
370 |
+
|
371 |
+
def __len__(self):
|
372 |
+
return self.length
|
373 |
+
|
374 |
+
def __getitem__(self, idx):
|
375 |
+
# Get data from LMDB
|
376 |
+
import pickle
|
377 |
+
import numpy as np
|
378 |
+
|
379 |
+
key = self.keys[idx]
|
380 |
+
filename = key.decode('utf-8') if isinstance(key, bytes) else key
|
381 |
+
filename = os.path.basename(filename) # get filename without path
|
382 |
+
filename = os.path.splitext(filename)[0] # remove .npy extension
|
383 |
+
|
384 |
+
with self.env.begin(write=False) as txn:
|
385 |
+
data = pickle.loads(txn.get(key))
|
386 |
+
|
387 |
+
# Convert bytes to data for each modality
|
388 |
+
decoded_data = {}
|
389 |
+
for k, bytes_data in data.items():
|
390 |
+
# Convert bytes back to numpy array with the expected shape (8, 32, 32).
|
391 |
+
# TODO: This is currently hardcoded.
|
392 |
+
features = np.frombuffer(bytes_data, dtype=np.float32).reshape(8, 32, 32).copy()
|
393 |
+
decoded_data[k] = features
|
394 |
+
|
395 |
+
# Apply transforms if any
|
396 |
+
if self.transform is not None:
|
397 |
+
decoded_data = {k: self.transform(v) for k, v in decoded_data.items()}
|
398 |
+
|
399 |
+
# Convert the dictionary values to a list in a consistent order
|
400 |
+
moments = [decoded_data[k] for k in sorted(decoded_data.keys())]
|
401 |
+
|
402 |
+
if self.return_filename:
|
403 |
+
return moments, filename
|
404 |
+
return moments
|
405 |
+
|
406 |
+
def __del__(self):
|
407 |
+
if self._env is not None:
|
408 |
+
self._env.close()
|
409 |
+
|
410 |
+
|
411 |
+
class MajorTOM_Lmdb_Features(DatasetFactory):
|
412 |
+
def __init__(self, path, cfg=False, p_uncond=None, return_filename=False):
|
413 |
+
super().__init__()
|
414 |
+
print("Prepare dataset...")
|
415 |
+
transform_train = []
|
416 |
+
self.return_filename = return_filename
|
417 |
+
self.train = MajorTOM_Lmdb_FeatureDataset(
|
418 |
+
path, transform=transforms.Compose(transform_train), return_filename=return_filename
|
419 |
+
)
|
420 |
+
self.path = path
|
421 |
+
print("Prepare dataset ok")
|
422 |
+
self.K = 1000
|
423 |
+
|
424 |
+
if cfg: # classifier free guidance
|
425 |
+
assert p_uncond is not None
|
426 |
+
print(f"prepare the dataset for classifier free guidance with p_uncond={p_uncond}")
|
427 |
+
self.train = CFGDataset(self.train, p_uncond, self.K)
|
428 |
+
|
429 |
+
def get_split(self, split, labeled=False):
|
430 |
+
if split == "train":
|
431 |
+
dataset = self.train
|
432 |
+
elif split == "test":
|
433 |
+
dataset = self.test
|
434 |
+
else:
|
435 |
+
raise ValueError
|
436 |
+
|
437 |
+
if self.has_label:
|
438 |
+
return dataset if labeled else UnlabeledDataset(dataset)
|
439 |
+
else:
|
440 |
+
assert not labeled
|
441 |
+
return dataset
|
442 |
+
|
443 |
+
@property
|
444 |
+
def data_shape(self):
|
445 |
+
return "blablabla"
|
446 |
+
|
447 |
+
@property
|
448 |
+
def fid_stat(self):
|
449 |
+
return f"assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz"
|
450 |
+
|
451 |
+
def sample_label(self, n_samples, device):
|
452 |
+
raise NotImplementedError
|
453 |
+
return torch.randint(0, 1000, (n_samples,), device=device)
|
454 |
+
|
455 |
+
|
456 |
+
class ImageNet256Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder
|
457 |
+
def __init__(self, path, cfg=False, p_uncond=None):
|
458 |
+
super().__init__()
|
459 |
+
print('Prepare dataset...')
|
460 |
+
self.train = FeatureDataset(path)
|
461 |
+
print('Prepare dataset ok')
|
462 |
+
self.K = 1000
|
463 |
+
|
464 |
+
if cfg: # classifier free guidance
|
465 |
+
assert p_uncond is not None
|
466 |
+
print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
|
467 |
+
self.train = CFGDataset(self.train, p_uncond, self.K)
|
468 |
+
|
469 |
+
@property
|
470 |
+
def data_shape(self):
|
471 |
+
return 4, 32, 32
|
472 |
+
|
473 |
+
@property
|
474 |
+
def fid_stat(self):
|
475 |
+
return f'assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz'
|
476 |
+
|
477 |
+
def sample_label(self, n_samples, device):
|
478 |
+
return torch.randint(0, 1000, (n_samples,), device=device)
|
479 |
+
|
480 |
+
|
481 |
+
class ImageNet512Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder
|
482 |
+
def __init__(self, path, cfg=False, p_uncond=None):
|
483 |
+
super().__init__()
|
484 |
+
print('Prepare dataset...')
|
485 |
+
self.train = FeatureDataset(path)
|
486 |
+
print('Prepare dataset ok')
|
487 |
+
self.K = 1000
|
488 |
+
|
489 |
+
if cfg: # classifier free guidance
|
490 |
+
assert p_uncond is not None
|
491 |
+
print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
|
492 |
+
self.train = CFGDataset(self.train, p_uncond, self.K)
|
493 |
+
|
494 |
+
@property
|
495 |
+
def data_shape(self):
|
496 |
+
return 4, 64, 64
|
497 |
+
|
498 |
+
@property
|
499 |
+
def fid_stat(self):
|
500 |
+
return f'assets/fid_stats/fid_stats_imagenet512_guided_diffusion.npz'
|
501 |
+
|
502 |
+
def sample_label(self, n_samples, device):
|
503 |
+
return torch.randint(0, 1000, (n_samples,), device=device)
|
504 |
+
|
505 |
+
|
506 |
+
class ImageNet(DatasetFactory):
|
507 |
+
def __init__(self, path, resolution, random_crop=False, random_flip=True):
|
508 |
+
super().__init__()
|
509 |
+
|
510 |
+
print(f'Counting ImageNet files from {path}')
|
511 |
+
train_files = _list_image_files_recursively(os.path.join(path, 'train'))
|
512 |
+
class_names = [os.path.basename(path).split("_")[0] for path in train_files]
|
513 |
+
sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
|
514 |
+
train_labels = [sorted_classes[x] for x in class_names]
|
515 |
+
print('Finish counting ImageNet files')
|
516 |
+
|
517 |
+
self.train = ImageDataset(resolution, train_files, labels=train_labels, random_crop=random_crop, random_flip=random_flip)
|
518 |
+
self.resolution = resolution
|
519 |
+
if len(self.train) != 1_281_167:
|
520 |
+
print(f'Missing train samples: {len(self.train)} < 1281167')
|
521 |
+
|
522 |
+
self.K = max(self.train.labels) + 1
|
523 |
+
cnt = dict(zip(*np.unique(self.train.labels, return_counts=True)))
|
524 |
+
self.cnt = torch.tensor([cnt[k] for k in range(self.K)]).float()
|
525 |
+
self.frac = [self.cnt[k] / len(self.train.labels) for k in range(self.K)]
|
526 |
+
print(f'{self.K} classes')
|
527 |
+
print(f'cnt[:10]: {self.cnt[:10]}')
|
528 |
+
print(f'frac[:10]: {self.frac[:10]}')
|
529 |
+
|
530 |
+
@property
|
531 |
+
def data_shape(self):
|
532 |
+
return 3, self.resolution, self.resolution
|
533 |
+
|
534 |
+
@property
|
535 |
+
def fid_stat(self):
|
536 |
+
return f'assets/fid_stats/fid_stats_imagenet{self.resolution}_guided_diffusion.npz'
|
537 |
+
|
538 |
+
def sample_label(self, n_samples, device):
|
539 |
+
return torch.multinomial(self.cnt, n_samples, replacement=True).to(device)
|
540 |
+
|
541 |
+
def label_prob(self, k):
|
542 |
+
return self.frac[k]
|
543 |
+
|
544 |
+
class MajorTOMThumbnail(DatasetFactory):
|
545 |
+
def __init__(self, path, resolution):
|
546 |
+
super().__init__()
|
547 |
+
|
548 |
+
print(f'Counting MajorTOM thumbnail files from {path}')
|
549 |
+
files_list = _list_image_files_recursively(path)
|
550 |
+
print('Finish counting MajorTOM thumbnail files')
|
551 |
+
|
552 |
+
self.dataset = MajorTOMThumbnailDataset(resolution, files_list)
|
553 |
+
self.resolution = resolution
|
554 |
+
if len(self.dataset) != 1_281_167:
|
555 |
+
print(f'Missing train samples: {len(self.dataset)} < 1281167')
|
556 |
+
|
557 |
+
@property
|
558 |
+
def data_shape(self):
|
559 |
+
return 3, self.resolution, self.resolution
|
560 |
+
|
561 |
+
@property
|
562 |
+
def has_label(self):
|
563 |
+
return False
|
564 |
+
|
565 |
+
@property
|
566 |
+
def fid_stat(self):
|
567 |
+
return f'assets/fid_stats/fid_stats_imagenet{self.resolution}_guided_diffusion.npz'
|
568 |
+
|
569 |
+
|
570 |
+
class MajorTOMThumbnailDataset(Dataset):
|
571 |
+
def __init__(
|
572 |
+
self,
|
573 |
+
resolution,
|
574 |
+
image_paths,
|
575 |
+
):
|
576 |
+
super().__init__()
|
577 |
+
self.resolution = resolution
|
578 |
+
self.image_paths = image_paths
|
579 |
+
|
580 |
+
def __len__(self):
|
581 |
+
return len(self.image_paths)
|
582 |
+
|
583 |
+
def __getitem__(self, idx):
|
584 |
+
path = self.image_paths[idx]
|
585 |
+
filename = os.path.basename(path).split('.')[0]
|
586 |
+
pil_image = Image.open(path)
|
587 |
+
pil_image.load()
|
588 |
+
pil_image = pil_image.convert("RGB")
|
589 |
+
|
590 |
+
# check that the image has the correct resolution
|
591 |
+
if pil_image.size != (self.resolution, self.resolution):
|
592 |
+
raise ValueError(f"Image at {path} has size {pil_image.size}, expected {self.resolution}x{self.resolution}")
|
593 |
+
|
594 |
+
arr = np.array(pil_image).astype(np.float32) / 127.5 - 1
|
595 |
+
|
596 |
+
return np.transpose(arr, [2, 0, 1]), filename
|
597 |
+
|
598 |
+
def _list_image_files_recursively(data_dir):
|
599 |
+
results = []
|
600 |
+
for entry in sorted(os.listdir(data_dir)):
|
601 |
+
full_path = os.path.join(data_dir, entry)
|
602 |
+
ext = entry.split(".")[-1]
|
603 |
+
if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
|
604 |
+
results.append(full_path)
|
605 |
+
elif os.listdir(full_path):
|
606 |
+
results.extend(_list_image_files_recursively(full_path))
|
607 |
+
return results
|
608 |
+
|
609 |
+
|
610 |
+
class ImageDataset(Dataset):
|
611 |
+
def __init__(
|
612 |
+
self,
|
613 |
+
resolution,
|
614 |
+
image_paths,
|
615 |
+
labels,
|
616 |
+
random_crop=False,
|
617 |
+
random_flip=True,
|
618 |
+
):
|
619 |
+
super().__init__()
|
620 |
+
self.resolution = resolution
|
621 |
+
self.image_paths = image_paths
|
622 |
+
self.labels = labels
|
623 |
+
self.random_crop = random_crop
|
624 |
+
self.random_flip = random_flip
|
625 |
+
|
626 |
+
def __len__(self):
|
627 |
+
return len(self.image_paths)
|
628 |
+
|
629 |
+
def __getitem__(self, idx):
|
630 |
+
path = self.image_paths[idx]
|
631 |
+
pil_image = Image.open(path)
|
632 |
+
pil_image.load()
|
633 |
+
pil_image = pil_image.convert("RGB")
|
634 |
+
|
635 |
+
if self.random_crop:
|
636 |
+
arr = random_crop_arr(pil_image, self.resolution)
|
637 |
+
else:
|
638 |
+
arr = center_crop_arr(pil_image, self.resolution)
|
639 |
+
|
640 |
+
if self.random_flip and random.random() < 0.5:
|
641 |
+
arr = arr[:, ::-1]
|
642 |
+
|
643 |
+
arr = arr.astype(np.float32) / 127.5 - 1
|
644 |
+
|
645 |
+
label = np.array(self.labels[idx], dtype=np.int64)
|
646 |
+
return np.transpose(arr, [2, 0, 1]), label
|
647 |
+
|
648 |
+
|
649 |
+
def center_crop_arr(pil_image, image_size):
|
650 |
+
# We are not on a new enough PIL to support the `reducing_gap`
|
651 |
+
# argument, which uses BOX downsampling at powers of two first.
|
652 |
+
# Thus, we do it by hand to improve downsample quality.
|
653 |
+
while min(*pil_image.size) >= 2 * image_size:
|
654 |
+
pil_image = pil_image.resize(
|
655 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
656 |
+
)
|
657 |
+
|
658 |
+
scale = image_size / min(*pil_image.size)
|
659 |
+
pil_image = pil_image.resize(
|
660 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
661 |
+
)
|
662 |
+
|
663 |
+
arr = np.array(pil_image)
|
664 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
665 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
666 |
+
return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
|
667 |
+
|
668 |
+
|
669 |
+
def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
|
670 |
+
min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
|
671 |
+
max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
|
672 |
+
smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
|
673 |
+
|
674 |
+
# We are not on a new enough PIL to support the `reducing_gap`
|
675 |
+
# argument, which uses BOX downsampling at powers of two first.
|
676 |
+
# Thus, we do it by hand to improve downsample quality.
|
677 |
+
while min(*pil_image.size) >= 2 * smaller_dim_size:
|
678 |
+
pil_image = pil_image.resize(
|
679 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
680 |
+
)
|
681 |
+
|
682 |
+
scale = smaller_dim_size / min(*pil_image.size)
|
683 |
+
pil_image = pil_image.resize(
|
684 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
685 |
+
)
|
686 |
+
|
687 |
+
arr = np.array(pil_image)
|
688 |
+
crop_y = random.randrange(arr.shape[0] - image_size + 1)
|
689 |
+
crop_x = random.randrange(arr.shape[1] - image_size + 1)
|
690 |
+
return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
|
691 |
+
|
692 |
+
|
693 |
+
# CelebA
|
694 |
+
|
695 |
+
|
696 |
+
class Crop(object):
|
697 |
+
def __init__(self, x1, x2, y1, y2):
|
698 |
+
self.x1 = x1
|
699 |
+
self.x2 = x2
|
700 |
+
self.y1 = y1
|
701 |
+
self.y2 = y2
|
702 |
+
|
703 |
+
def __call__(self, img):
|
704 |
+
return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1)
|
705 |
+
|
706 |
+
def __repr__(self):
|
707 |
+
return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
|
708 |
+
self.x1, self.x2, self.y1, self.y2
|
709 |
+
)
|
710 |
+
|
711 |
+
|
712 |
+
class CelebA(DatasetFactory):
|
713 |
+
r""" train: 162,770
|
714 |
+
val: 19,867
|
715 |
+
test: 19,962
|
716 |
+
shape: 3 * width * width
|
717 |
+
"""
|
718 |
+
|
719 |
+
def __init__(self, path, resolution=64):
|
720 |
+
super().__init__()
|
721 |
+
|
722 |
+
self.resolution = resolution
|
723 |
+
|
724 |
+
cx = 89
|
725 |
+
cy = 121
|
726 |
+
x1 = cy - 64
|
727 |
+
x2 = cy + 64
|
728 |
+
y1 = cx - 64
|
729 |
+
y2 = cx + 64
|
730 |
+
|
731 |
+
transform = transforms.Compose([Crop(x1, x2, y1, y2), transforms.Resize(self.resolution),
|
732 |
+
transforms.RandomHorizontalFlip(), transforms.ToTensor(),
|
733 |
+
transforms.Normalize(0.5, 0.5)])
|
734 |
+
self.train = datasets.CelebA(root=path, split="train", target_type=[], transform=transform, download=True)
|
735 |
+
self.train = UnlabeledDataset(self.train)
|
736 |
+
|
737 |
+
@property
|
738 |
+
def data_shape(self):
|
739 |
+
return 3, self.resolution, self.resolution
|
740 |
+
|
741 |
+
@property
|
742 |
+
def fid_stat(self):
|
743 |
+
return 'assets/fid_stats/fid_stats_celeba64_train_50000_ddim.npz'
|
744 |
+
|
745 |
+
@property
|
746 |
+
def has_label(self):
|
747 |
+
return False
|
748 |
+
|
749 |
+
|
750 |
+
# MS COCO
|
751 |
+
|
752 |
+
|
753 |
+
def center_crop(width, height, img):
|
754 |
+
resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos']
|
755 |
+
crop = np.min(img.shape[:2])
|
756 |
+
img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
|
757 |
+
(img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2]
|
758 |
+
try:
|
759 |
+
img = Image.fromarray(img, 'RGB')
|
760 |
+
except:
|
761 |
+
img = Image.fromarray(img)
|
762 |
+
img = img.resize((width, height), resample)
|
763 |
+
|
764 |
+
return np.array(img).astype(np.uint8)
|
765 |
+
|
766 |
+
|
767 |
+
class MSCOCODatabase(Dataset):
|
768 |
+
def __init__(self, root, annFile, size=None):
|
769 |
+
from pycocotools.coco import COCO
|
770 |
+
self.root = root
|
771 |
+
self.height = self.width = size
|
772 |
+
|
773 |
+
self.coco = COCO(annFile)
|
774 |
+
self.keys = list(sorted(self.coco.imgs.keys()))
|
775 |
+
|
776 |
+
def _load_image(self, key: int):
|
777 |
+
path = self.coco.loadImgs(key)[0]["file_name"]
|
778 |
+
return Image.open(os.path.join(self.root, path)).convert("RGB")
|
779 |
+
|
780 |
+
def _load_target(self, key: int):
|
781 |
+
return self.coco.loadAnns(self.coco.getAnnIds(key))
|
782 |
+
|
783 |
+
def __len__(self):
|
784 |
+
return len(self.keys)
|
785 |
+
|
786 |
+
def __getitem__(self, index):
|
787 |
+
key = self.keys[index]
|
788 |
+
image = self._load_image(key)
|
789 |
+
image = np.array(image).astype(np.uint8)
|
790 |
+
image = center_crop(self.width, self.height, image).astype(np.float32)
|
791 |
+
image = (image / 127.5 - 1.0).astype(np.float32)
|
792 |
+
image = einops.rearrange(image, 'h w c -> c h w')
|
793 |
+
|
794 |
+
anns = self._load_target(key)
|
795 |
+
target = []
|
796 |
+
for ann in anns:
|
797 |
+
target.append(ann['caption'])
|
798 |
+
|
799 |
+
return image, target
|
800 |
+
|
801 |
+
|
802 |
+
def get_feature_dir_info(root):
|
803 |
+
files = glob.glob(os.path.join(root, '*.npy'))
|
804 |
+
files_caption = glob.glob(os.path.join(root, '*_*.npy'))
|
805 |
+
num_data = len(files) - len(files_caption)
|
806 |
+
n_captions = {k: 0 for k in range(num_data)}
|
807 |
+
for f in files_caption:
|
808 |
+
name = os.path.split(f)[-1]
|
809 |
+
k1, k2 = os.path.splitext(name)[0].split('_')
|
810 |
+
n_captions[int(k1)] += 1
|
811 |
+
return num_data, n_captions
|
812 |
+
|
813 |
+
|
814 |
+
class MSCOCOFeatureDataset(Dataset):
|
815 |
+
# the image features are got through sample
|
816 |
+
def __init__(self, root):
|
817 |
+
self.root = root
|
818 |
+
self.num_data, self.n_captions = get_feature_dir_info(root)
|
819 |
+
|
820 |
+
def __len__(self):
|
821 |
+
return self.num_data
|
822 |
+
|
823 |
+
def __getitem__(self, index):
|
824 |
+
z = np.load(os.path.join(self.root, f'{index}.npy'))
|
825 |
+
k = random.randint(0, self.n_captions[index] - 1)
|
826 |
+
c = np.load(os.path.join(self.root, f'{index}_{k}.npy'))
|
827 |
+
return z, c
|
828 |
+
|
829 |
+
|
830 |
+
class MSCOCO256Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder & the contexts calculated by clip
|
831 |
+
def __init__(self, path, cfg=False, p_uncond=None):
|
832 |
+
super().__init__()
|
833 |
+
print('Prepare dataset...')
|
834 |
+
self.train = MSCOCOFeatureDataset(os.path.join(path, 'train'))
|
835 |
+
self.test = MSCOCOFeatureDataset(os.path.join(path, 'val'))
|
836 |
+
assert len(self.train) == 82783
|
837 |
+
assert len(self.test) == 40504
|
838 |
+
print('Prepare dataset ok')
|
839 |
+
|
840 |
+
self.empty_context = np.load(os.path.join(path, 'empty_context.npy'))
|
841 |
+
|
842 |
+
if cfg: # classifier free guidance
|
843 |
+
assert p_uncond is not None
|
844 |
+
print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
|
845 |
+
self.train = CFGDataset(self.train, p_uncond, self.empty_context)
|
846 |
+
|
847 |
+
# text embedding extracted by clip
|
848 |
+
# for visulization in t2i
|
849 |
+
self.prompts, self.contexts = [], []
|
850 |
+
for f in sorted(os.listdir(os.path.join(path, 'run_vis')), key=lambda x: int(x.split('.')[0])):
|
851 |
+
prompt, context = np.load(os.path.join(path, 'run_vis', f), allow_pickle=True)
|
852 |
+
self.prompts.append(prompt)
|
853 |
+
self.contexts.append(context)
|
854 |
+
self.contexts = np.array(self.contexts)
|
855 |
+
|
856 |
+
@property
|
857 |
+
def data_shape(self):
|
858 |
+
return 4, 32, 32
|
859 |
+
|
860 |
+
@property
|
861 |
+
def fid_stat(self):
|
862 |
+
return f'assets/fid_stats/fid_stats_mscoco256_val.npz'
|
863 |
+
|
864 |
+
|
865 |
+
def get_dataset(name, **kwargs):
|
866 |
+
if name == 'cifar10':
|
867 |
+
return CIFAR10(**kwargs)
|
868 |
+
elif name == 'imagenet':
|
869 |
+
return ImageNet(**kwargs)
|
870 |
+
elif name == 'imagenet256_features':
|
871 |
+
return ImageNet256Features(**kwargs)
|
872 |
+
elif name == 'imagenet512_features':
|
873 |
+
return ImageNet512Features(**kwargs)
|
874 |
+
elif name == "majorTOM_S2_256_features":
|
875 |
+
return MajorTOM_S2_Features(**kwargs)
|
876 |
+
elif name == "majorTOM_tuples_256_features":
|
877 |
+
return MajorTOM_Tuples_Features(**kwargs)
|
878 |
+
elif name == "majorTOM_lmdb_256_features":
|
879 |
+
return MajorTOM_Lmdb_Features(**kwargs)
|
880 |
+
elif name == 'celeba':
|
881 |
+
return CelebA(**kwargs)
|
882 |
+
elif name == 'mscoco256_features':
|
883 |
+
return MSCOCO256Features(**kwargs)
|
884 |
+
else:
|
885 |
+
raise NotImplementedError(name)
|
src/COP-GEN-Beta/dpm_solver_pp.py
ADDED
@@ -0,0 +1,952 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
|
8 |
+
def interpolate_fn(x: torch.Tensor, xp: torch.Tensor, yp: torch.Tensor) -> torch.Tensor:
|
9 |
+
"""Performs piecewise linear interpolation for x, using xp and yp keypoints (knots).
|
10 |
+
Performs separate interpolation for each channel.
|
11 |
+
Args:
|
12 |
+
x: [N, C] points to be calibrated (interpolated). Batch with C channels.
|
13 |
+
xp: [C, K] x coordinates of the PWL knots. C is the number of channels, K is the number of knots.
|
14 |
+
yp: [C, K] y coordinates of the PWL knots. C is the number of channels, K is the number of knots.
|
15 |
+
Returns:
|
16 |
+
Interpolated points of the shape [N, C].
|
17 |
+
The piecewise linear function extends for the whole x axis (the outermost keypoints define the outermost
|
18 |
+
infinite lines).
|
19 |
+
For example:
|
20 |
+
>>> calibrate1d(torch.tensor([[0.5]]), torch.tensor([[0.0, 1.0]]), torch.tensor([[0.0, 2.0]]))
|
21 |
+
tensor([[1.0000]])
|
22 |
+
>>> calibrate1d(torch.tensor([[-10]]), torch.tensor([[0.0, 1.0]]), torch.tensor([[0.0, 2.0]]))
|
23 |
+
tensor([[-20.0000]])
|
24 |
+
"""
|
25 |
+
x_breakpoints = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((x.shape[0], 1, 1))], dim=2)
|
26 |
+
num_x_points = xp.shape[1]
|
27 |
+
sorted_x_breakpoints, x_indices = torch.sort(x_breakpoints, dim=2)
|
28 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
29 |
+
cand_start_idx = x_idx - 1
|
30 |
+
start_idx = torch.where(
|
31 |
+
torch.eq(x_idx, 0),
|
32 |
+
torch.tensor(1, device=x.device),
|
33 |
+
torch.where(
|
34 |
+
torch.eq(x_idx, num_x_points), torch.tensor(num_x_points - 2, device=x.device), cand_start_idx,
|
35 |
+
),
|
36 |
+
)
|
37 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
38 |
+
start_x = torch.gather(sorted_x_breakpoints, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
39 |
+
end_x = torch.gather(sorted_x_breakpoints, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
40 |
+
start_idx2 = torch.where(
|
41 |
+
torch.eq(x_idx, 0),
|
42 |
+
torch.tensor(0, device=x.device),
|
43 |
+
torch.where(
|
44 |
+
torch.eq(x_idx, num_x_points), torch.tensor(num_x_points - 2, device=x.device), cand_start_idx,
|
45 |
+
),
|
46 |
+
)
|
47 |
+
y_positions_expanded = yp.unsqueeze(0).expand(x.shape[0], -1, -1)
|
48 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
49 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
50 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
51 |
+
return cand
|
52 |
+
|
53 |
+
|
54 |
+
class NoiseScheduleVP:
|
55 |
+
def __init__(self, schedule='discrete', beta_0=1e-4, beta_1=2e-2, total_N=1000, betas=None, alphas_cumprod=None):
|
56 |
+
"""Create a wrapper class for the forward SDE (VP type).
|
57 |
+
|
58 |
+
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
59 |
+
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
60 |
+
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
61 |
+
|
62 |
+
log_alpha_t = self.marginal_log_mean_coeff(t)
|
63 |
+
sigma_t = self.marginal_std(t)
|
64 |
+
lambda_t = self.marginal_lambda(t)
|
65 |
+
|
66 |
+
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
67 |
+
|
68 |
+
t = self.inverse_lambda(lambda_t)
|
69 |
+
|
70 |
+
===============================================================
|
71 |
+
|
72 |
+
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
73 |
+
schedule are the default settings in DDPM and improved-DDPM:
|
74 |
+
|
75 |
+
beta_min: A `float` number. The smallest beta for the linear schedule.
|
76 |
+
beta_max: A `float` number. The largest beta for the linear schedule.
|
77 |
+
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
78 |
+
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
79 |
+
T: A `float` number. The ending time of the forward process.
|
80 |
+
|
81 |
+
Note that the original DDPM (linear schedule) used the discrete-time label (0 to 999). We convert the discrete-time
|
82 |
+
label to the continuous-time time (followed Song et al., 2021), so the beta here is 1000x larger than those in DDPM.
|
83 |
+
|
84 |
+
===============================================================
|
85 |
+
|
86 |
+
Args:
|
87 |
+
schedule: A `str`. The noise schedule of the forward SDE ('linear' or 'cosine').
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
A wrapper object of the forward SDE (VP type).
|
91 |
+
"""
|
92 |
+
if schedule not in ['linear', 'discrete', 'cosine']:
|
93 |
+
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'linear' or 'cosine'".format(schedule))
|
94 |
+
self.total_N = total_N
|
95 |
+
self.beta_0 = beta_0 * 1000.
|
96 |
+
self.beta_1 = beta_1 * 1000.
|
97 |
+
|
98 |
+
if schedule == 'discrete':
|
99 |
+
if betas is not None:
|
100 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
101 |
+
else:
|
102 |
+
assert alphas_cumprod is not None
|
103 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
104 |
+
self.total_N = len(log_alphas)
|
105 |
+
self.t_discrete = torch.linspace(1. / self.total_N, 1., self.total_N).reshape((1, -1))
|
106 |
+
self.log_alpha_discrete = log_alphas.reshape((1, -1))
|
107 |
+
|
108 |
+
self.cosine_s = 0.008
|
109 |
+
self.cosine_beta_max = 999.
|
110 |
+
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
111 |
+
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
112 |
+
self.schedule = schedule
|
113 |
+
if schedule == 'cosine':
|
114 |
+
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
115 |
+
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
116 |
+
self.T = 0.9946
|
117 |
+
else:
|
118 |
+
self.T = 1.
|
119 |
+
|
120 |
+
def marginal_log_mean_coeff(self, t):
|
121 |
+
"""
|
122 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
123 |
+
"""
|
124 |
+
if self.schedule == 'linear':
|
125 |
+
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
126 |
+
elif self.schedule == 'discrete':
|
127 |
+
return interpolate_fn(t.reshape((-1, 1)), self.t_discrete.clone().to(t.device), self.log_alpha_discrete.clone().to(t.device)).reshape((-1,))
|
128 |
+
elif self.schedule == 'cosine':
|
129 |
+
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
130 |
+
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
131 |
+
return log_alpha_t
|
132 |
+
else:
|
133 |
+
raise ValueError("Unsupported ")
|
134 |
+
|
135 |
+
def marginal_alpha(self, t):
|
136 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
137 |
+
|
138 |
+
def marginal_std(self, t):
|
139 |
+
"""
|
140 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
141 |
+
"""
|
142 |
+
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
143 |
+
|
144 |
+
def marginal_lambda(self, t):
|
145 |
+
"""
|
146 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
147 |
+
"""
|
148 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
149 |
+
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
150 |
+
return log_mean_coeff - log_std
|
151 |
+
|
152 |
+
def inverse_lambda(self, lamb):
|
153 |
+
"""
|
154 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
155 |
+
"""
|
156 |
+
if self.schedule == 'linear':
|
157 |
+
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
158 |
+
Delta = self.beta_0**2 + tmp
|
159 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
160 |
+
elif self.schedule == 'discrete':
|
161 |
+
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
162 |
+
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_discrete.clone().to(lamb.device), [1]), torch.flip(self.t_discrete.clone().to(lamb.device), [1]))
|
163 |
+
return t.reshape((-1,))
|
164 |
+
else:
|
165 |
+
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
166 |
+
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
167 |
+
t = t_fn(log_alpha)
|
168 |
+
return t
|
169 |
+
|
170 |
+
|
171 |
+
def model_wrapper(model, noise_schedule=None, is_cond_classifier=False, classifier_fn=None, classifier_scale=1., time_input_type='1', total_N=1000, model_kwargs={}, is_deis=False):
|
172 |
+
"""Create a wrapper function for the noise prediction model.
|
173 |
+
|
174 |
+
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
175 |
+
firstly wrap the model function to a function that accepts the continuous time as the input.
|
176 |
+
|
177 |
+
The input `model` has the following format:
|
178 |
+
|
179 |
+
``
|
180 |
+
model(x, t_input, **model_kwargs) -> noise
|
181 |
+
``
|
182 |
+
|
183 |
+
where `x` and `noise` have the same shape, and `t_input` is the time label of the model.
|
184 |
+
(may be discrete-time labels (i.e. 0 to 999) or continuous-time labels (i.e. epsilon to T).)
|
185 |
+
|
186 |
+
We wrap the model function to the following format:
|
187 |
+
|
188 |
+
``
|
189 |
+
def model_fn(x, t_continuous) -> noise:
|
190 |
+
t_input = get_model_input_time(t_continuous)
|
191 |
+
return model(x, t_input, **model_kwargs)
|
192 |
+
``
|
193 |
+
|
194 |
+
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
195 |
+
|
196 |
+
For DPMs with classifier guidance, we also combine the model output with the classifier gradient as used in [1].
|
197 |
+
|
198 |
+
[1] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," in Advances in Neural
|
199 |
+
Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
200 |
+
|
201 |
+
===============================================================
|
202 |
+
|
203 |
+
Args:
|
204 |
+
model: A noise prediction model with the following format:
|
205 |
+
``
|
206 |
+
def model(x, t_input, **model_kwargs):
|
207 |
+
return noise
|
208 |
+
``
|
209 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP. Only used for the classifier guidance.
|
210 |
+
is_cond_classifier: A `bool`. Whether to use the classifier guidance.
|
211 |
+
classifier_fn: A classifier function. Only used for the classifier guidance. The format is:
|
212 |
+
``
|
213 |
+
def classifier_fn(x, t_input):
|
214 |
+
return logits
|
215 |
+
``
|
216 |
+
classifier_scale: A `float`. The scale for the classifier guidance.
|
217 |
+
time_input_type: A `str`. The type for the time input of the model. We support three types:
|
218 |
+
- '0': The continuous-time type. In this case, the model is trained on the continuous time,
|
219 |
+
so `t_input` = `t_continuous`.
|
220 |
+
- '1': The Type-1 discrete type described in the Appendix of DPM-Solver paper.
|
221 |
+
**For discrete-time DPMs, we recommend to use this type for DPM-Solver**.
|
222 |
+
- '2': The Type-2 discrete type described in the Appendix of DPM-Solver paper.
|
223 |
+
total_N: A `int`. The total number of the discrete-time DPMs (default is 1000), used when `time_input_type`
|
224 |
+
is '1' or '2'.
|
225 |
+
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
226 |
+
Returns:
|
227 |
+
A function that accepts the continuous time as the input, with the following format:
|
228 |
+
``
|
229 |
+
def model_fn(x, t_continuous):
|
230 |
+
t_input = get_model_input_time(t_continuous)
|
231 |
+
return model(x, t_input, **model_kwargs)
|
232 |
+
``
|
233 |
+
"""
|
234 |
+
def get_model_input_time(t_continuous):
|
235 |
+
"""
|
236 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
237 |
+
"""
|
238 |
+
if time_input_type == '0':
|
239 |
+
# discrete_type == '0' means that the model is continuous-time model.
|
240 |
+
# For continuous-time DPMs, the continuous time equals to the discrete time.
|
241 |
+
return t_continuous
|
242 |
+
elif time_input_type == '1':
|
243 |
+
# Type-1 discrete label, as detailed in the Appendix of DPM-Solver.
|
244 |
+
return 1000. * torch.max(t_continuous - 1. / total_N, torch.zeros_like(t_continuous).to(t_continuous))
|
245 |
+
elif time_input_type == '2':
|
246 |
+
# Type-2 discrete label, as detailed in the Appendix of DPM-Solver.
|
247 |
+
max_N = (total_N - 1) / total_N * 1000.
|
248 |
+
return max_N * t_continuous
|
249 |
+
else:
|
250 |
+
raise ValueError("Unsupported time input type {}, must be '0' or '1' or '2'".format(time_input_type))
|
251 |
+
|
252 |
+
def cond_fn(x, t_discrete, y):
|
253 |
+
"""
|
254 |
+
Compute the gradient of the classifier, multiplied with the sclae of the classifier guidance.
|
255 |
+
"""
|
256 |
+
assert y is not None
|
257 |
+
with torch.enable_grad():
|
258 |
+
x_in = x.detach().requires_grad_(True)
|
259 |
+
logits = classifier_fn(x_in, t_discrete)
|
260 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
261 |
+
selected = log_probs[range(len(logits)), y.view(-1)]
|
262 |
+
return classifier_scale * torch.autograd.grad(selected.sum(), x_in)[0]
|
263 |
+
|
264 |
+
def model_fn(x, t_continuous):
|
265 |
+
"""
|
266 |
+
The noise predicition model function that is used for DPM-Solver.
|
267 |
+
"""
|
268 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
269 |
+
t_continuous = torch.ones((x.shape[0],)).to(x.device) * t_continuous
|
270 |
+
if is_cond_classifier:
|
271 |
+
y = model_kwargs.get("y", None)
|
272 |
+
if y is None:
|
273 |
+
raise ValueError("For classifier guidance, the label y has to be in the input.")
|
274 |
+
t_discrete = get_model_input_time(t_continuous)
|
275 |
+
noise_uncond = model(x, t_discrete, **model_kwargs)
|
276 |
+
cond_grad = cond_fn(x, t_discrete, y)
|
277 |
+
if is_deis:
|
278 |
+
sigma_t = noise_schedule.marginal_std(t_continuous / 1000.)
|
279 |
+
else:
|
280 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
281 |
+
dims = len(cond_grad.shape) - 1
|
282 |
+
return noise_uncond - sigma_t[(...,) + (None,)*dims] * cond_grad
|
283 |
+
else:
|
284 |
+
t_discrete = get_model_input_time(t_continuous)
|
285 |
+
return model(x, t_discrete, **model_kwargs)
|
286 |
+
|
287 |
+
return model_fn
|
288 |
+
|
289 |
+
|
290 |
+
class DPM_Solver:
|
291 |
+
def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
|
292 |
+
"""Construct a DPM-Solver.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
model_fn: A noise prediction model function which accepts the continuous-time input
|
296 |
+
(t in [epsilon, T]):
|
297 |
+
``
|
298 |
+
def model_fn(x, t_continuous):
|
299 |
+
return noise
|
300 |
+
``
|
301 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
302 |
+
"""
|
303 |
+
self.model = model_fn
|
304 |
+
self.noise_schedule = noise_schedule
|
305 |
+
self.predict_x0 = predict_x0
|
306 |
+
self.thresholding = thresholding
|
307 |
+
self.max_val = max_val
|
308 |
+
|
309 |
+
def model_fn(self, x, t):
|
310 |
+
if self.predict_x0:
|
311 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
312 |
+
noise = self.model(x, t)
|
313 |
+
dims = len(x.shape) - 1
|
314 |
+
x0 = (x - sigma_t[(...,) + (None,)*dims] * noise) / alpha_t[(...,) + (None,)*dims]
|
315 |
+
if self.thresholding:
|
316 |
+
p = 0.995
|
317 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
318 |
+
s = torch.maximum(s, torch.ones_like(s).to(s.device))[(...,) + (None,)*dims]
|
319 |
+
x0 = torch.clamp(x0, -s, s) / (s / self.max_val)
|
320 |
+
return x0
|
321 |
+
else:
|
322 |
+
return self.model(x, t)
|
323 |
+
|
324 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
325 |
+
"""Compute the intermediate time steps for sampling.
|
326 |
+
|
327 |
+
Args:
|
328 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
329 |
+
- 'logSNR': uniform logSNR for the time steps, **recommended for DPM-Solver**.
|
330 |
+
- 'time_uniform': uniform time for the time steps. (Used in DDIM and DDPM.)
|
331 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
332 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
333 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
334 |
+
N: A `int`. The total number of the spacing of the time steps.
|
335 |
+
device: A torch device.
|
336 |
+
Returns:
|
337 |
+
A pytorch tensor of the time steps, with the shape (N + 1,).
|
338 |
+
"""
|
339 |
+
if skip_type == 'logSNR':
|
340 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
341 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
342 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
343 |
+
# print(torch.min(torch.abs(logSNR_steps - self.noise_schedule.marginal_lambda(self.noise_schedule.inverse_lambda(logSNR_steps)))).item())
|
344 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
345 |
+
elif skip_type == 't2':
|
346 |
+
t_order = 2
|
347 |
+
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
348 |
+
return t
|
349 |
+
elif skip_type == 'time_uniform':
|
350 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
351 |
+
elif skip_type == 'time_quadratic':
|
352 |
+
t = torch.linspace(t_0, t_T, 10000000).to(device)
|
353 |
+
quadratic_t = torch.sqrt(t)
|
354 |
+
quadratic_steps = torch.linspace(quadratic_t[0], quadratic_t[-1], N + 1).to(device)
|
355 |
+
return torch.flip(torch.cat([t[torch.searchsorted(quadratic_t, quadratic_steps)[:-1]], t_T * torch.ones((1,)).to(device)], dim=0), dims=[0])
|
356 |
+
else:
|
357 |
+
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
358 |
+
|
359 |
+
def get_time_steps_for_dpm_solver_fast(self, skip_type, t_T, t_0, steps, order, device):
|
360 |
+
"""
|
361 |
+
Compute the intermediate time steps and the order of each step for sampling by DPM-Solver-fast.
|
362 |
+
|
363 |
+
We recommend DPM-Solver-fast for fast sampling of DPMs. Given a fixed number of function evaluations by `steps`,
|
364 |
+
the sampling procedure by DPM-Solver-fast is:
|
365 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
366 |
+
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
367 |
+
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
|
368 |
+
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
|
369 |
+
|
370 |
+
============================================
|
371 |
+
Args:
|
372 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
373 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
374 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
375 |
+
device: A torch device.
|
376 |
+
Returns:
|
377 |
+
orders: A list of the solver order of each step.
|
378 |
+
timesteps: A pytorch tensor of the time steps, with the shape of (K + 1,).
|
379 |
+
"""
|
380 |
+
if order == 3:
|
381 |
+
K = steps // 3 + 1
|
382 |
+
if steps % 3 == 0:
|
383 |
+
orders = [3,] * (K - 2) + [2, 1]
|
384 |
+
elif steps % 3 == 1:
|
385 |
+
orders = [3,] * (K - 1) + [1]
|
386 |
+
else:
|
387 |
+
orders = [3,] * (K - 1) + [2]
|
388 |
+
timesteps = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
389 |
+
return orders, timesteps
|
390 |
+
elif order == 2:
|
391 |
+
K = steps // 2
|
392 |
+
if steps % 2 == 0:
|
393 |
+
orders = [2,] * K
|
394 |
+
else:
|
395 |
+
orders = [2,] * K + [1]
|
396 |
+
timesteps = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
397 |
+
return orders, timesteps
|
398 |
+
else:
|
399 |
+
raise ValueError("order must >= 2")
|
400 |
+
|
401 |
+
def denoise_fn(self, x, s, noise_s=None):
|
402 |
+
ns = self.noise_schedule
|
403 |
+
dims = len(x.shape) - 1
|
404 |
+
log_alpha_s = ns.marginal_log_mean_coeff(s)
|
405 |
+
sigma_s = ns.marginal_std(s)
|
406 |
+
|
407 |
+
if noise_s is None:
|
408 |
+
noise_s = self.model_fn(x, s)
|
409 |
+
x_0 = (
|
410 |
+
(x - sigma_s[(...,) + (None,)*dims] * noise_s) / torch.exp(log_alpha_s)[(...,) + (None,)*dims]
|
411 |
+
)
|
412 |
+
return x_0
|
413 |
+
|
414 |
+
def dpm_solver_first_update(self, x, s, t, noise_s=None, return_noise=False):
|
415 |
+
"""
|
416 |
+
A single step for DPM-Solver-1.
|
417 |
+
|
418 |
+
Args:
|
419 |
+
x: A pytorch tensor. The initial value at time `s`.
|
420 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
421 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
422 |
+
return_noise: A `bool`. If true, also return the predicted noise at time `s`.
|
423 |
+
Returns:
|
424 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
425 |
+
"""
|
426 |
+
ns = self.noise_schedule
|
427 |
+
dims = len(x.shape) - 1
|
428 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
429 |
+
h = lambda_t - lambda_s
|
430 |
+
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
|
431 |
+
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
|
432 |
+
alpha_t = torch.exp(log_alpha_t)
|
433 |
+
|
434 |
+
if self.predict_x0:
|
435 |
+
phi_1 = (torch.exp(-h) - 1.) / (-1.)
|
436 |
+
if noise_s is None:
|
437 |
+
noise_s = self.model_fn(x, s)
|
438 |
+
x_t = (
|
439 |
+
(sigma_t / sigma_s)[(...,) + (None,)*dims] * x
|
440 |
+
+ (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
441 |
+
)
|
442 |
+
if return_noise:
|
443 |
+
return x_t, {'noise_s': noise_s}
|
444 |
+
else:
|
445 |
+
return x_t
|
446 |
+
else:
|
447 |
+
phi_1 = torch.expm1(h)
|
448 |
+
if noise_s is None:
|
449 |
+
noise_s = self.model_fn(x, s)
|
450 |
+
x_t = (
|
451 |
+
torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
|
452 |
+
- (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
453 |
+
)
|
454 |
+
if return_noise:
|
455 |
+
return x_t, {'noise_s': noise_s}
|
456 |
+
else:
|
457 |
+
return x_t
|
458 |
+
|
459 |
+
def dpm_solver_second_update(self, x, s, t, r1=0.5, noise_s=None, return_noise=False, solver_type='dpm_solver'):
|
460 |
+
"""
|
461 |
+
A single step for DPM-Solver-2.
|
462 |
+
|
463 |
+
Args:
|
464 |
+
x: A pytorch tensor. The initial value at time `s`.
|
465 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
466 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
467 |
+
r1: A `float`. The hyperparameter of the second-order solver. We recommend the default setting `0.5`.
|
468 |
+
noise_s: A pytorch tensor. The predicted noise at time `s`.
|
469 |
+
If `noise_s` is None, we compute the predicted noise by `x` and `s`; otherwise we directly use it.
|
470 |
+
return_noise: A `bool`. If true, also return the predicted noise at time `s` and `s1` (the intermediate time).
|
471 |
+
Returns:
|
472 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
473 |
+
"""
|
474 |
+
if r1 is None:
|
475 |
+
r1 = 0.5
|
476 |
+
ns = self.noise_schedule
|
477 |
+
dims = len(x.shape) - 1
|
478 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
479 |
+
h = lambda_t - lambda_s
|
480 |
+
lambda_s1 = lambda_s + r1 * h
|
481 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
482 |
+
log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
|
483 |
+
sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
|
484 |
+
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
|
485 |
+
|
486 |
+
if self.predict_x0:
|
487 |
+
phi_11 = torch.expm1(-r1 * h)
|
488 |
+
phi_1 = torch.expm1(-h)
|
489 |
+
|
490 |
+
if noise_s is None:
|
491 |
+
noise_s = self.model_fn(x, s)
|
492 |
+
x_s1 = (
|
493 |
+
(sigma_s1 / sigma_s)[(...,) + (None,)*dims] * x
|
494 |
+
- (alpha_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
|
495 |
+
)
|
496 |
+
noise_s1 = self.model_fn(x_s1, s1)
|
497 |
+
if solver_type == 'dpm_solver':
|
498 |
+
x_t = (
|
499 |
+
(sigma_t / sigma_s)[(...,) + (None,)*dims] * x
|
500 |
+
- (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
501 |
+
- (0.5 / r1) * (alpha_t * phi_1)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
|
502 |
+
)
|
503 |
+
elif solver_type == 'taylor':
|
504 |
+
x_t = (
|
505 |
+
(sigma_t / sigma_s)[(...,) + (None,)*dims] * x
|
506 |
+
- (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
507 |
+
+ (1. / r1) * (alpha_t * ((torch.exp(-h) - 1.) / h + 1.))[(...,) + (None,)*dims] * (noise_s1 - noise_s)
|
508 |
+
)
|
509 |
+
else:
|
510 |
+
raise ValueError("solver_type must be either dpm_solver or taylor, got {}".format(solver_type))
|
511 |
+
else:
|
512 |
+
phi_11 = torch.expm1(r1 * h)
|
513 |
+
phi_1 = torch.expm1(h)
|
514 |
+
|
515 |
+
if noise_s is None:
|
516 |
+
noise_s = self.model_fn(x, s)
|
517 |
+
x_s1 = (
|
518 |
+
torch.exp(log_alpha_s1 - log_alpha_s)[(...,) + (None,)*dims] * x
|
519 |
+
- (sigma_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
|
520 |
+
)
|
521 |
+
noise_s1 = self.model_fn(x_s1, s1)
|
522 |
+
if solver_type == 'dpm_solver':
|
523 |
+
x_t = (
|
524 |
+
torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
|
525 |
+
- (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
526 |
+
- (0.5 / r1) * (sigma_t * phi_1)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
|
527 |
+
)
|
528 |
+
elif solver_type == 'taylor':
|
529 |
+
x_t = (
|
530 |
+
torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
|
531 |
+
- (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
532 |
+
- (1. / r1) * (sigma_t * ((torch.exp(h) - 1.) / h - 1.))[(...,) + (None,)*dims] * (noise_s1 - noise_s)
|
533 |
+
)
|
534 |
+
else:
|
535 |
+
raise ValueError("solver_type must be either dpm_solver or taylor, got {}".format(solver_type))
|
536 |
+
if return_noise:
|
537 |
+
return x_t, {'noise_s': noise_s, 'noise_s1': noise_s1}
|
538 |
+
else:
|
539 |
+
return x_t
|
540 |
+
|
541 |
+
|
542 |
+
def dpm_multistep_second_update(self, x, noise_prev_list, t_prev_list, t, solver_type="dpm_solver"):
|
543 |
+
ns = self.noise_schedule
|
544 |
+
dims = len(x.shape) - 1
|
545 |
+
noise_prev_1, noise_prev_0 = noise_prev_list
|
546 |
+
t_prev_1, t_prev_0 = t_prev_list
|
547 |
+
lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
|
548 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
549 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
550 |
+
alpha_t = torch.exp(log_alpha_t)
|
551 |
+
|
552 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
553 |
+
h = lambda_t - lambda_prev_0
|
554 |
+
r0 = h_0 / h
|
555 |
+
D1_0 = (1. / r0)[(...,) + (None,)*dims] * (noise_prev_0 - noise_prev_1)
|
556 |
+
if self.predict_x0:
|
557 |
+
if solver_type == 'taylor':
|
558 |
+
x_t = (
|
559 |
+
(sigma_t / sigma_prev_0)[(...,) + (None,)*dims] * x
|
560 |
+
- (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
|
561 |
+
+ (alpha_t * ((torch.exp(-h) - 1.) / h + 1.))[(...,) + (None,)*dims] * D1_0
|
562 |
+
)
|
563 |
+
elif solver_type == 'dpm_solver':
|
564 |
+
x_t = (
|
565 |
+
(sigma_t / sigma_prev_0)[(...,) + (None,)*dims] * x
|
566 |
+
- (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
|
567 |
+
- 0.5 * (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * D1_0
|
568 |
+
)
|
569 |
+
else:
|
570 |
+
if solver_type == 'taylor':
|
571 |
+
x_t = (
|
572 |
+
torch.exp(log_alpha_t - log_alpha_prev_0)[(...,) + (None,)*dims] * x
|
573 |
+
- (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
|
574 |
+
- (sigma_t * ((torch.exp(h) - 1.) / h - 1.))[(...,) + (None,)*dims] * D1_0
|
575 |
+
)
|
576 |
+
elif solver_type == 'dpm_solver':
|
577 |
+
x_t = (
|
578 |
+
torch.exp(log_alpha_t - log_alpha_prev_0)[(...,) + (None,)*dims] * x
|
579 |
+
- (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
|
580 |
+
- 0.5 * (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * D1_0
|
581 |
+
)
|
582 |
+
return x_t
|
583 |
+
|
584 |
+
|
585 |
+
def dpm_multistep_third_update(self, x, noise_prev_list, t_prev_list, t, solver_type='dpm_solver'):
|
586 |
+
ns = self.noise_schedule
|
587 |
+
dims = len(x.shape) - 1
|
588 |
+
noise_prev_2, noise_prev_1, noise_prev_0 = noise_prev_list
|
589 |
+
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
|
590 |
+
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
|
591 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
592 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
593 |
+
alpha_t = torch.exp(log_alpha_t)
|
594 |
+
|
595 |
+
h_1 = lambda_prev_1 - lambda_prev_2
|
596 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
597 |
+
h = lambda_t - lambda_prev_0
|
598 |
+
r0, r1 = h_0 / h, h_1 / h
|
599 |
+
D1_0 = (1. / r0)[(...,) + (None,)*dims] * (noise_prev_0 - noise_prev_1)
|
600 |
+
D1_1 = (1. / r1)[(...,) + (None,)*dims] * (noise_prev_1 - noise_prev_2)
|
601 |
+
D1 = D1_0 + (r0 / (r0 + r1))[(...,) + (None,)*dims] * (D1_0 - D1_1)
|
602 |
+
D2 = (1. / (r0 + r1))[(...,) + (None,)*dims] * (D1_0 - D1_1)
|
603 |
+
if self.predict_x0:
|
604 |
+
x_t = (
|
605 |
+
(sigma_t / sigma_prev_0)[(...,) + (None,)*dims] * x
|
606 |
+
- (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
|
607 |
+
+ (alpha_t * ((torch.exp(-h) - 1.) / h + 1.))[(...,) + (None,)*dims] * D1
|
608 |
+
- (alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5))[(...,) + (None,)*dims] * D2
|
609 |
+
)
|
610 |
+
else:
|
611 |
+
x_t = (
|
612 |
+
torch.exp(log_alpha_t - log_alpha_prev_0)[(...,) + (None,)*dims] * x
|
613 |
+
- (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
|
614 |
+
- (sigma_t * ((torch.exp(h) - 1.) / h - 1.))[(...,) + (None,)*dims] * D1
|
615 |
+
- (sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5))[(...,) + (None,)*dims] * D2
|
616 |
+
)
|
617 |
+
return x_t
|
618 |
+
|
619 |
+
def dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., noise_s=None, noise_s1=None, noise_s2=None, return_noise=False, solver_type='dpm_solver'):
|
620 |
+
"""
|
621 |
+
A single step for DPM-Solver-3.
|
622 |
+
|
623 |
+
Args:
|
624 |
+
x: A pytorch tensor. The initial value at time `s`.
|
625 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
626 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
627 |
+
r1: A `float`. The hyperparameter of the third-order solver. We recommend the default setting `1 / 3`.
|
628 |
+
r2: A `float`. The hyperparameter of the third-order solver. We recommend the default setting `2 / 3`.
|
629 |
+
noise_s: A pytorch tensor. The predicted noise at time `s`.
|
630 |
+
If `noise_s` is None, we compute the predicted noise by `x` and `s`; otherwise we directly use it.
|
631 |
+
noise_s1: A pytorch tensor. The predicted noise at time `s1` (the intermediate time given by `r1`).
|
632 |
+
If `noise_s1` is None, we compute the predicted noise by `s1`; otherwise we directly use it.
|
633 |
+
Returns:
|
634 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
635 |
+
"""
|
636 |
+
if r1 is None:
|
637 |
+
r1 = 1. / 3.
|
638 |
+
if r2 is None:
|
639 |
+
r2 = 2. / 3.
|
640 |
+
ns = self.noise_schedule
|
641 |
+
dims = len(x.shape) - 1
|
642 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
643 |
+
h = lambda_t - lambda_s
|
644 |
+
lambda_s1 = lambda_s + r1 * h
|
645 |
+
lambda_s2 = lambda_s + r2 * h
|
646 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
647 |
+
s2 = ns.inverse_lambda(lambda_s2)
|
648 |
+
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
|
649 |
+
sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
|
650 |
+
alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
|
651 |
+
|
652 |
+
if self.predict_x0:
|
653 |
+
phi_11 = torch.expm1(-r1 * h)
|
654 |
+
phi_12 = torch.expm1(-r2 * h)
|
655 |
+
phi_1 = torch.expm1(-h)
|
656 |
+
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
|
657 |
+
phi_2 = phi_1 / h + 1.
|
658 |
+
phi_3 = phi_2 / h - 0.5
|
659 |
+
|
660 |
+
if noise_s is None:
|
661 |
+
noise_s = self.model_fn(x, s)
|
662 |
+
if noise_s1 is None:
|
663 |
+
x_s1 = (
|
664 |
+
(sigma_s1 / sigma_s)[(...,) + (None,)*dims] * x
|
665 |
+
- (alpha_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
|
666 |
+
)
|
667 |
+
noise_s1 = self.model_fn(x_s1, s1)
|
668 |
+
if noise_s2 is None:
|
669 |
+
x_s2 = (
|
670 |
+
(sigma_s2 / sigma_s)[(...,) + (None,)*dims] * x
|
671 |
+
- (alpha_s2 * phi_12)[(...,) + (None,)*dims] * noise_s
|
672 |
+
+ r2 / r1 * (alpha_s2 * phi_22)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
|
673 |
+
)
|
674 |
+
noise_s2 = self.model_fn(x_s2, s2)
|
675 |
+
if solver_type == 'dpm_solver':
|
676 |
+
x_t = (
|
677 |
+
(sigma_t / sigma_s)[(...,) + (None,)*dims] * x
|
678 |
+
- (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
679 |
+
+ (1. / r2) * (alpha_t * phi_2)[(...,) + (None,)*dims] * (noise_s2 - noise_s)
|
680 |
+
)
|
681 |
+
elif solver_type == 'taylor':
|
682 |
+
D1_0 = (1. / r1) * (noise_s1 - noise_s)
|
683 |
+
D1_1 = (1. / r2) * (noise_s2 - noise_s)
|
684 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
685 |
+
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
686 |
+
x_t = (
|
687 |
+
(sigma_t / sigma_s)[(...,) + (None,)*dims] * x
|
688 |
+
- (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
689 |
+
+ (alpha_t * phi_2)[(...,) + (None,)*dims] * D1
|
690 |
+
- (alpha_t * phi_3)[(...,) + (None,)*dims] * D2
|
691 |
+
)
|
692 |
+
else:
|
693 |
+
raise ValueError("solver_type must be either dpm_solver or dpm_solver++, got {}".format(solver_type))
|
694 |
+
else:
|
695 |
+
phi_11 = torch.expm1(r1 * h)
|
696 |
+
phi_12 = torch.expm1(r2 * h)
|
697 |
+
phi_1 = torch.expm1(h)
|
698 |
+
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
|
699 |
+
phi_2 = phi_1 / h - 1.
|
700 |
+
phi_3 = phi_2 / h - 0.5
|
701 |
+
|
702 |
+
if noise_s is None:
|
703 |
+
noise_s = self.model_fn(x, s)
|
704 |
+
if noise_s1 is None:
|
705 |
+
x_s1 = (
|
706 |
+
torch.exp(log_alpha_s1 - log_alpha_s)[(...,) + (None,)*dims] * x
|
707 |
+
- (sigma_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
|
708 |
+
)
|
709 |
+
noise_s1 = self.model_fn(x_s1, s1)
|
710 |
+
if noise_s2 is None:
|
711 |
+
x_s2 = (
|
712 |
+
torch.exp(log_alpha_s2 - log_alpha_s)[(...,) + (None,)*dims] * x
|
713 |
+
- (sigma_s2 * phi_12)[(...,) + (None,)*dims] * noise_s
|
714 |
+
- r2 / r1 * (sigma_s2 * phi_22)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
|
715 |
+
)
|
716 |
+
noise_s2 = self.model_fn(x_s2, s2)
|
717 |
+
if solver_type == 'dpm_solver':
|
718 |
+
x_t = (
|
719 |
+
torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
|
720 |
+
- (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
721 |
+
- (1. / r2) * (sigma_t * phi_2)[(...,) + (None,)*dims] * (noise_s2 - noise_s)
|
722 |
+
)
|
723 |
+
elif solver_type == 'taylor':
|
724 |
+
D1_0 = (1. / r1) * (noise_s1 - noise_s)
|
725 |
+
D1_1 = (1. / r2) * (noise_s2 - noise_s)
|
726 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
727 |
+
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
728 |
+
x_t = (
|
729 |
+
torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
|
730 |
+
- (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
731 |
+
- (sigma_t * phi_2)[(...,) + (None,)*dims] * D1
|
732 |
+
- (sigma_t * phi_3)[(...,) + (None,)*dims] * D2
|
733 |
+
)
|
734 |
+
else:
|
735 |
+
raise ValueError("solver_type must be either dpm_solver or dpm_solver++, got {}".format(solver_type))
|
736 |
+
|
737 |
+
if return_noise:
|
738 |
+
return x_t, {'noise_s': noise_s, 'noise_s1': noise_s1, 'noise_s2': noise_s2}
|
739 |
+
else:
|
740 |
+
return x_t
|
741 |
+
|
742 |
+
def dpm_solver_update(self, x, s, t, order, return_noise=False, solver_type='dpm_solver', r1=None, r2=None):
|
743 |
+
"""
|
744 |
+
A single step for DPM-Solver of the given order `order`.
|
745 |
+
|
746 |
+
Args:
|
747 |
+
x: A pytorch tensor. The initial value at time `s`.
|
748 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
749 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
750 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
751 |
+
Returns:
|
752 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
753 |
+
"""
|
754 |
+
if order == 1:
|
755 |
+
return self.dpm_solver_first_update(x, s, t, return_noise=return_noise)
|
756 |
+
elif order == 2:
|
757 |
+
return self.dpm_solver_second_update(x, s, t, return_noise=return_noise, solver_type=solver_type, r1=r1)
|
758 |
+
elif order == 3:
|
759 |
+
return self.dpm_solver_third_update(x, s, t, return_noise=return_noise, solver_type=solver_type, r1=r1, r2=r2)
|
760 |
+
else:
|
761 |
+
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
762 |
+
|
763 |
+
def dpm_multistep_update(self, x, noise_prev_list, t_prev_list, t, order, solver_type='taylor'):
|
764 |
+
"""
|
765 |
+
A single step for DPM-Solver of the given order `order`.
|
766 |
+
|
767 |
+
Args:
|
768 |
+
x: A pytorch tensor. The initial value at time `s`.
|
769 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
770 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
771 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
772 |
+
Returns:
|
773 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
774 |
+
"""
|
775 |
+
if order == 1:
|
776 |
+
return self.dpm_solver_first_update(x, t_prev_list[-1], t, noise_s=noise_prev_list[-1])
|
777 |
+
elif order == 2:
|
778 |
+
return self.dpm_multistep_second_update(x, noise_prev_list, t_prev_list, t, solver_type=solver_type)
|
779 |
+
elif order == 3:
|
780 |
+
return self.dpm_multistep_third_update(x, noise_prev_list, t_prev_list, t, solver_type=solver_type)
|
781 |
+
else:
|
782 |
+
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
783 |
+
|
784 |
+
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'):
|
785 |
+
"""
|
786 |
+
The adaptive step size solver based on DPM-Solver.
|
787 |
+
|
788 |
+
Args:
|
789 |
+
x: A pytorch tensor. The initial value at time `t_T`.
|
790 |
+
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
|
791 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
792 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
793 |
+
h_init: A `float`. The initial step size (for logSNR).
|
794 |
+
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
|
795 |
+
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
|
796 |
+
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
|
797 |
+
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
|
798 |
+
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
|
799 |
+
Returns:
|
800 |
+
x_0: A pytorch tensor. The approximated solution at time `t_0`.
|
801 |
+
|
802 |
+
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
|
803 |
+
"""
|
804 |
+
ns = self.noise_schedule
|
805 |
+
s = t_T * torch.ones((x.shape[0],)).to(x)
|
806 |
+
lambda_s = ns.marginal_lambda(s)
|
807 |
+
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
|
808 |
+
h = h_init * torch.ones_like(s).to(x)
|
809 |
+
x_prev = x
|
810 |
+
nfe = 0
|
811 |
+
if order == 2:
|
812 |
+
r1 = 0.5
|
813 |
+
lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_noise=True)
|
814 |
+
higher_update = lambda x, s, t, **kwargs: self.dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
|
815 |
+
elif order == 3:
|
816 |
+
r1, r2 = 1. / 3., 2. / 3.
|
817 |
+
lower_update = lambda x, s, t: self.dpm_solver_second_update(x, s, t, r1=r1, return_noise=True, solver_type=solver_type)
|
818 |
+
higher_update = lambda x, s, t, **kwargs: self.dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
|
819 |
+
else:
|
820 |
+
raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
|
821 |
+
while torch.abs((s - t_0)).mean() > t_err:
|
822 |
+
t = ns.inverse_lambda(lambda_s + h)
|
823 |
+
x_lower, lower_noise_kwargs = lower_update(x, s, t)
|
824 |
+
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
|
825 |
+
delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
|
826 |
+
norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
|
827 |
+
E = norm_fn((x_higher - x_lower) / delta).max()
|
828 |
+
if torch.all(E <= 1.):
|
829 |
+
x = x_higher
|
830 |
+
s = t
|
831 |
+
x_prev = x_lower
|
832 |
+
lambda_s = ns.marginal_lambda(s)
|
833 |
+
h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
|
834 |
+
nfe += order
|
835 |
+
print('adaptive solver nfe', nfe)
|
836 |
+
return x
|
837 |
+
|
838 |
+
def sample(self, x, steps=10, eps=1e-4, T=None, order=3, skip_type='time_uniform',
|
839 |
+
denoise=False, method='fast', solver_type='dpm_solver', atol=0.0078,
|
840 |
+
rtol=0.05,
|
841 |
+
):
|
842 |
+
"""
|
843 |
+
Compute the sample at time `eps` by DPM-Solver, given the initial `x` at time `T`.
|
844 |
+
|
845 |
+
We support the following algorithms:
|
846 |
+
|
847 |
+
- Adaptive step size DPM-Solver (i.e. DPM-Solver-12 and DPM-Solver-23)
|
848 |
+
|
849 |
+
- Fixed order DPM-Solver (i.e. DPM-Solver-1, DPM-Solver-2 and DPM-Solver-3).
|
850 |
+
|
851 |
+
- Fast version of DPM-Solver (i.e. DPM-Solver-fast), which uses uniform logSNR steps and combine
|
852 |
+
different orders of DPM-Solver.
|
853 |
+
|
854 |
+
**We recommend DPM-Solver-fast for both fast sampling in few steps (<=20) and fast convergence in many steps (50 to 100).**
|
855 |
+
|
856 |
+
Choosing the algorithms:
|
857 |
+
|
858 |
+
- If `adaptive_step_size` is True:
|
859 |
+
We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
|
860 |
+
If `order`=2, we use DPM-Solver-12 which combines DPM-Solver-1 and DPM-Solver-2.
|
861 |
+
If `order`=3, we use DPM-Solver-23 which combines DPM-Solver-2 and DPM-Solver-3.
|
862 |
+
You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
|
863 |
+
(NFE) and the sample quality.
|
864 |
+
|
865 |
+
- If `adaptive_step_size` is False and `fast_version` is True:
|
866 |
+
We ignore `order` and use DPM-Solver-fast with number of function evaluations (NFE) = `steps`.
|
867 |
+
We ignore `skip_type` and use uniform logSNR steps for DPM-Solver-fast.
|
868 |
+
Given a fixed NFE=`steps`, the sampling procedure by DPM-Solver-fast is:
|
869 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
870 |
+
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
871 |
+
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
|
872 |
+
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
|
873 |
+
|
874 |
+
- If `adaptive_step_size` is False and `fast_version` is False:
|
875 |
+
We use DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
|
876 |
+
We support three types of `skip_type`:
|
877 |
+
- 'logSNR': uniform logSNR for the time steps, **recommended for DPM-Solver**.
|
878 |
+
- 'time_uniform': uniform time for the time steps. (Used in DDIM and DDPM.)
|
879 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM.)
|
880 |
+
|
881 |
+
=====================================================
|
882 |
+
Args:
|
883 |
+
x: A pytorch tensor. The initial value at time `T` (a sample from the normal distribution).
|
884 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
885 |
+
eps: A `float`. The ending time of the sampling.
|
886 |
+
We recommend `eps`=1e-3 when `steps` <= 15; and `eps`=1e-4 when `steps` > 15.
|
887 |
+
T: A `float`. The starting time of the sampling. Default is `None`.
|
888 |
+
If `T` is None, we use self.noise_schedule.T.
|
889 |
+
order: A `int`. The order of DPM-Solver.
|
890 |
+
skip_type: A `str`. The type for the spacing of the time steps. Default is 'logSNR'.
|
891 |
+
adaptive_step_size: A `bool`. If true, use the adaptive step size DPM-Solver.
|
892 |
+
fast_version: A `bool`. If true, use DPM-Solver-fast (recommended).
|
893 |
+
atol: A `float`. The absolute tolerance of the adaptive step size solver.
|
894 |
+
rtol: A `float`. The relative tolerance of the adaptive step size solver.
|
895 |
+
Returns:
|
896 |
+
x_0: A pytorch tensor. The approximated solution at time `t_0`.
|
897 |
+
|
898 |
+
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
|
899 |
+
"""
|
900 |
+
t_0 = eps
|
901 |
+
t_T = self.noise_schedule.T if T is None else T
|
902 |
+
device = x.device
|
903 |
+
if method == 'adaptive':
|
904 |
+
with torch.no_grad():
|
905 |
+
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
|
906 |
+
elif method == 'multistep':
|
907 |
+
assert steps >= order
|
908 |
+
if timesteps is None:
|
909 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
910 |
+
assert timesteps.shape[0] - 1 == steps
|
911 |
+
with torch.no_grad():
|
912 |
+
vec_t = timesteps[0].expand((x.shape[0]))
|
913 |
+
noise_prev_list = [self.model_fn(x, vec_t)]
|
914 |
+
t_prev_list = [vec_t]
|
915 |
+
for init_order in range(1, order):
|
916 |
+
vec_t = timesteps[init_order].expand(x.shape[0])
|
917 |
+
x = self.dpm_multistep_update(x, noise_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type)
|
918 |
+
noise_prev_list.append(self.model_fn(x, vec_t))
|
919 |
+
t_prev_list.append(vec_t)
|
920 |
+
for step in range(order, steps + 1):
|
921 |
+
vec_t = timesteps[step].expand(x.shape[0])
|
922 |
+
x = self.dpm_multistep_update(x, noise_prev_list, t_prev_list, vec_t, order, solver_type=solver_type)
|
923 |
+
for i in range(order - 1):
|
924 |
+
t_prev_list[i] = t_prev_list[i + 1]
|
925 |
+
noise_prev_list[i] = noise_prev_list[i + 1]
|
926 |
+
t_prev_list[-1] = vec_t
|
927 |
+
if step < steps:
|
928 |
+
noise_prev_list[-1] = self.model_fn(x, vec_t)
|
929 |
+
elif method == 'fast':
|
930 |
+
orders, _ = self.get_time_steps_for_dpm_solver_fast(skip_type=skip_type, t_T=t_T, t_0=t_0, steps=steps, order=order, device=device)
|
931 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
932 |
+
with torch.no_grad():
|
933 |
+
i = 0
|
934 |
+
for order in orders:
|
935 |
+
vec_s, vec_t = torch.ones((x.shape[0],)).to(device) * timesteps[i], torch.ones((x.shape[0],)).to(device) * timesteps[i + order]
|
936 |
+
h = self.noise_schedule.marginal_lambda(timesteps[i + order]) - self.noise_schedule.marginal_lambda(timesteps[i])
|
937 |
+
r1 = None if order <= 1 else (self.noise_schedule.marginal_lambda(timesteps[i + 1]) - self.noise_schedule.marginal_lambda(timesteps[i])) / h
|
938 |
+
r2 = None if order <= 2 else (self.noise_schedule.marginal_lambda(timesteps[i + 2]) - self.noise_schedule.marginal_lambda(timesteps[i])) / h
|
939 |
+
x = self.dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
|
940 |
+
i += order
|
941 |
+
elif method == 'singlestep':
|
942 |
+
N_steps = steps // order
|
943 |
+
orders = [order,] * N_steps
|
944 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=N_steps, device=device)
|
945 |
+
assert len(timesteps) - 1 == N_steps
|
946 |
+
with torch.no_grad():
|
947 |
+
for i, order in enumerate(orders):
|
948 |
+
vec_s, vec_t = torch.ones((x.shape[0],)).to(device) * timesteps[i], torch.ones((x.shape[0],)).to(device) * timesteps[i + 1]
|
949 |
+
x = self.dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type)
|
950 |
+
if denoise:
|
951 |
+
x = self.denoise_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
952 |
+
return x
|
src/COP-GEN-Beta/encode_majortom_images.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from datasets import MajorTOMThumbnail
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from libs.autoencoder import get_model
|
7 |
+
import argparse
|
8 |
+
from tqdm import tqdm
|
9 |
+
import os
|
10 |
+
|
11 |
+
torch.manual_seed(0)
|
12 |
+
np.random.seed(0)
|
13 |
+
|
14 |
+
|
15 |
+
def get_existing_encoded_files(output_dir, image_paths):
|
16 |
+
"""Returns a set of filenames that already have their encoded features saved"""
|
17 |
+
existing_files = set()
|
18 |
+
missing_files = set()
|
19 |
+
|
20 |
+
for img_path in image_paths:
|
21 |
+
filename = os.path.basename(img_path).split('.')[0]
|
22 |
+
npy_path = os.path.join(output_dir, f'{filename}.npy')
|
23 |
+
|
24 |
+
if os.path.exists(npy_path):
|
25 |
+
existing_files.add(filename)
|
26 |
+
else:
|
27 |
+
missing_files.add(filename)
|
28 |
+
|
29 |
+
print(f"\nFound {len(existing_files)} already encoded files")
|
30 |
+
print(f"Missing {len(missing_files)} files to encode")
|
31 |
+
|
32 |
+
return existing_files, missing_files
|
33 |
+
|
34 |
+
def main(resolution=256):
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
parser.add_argument('--path', type=str)
|
37 |
+
parser.add_argument('--resolution', type=int, default=256)
|
38 |
+
parser.add_argument('--output_dir', type=str)
|
39 |
+
args = parser.parse_args()
|
40 |
+
|
41 |
+
# Create output directory if it doesn't exist
|
42 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
43 |
+
|
44 |
+
datafactory = MajorTOMThumbnail(path=args.path, resolution=args.resolution)
|
45 |
+
dataset = datafactory.get_split(split=None, labeled=False, nosplit=True)
|
46 |
+
image_paths = dataset.image_paths
|
47 |
+
|
48 |
+
# Check for existing encoded files
|
49 |
+
# existing_files, missing_files = get_existing_encoded_files(args.output_dir, image_paths)
|
50 |
+
# TODO: Restart is not working yet
|
51 |
+
existing_files = set()
|
52 |
+
missing_files = set(image_paths)
|
53 |
+
|
54 |
+
if len(missing_files) == 0:
|
55 |
+
print("All files have already been encoded. Exiting...")
|
56 |
+
return
|
57 |
+
|
58 |
+
# Filter dataset to only process missing files
|
59 |
+
filtered_indices = [i for i, path in enumerate(image_paths)
|
60 |
+
if os.path.basename(path).split('.')[0] not in existing_files]
|
61 |
+
dataset.image_paths = [image_paths[i] for i in filtered_indices]
|
62 |
+
|
63 |
+
dataset_loader = DataLoader(dataset, batch_size=128, shuffle=False, drop_last=False,
|
64 |
+
num_workers=8, pin_memory=True, persistent_workers=True)
|
65 |
+
|
66 |
+
model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
|
67 |
+
# model = nn.DataParallel(model)
|
68 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
69 |
+
model.to(device)
|
70 |
+
|
71 |
+
processed_count = 0
|
72 |
+
for img, filename in tqdm(dataset_loader, desc="Encoding images", unit="batch"):
|
73 |
+
img = img.to(device)
|
74 |
+
moments = model(img, fn='encode_moments')
|
75 |
+
moments = moments.detach().cpu().numpy()
|
76 |
+
|
77 |
+
for moment, fname in zip(moments, filename):
|
78 |
+
np.save(f'{args.output_dir}/{fname}.npy', moment)
|
79 |
+
processed_count += 1
|
80 |
+
|
81 |
+
print(f'\nProcessed {processed_count} new files')
|
82 |
+
print(f'Total encoded files: {len(existing_files) + processed_count}')
|
83 |
+
|
84 |
+
# features = []
|
85 |
+
# labels = []
|
86 |
+
# features = np.concatenate(features, axis=0)
|
87 |
+
# labels = np.concatenate(labels, axis=0)
|
88 |
+
# print(f'features.shape={features.shape}')
|
89 |
+
# print(f'labels.shape={labels.shape}')
|
90 |
+
# np.save(f'imagenet{resolution}_features.npy', features)
|
91 |
+
# np.save(f'imagenet{resolution}_labels.npy', labels)
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
main()
|
src/COP-GEN-Beta/libs/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# codes from third party
|
src/COP-GEN-Beta/libs/autoencoder.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
|
7 |
+
class LinearAttention(nn.Module):
|
8 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
9 |
+
super().__init__()
|
10 |
+
self.heads = heads
|
11 |
+
hidden_dim = dim_head * heads
|
12 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
13 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
b, c, h, w = x.shape
|
17 |
+
qkv = self.to_qkv(x)
|
18 |
+
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
19 |
+
k = k.softmax(dim=-1)
|
20 |
+
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
21 |
+
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
22 |
+
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
23 |
+
return self.to_out(out)
|
24 |
+
|
25 |
+
|
26 |
+
def nonlinearity(x):
|
27 |
+
# swish
|
28 |
+
return x*torch.sigmoid(x)
|
29 |
+
|
30 |
+
|
31 |
+
def Normalize(in_channels, num_groups=32):
|
32 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
33 |
+
|
34 |
+
|
35 |
+
class Upsample(nn.Module):
|
36 |
+
def __init__(self, in_channels, with_conv):
|
37 |
+
super().__init__()
|
38 |
+
self.with_conv = with_conv
|
39 |
+
if self.with_conv:
|
40 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
41 |
+
in_channels,
|
42 |
+
kernel_size=3,
|
43 |
+
stride=1,
|
44 |
+
padding=1)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
48 |
+
if self.with_conv:
|
49 |
+
x = self.conv(x)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class Downsample(nn.Module):
|
54 |
+
def __init__(self, in_channels, with_conv):
|
55 |
+
super().__init__()
|
56 |
+
self.with_conv = with_conv
|
57 |
+
if self.with_conv:
|
58 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
59 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
60 |
+
in_channels,
|
61 |
+
kernel_size=3,
|
62 |
+
stride=2,
|
63 |
+
padding=0)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
if self.with_conv:
|
67 |
+
pad = (0,1,0,1)
|
68 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
69 |
+
x = self.conv(x)
|
70 |
+
else:
|
71 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
class ResnetBlock(nn.Module):
|
76 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
77 |
+
dropout, temb_channels=512):
|
78 |
+
super().__init__()
|
79 |
+
self.in_channels = in_channels
|
80 |
+
out_channels = in_channels if out_channels is None else out_channels
|
81 |
+
self.out_channels = out_channels
|
82 |
+
self.use_conv_shortcut = conv_shortcut
|
83 |
+
|
84 |
+
self.norm1 = Normalize(in_channels)
|
85 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
86 |
+
out_channels,
|
87 |
+
kernel_size=3,
|
88 |
+
stride=1,
|
89 |
+
padding=1)
|
90 |
+
if temb_channels > 0:
|
91 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
92 |
+
out_channels)
|
93 |
+
self.norm2 = Normalize(out_channels)
|
94 |
+
self.dropout = torch.nn.Dropout(dropout)
|
95 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
96 |
+
out_channels,
|
97 |
+
kernel_size=3,
|
98 |
+
stride=1,
|
99 |
+
padding=1)
|
100 |
+
if self.in_channels != self.out_channels:
|
101 |
+
if self.use_conv_shortcut:
|
102 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
103 |
+
out_channels,
|
104 |
+
kernel_size=3,
|
105 |
+
stride=1,
|
106 |
+
padding=1)
|
107 |
+
else:
|
108 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
109 |
+
out_channels,
|
110 |
+
kernel_size=1,
|
111 |
+
stride=1,
|
112 |
+
padding=0)
|
113 |
+
|
114 |
+
def forward(self, x, temb):
|
115 |
+
h = x
|
116 |
+
h = self.norm1(h)
|
117 |
+
h = nonlinearity(h)
|
118 |
+
h = self.conv1(h)
|
119 |
+
|
120 |
+
if temb is not None:
|
121 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
122 |
+
|
123 |
+
h = self.norm2(h)
|
124 |
+
h = nonlinearity(h)
|
125 |
+
h = self.dropout(h)
|
126 |
+
h = self.conv2(h)
|
127 |
+
|
128 |
+
if self.in_channels != self.out_channels:
|
129 |
+
if self.use_conv_shortcut:
|
130 |
+
x = self.conv_shortcut(x)
|
131 |
+
else:
|
132 |
+
x = self.nin_shortcut(x)
|
133 |
+
|
134 |
+
return x+h
|
135 |
+
|
136 |
+
|
137 |
+
class LinAttnBlock(LinearAttention):
|
138 |
+
"""to match AttnBlock usage"""
|
139 |
+
def __init__(self, in_channels):
|
140 |
+
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
141 |
+
|
142 |
+
|
143 |
+
class AttnBlock(nn.Module):
|
144 |
+
def __init__(self, in_channels):
|
145 |
+
super().__init__()
|
146 |
+
self.in_channels = in_channels
|
147 |
+
|
148 |
+
self.norm = Normalize(in_channels)
|
149 |
+
self.q = torch.nn.Conv2d(in_channels,
|
150 |
+
in_channels,
|
151 |
+
kernel_size=1,
|
152 |
+
stride=1,
|
153 |
+
padding=0)
|
154 |
+
self.k = torch.nn.Conv2d(in_channels,
|
155 |
+
in_channels,
|
156 |
+
kernel_size=1,
|
157 |
+
stride=1,
|
158 |
+
padding=0)
|
159 |
+
self.v = torch.nn.Conv2d(in_channels,
|
160 |
+
in_channels,
|
161 |
+
kernel_size=1,
|
162 |
+
stride=1,
|
163 |
+
padding=0)
|
164 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
165 |
+
in_channels,
|
166 |
+
kernel_size=1,
|
167 |
+
stride=1,
|
168 |
+
padding=0)
|
169 |
+
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
h_ = x
|
173 |
+
h_ = self.norm(h_)
|
174 |
+
q = self.q(h_)
|
175 |
+
k = self.k(h_)
|
176 |
+
v = self.v(h_)
|
177 |
+
|
178 |
+
# compute attention
|
179 |
+
b,c,h,w = q.shape
|
180 |
+
q = q.reshape(b,c,h*w)
|
181 |
+
q = q.permute(0,2,1) # b,hw,c
|
182 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
183 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
184 |
+
w_ = w_ * (int(c)**(-0.5))
|
185 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
186 |
+
|
187 |
+
# attend to values
|
188 |
+
v = v.reshape(b,c,h*w)
|
189 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
190 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
191 |
+
h_ = h_.reshape(b,c,h,w)
|
192 |
+
|
193 |
+
h_ = self.proj_out(h_)
|
194 |
+
|
195 |
+
return x+h_
|
196 |
+
|
197 |
+
|
198 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
199 |
+
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
200 |
+
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
201 |
+
if attn_type == "vanilla":
|
202 |
+
return AttnBlock(in_channels)
|
203 |
+
elif attn_type == "none":
|
204 |
+
return nn.Identity(in_channels)
|
205 |
+
else:
|
206 |
+
return LinAttnBlock(in_channels)
|
207 |
+
|
208 |
+
|
209 |
+
class Encoder(nn.Module):
|
210 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
211 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
212 |
+
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
213 |
+
**ignore_kwargs):
|
214 |
+
super().__init__()
|
215 |
+
if use_linear_attn: attn_type = "linear"
|
216 |
+
self.ch = ch
|
217 |
+
self.temb_ch = 0
|
218 |
+
self.num_resolutions = len(ch_mult)
|
219 |
+
self.num_res_blocks = num_res_blocks
|
220 |
+
self.resolution = resolution
|
221 |
+
self.in_channels = in_channels
|
222 |
+
|
223 |
+
# downsampling
|
224 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
225 |
+
self.ch,
|
226 |
+
kernel_size=3,
|
227 |
+
stride=1,
|
228 |
+
padding=1)
|
229 |
+
|
230 |
+
curr_res = resolution
|
231 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
232 |
+
self.in_ch_mult = in_ch_mult
|
233 |
+
self.down = nn.ModuleList()
|
234 |
+
for i_level in range(self.num_resolutions):
|
235 |
+
block = nn.ModuleList()
|
236 |
+
attn = nn.ModuleList()
|
237 |
+
block_in = ch*in_ch_mult[i_level]
|
238 |
+
block_out = ch*ch_mult[i_level]
|
239 |
+
for i_block in range(self.num_res_blocks):
|
240 |
+
block.append(ResnetBlock(in_channels=block_in,
|
241 |
+
out_channels=block_out,
|
242 |
+
temb_channels=self.temb_ch,
|
243 |
+
dropout=dropout))
|
244 |
+
block_in = block_out
|
245 |
+
if curr_res in attn_resolutions:
|
246 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
247 |
+
down = nn.Module()
|
248 |
+
down.block = block
|
249 |
+
down.attn = attn
|
250 |
+
if i_level != self.num_resolutions-1:
|
251 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
252 |
+
curr_res = curr_res // 2
|
253 |
+
self.down.append(down)
|
254 |
+
|
255 |
+
# middle
|
256 |
+
self.mid = nn.Module()
|
257 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
258 |
+
out_channels=block_in,
|
259 |
+
temb_channels=self.temb_ch,
|
260 |
+
dropout=dropout)
|
261 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
262 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
263 |
+
out_channels=block_in,
|
264 |
+
temb_channels=self.temb_ch,
|
265 |
+
dropout=dropout)
|
266 |
+
|
267 |
+
# end
|
268 |
+
self.norm_out = Normalize(block_in)
|
269 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
270 |
+
2*z_channels if double_z else z_channels,
|
271 |
+
kernel_size=3,
|
272 |
+
stride=1,
|
273 |
+
padding=1)
|
274 |
+
|
275 |
+
def forward(self, x):
|
276 |
+
# timestep embedding
|
277 |
+
temb = None
|
278 |
+
|
279 |
+
# downsampling
|
280 |
+
hs = [self.conv_in(x)]
|
281 |
+
for i_level in range(self.num_resolutions):
|
282 |
+
for i_block in range(self.num_res_blocks):
|
283 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
284 |
+
if len(self.down[i_level].attn) > 0:
|
285 |
+
h = self.down[i_level].attn[i_block](h)
|
286 |
+
hs.append(h)
|
287 |
+
if i_level != self.num_resolutions-1:
|
288 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
289 |
+
|
290 |
+
# middle
|
291 |
+
h = hs[-1]
|
292 |
+
h = self.mid.block_1(h, temb)
|
293 |
+
h = self.mid.attn_1(h)
|
294 |
+
h = self.mid.block_2(h, temb)
|
295 |
+
|
296 |
+
# end
|
297 |
+
h = self.norm_out(h)
|
298 |
+
h = nonlinearity(h)
|
299 |
+
h = self.conv_out(h)
|
300 |
+
return h
|
301 |
+
|
302 |
+
|
303 |
+
class Decoder(nn.Module):
|
304 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
305 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
306 |
+
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
307 |
+
attn_type="vanilla", **ignorekwargs):
|
308 |
+
super().__init__()
|
309 |
+
if use_linear_attn: attn_type = "linear"
|
310 |
+
self.ch = ch
|
311 |
+
self.temb_ch = 0
|
312 |
+
self.num_resolutions = len(ch_mult)
|
313 |
+
self.num_res_blocks = num_res_blocks
|
314 |
+
self.resolution = resolution
|
315 |
+
self.in_channels = in_channels
|
316 |
+
self.give_pre_end = give_pre_end
|
317 |
+
self.tanh_out = tanh_out
|
318 |
+
|
319 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
320 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
321 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
322 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
323 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
324 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
325 |
+
self.z_shape, np.prod(self.z_shape)))
|
326 |
+
|
327 |
+
# z to block_in
|
328 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
329 |
+
block_in,
|
330 |
+
kernel_size=3,
|
331 |
+
stride=1,
|
332 |
+
padding=1)
|
333 |
+
|
334 |
+
# middle
|
335 |
+
self.mid = nn.Module()
|
336 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
337 |
+
out_channels=block_in,
|
338 |
+
temb_channels=self.temb_ch,
|
339 |
+
dropout=dropout)
|
340 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
341 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
342 |
+
out_channels=block_in,
|
343 |
+
temb_channels=self.temb_ch,
|
344 |
+
dropout=dropout)
|
345 |
+
|
346 |
+
# upsampling
|
347 |
+
self.up = nn.ModuleList()
|
348 |
+
for i_level in reversed(range(self.num_resolutions)):
|
349 |
+
block = nn.ModuleList()
|
350 |
+
attn = nn.ModuleList()
|
351 |
+
block_out = ch*ch_mult[i_level]
|
352 |
+
for i_block in range(self.num_res_blocks+1):
|
353 |
+
block.append(ResnetBlock(in_channels=block_in,
|
354 |
+
out_channels=block_out,
|
355 |
+
temb_channels=self.temb_ch,
|
356 |
+
dropout=dropout))
|
357 |
+
block_in = block_out
|
358 |
+
if curr_res in attn_resolutions:
|
359 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
360 |
+
up = nn.Module()
|
361 |
+
up.block = block
|
362 |
+
up.attn = attn
|
363 |
+
if i_level != 0:
|
364 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
365 |
+
curr_res = curr_res * 2
|
366 |
+
self.up.insert(0, up) # prepend to get consistent order
|
367 |
+
|
368 |
+
# end
|
369 |
+
self.norm_out = Normalize(block_in)
|
370 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
371 |
+
out_ch,
|
372 |
+
kernel_size=3,
|
373 |
+
stride=1,
|
374 |
+
padding=1)
|
375 |
+
|
376 |
+
def forward(self, z):
|
377 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
378 |
+
self.last_z_shape = z.shape
|
379 |
+
|
380 |
+
# timestep embedding
|
381 |
+
temb = None
|
382 |
+
|
383 |
+
# z to block_in
|
384 |
+
h = self.conv_in(z)
|
385 |
+
|
386 |
+
# middle
|
387 |
+
h = self.mid.block_1(h, temb)
|
388 |
+
h = self.mid.attn_1(h)
|
389 |
+
h = self.mid.block_2(h, temb)
|
390 |
+
|
391 |
+
# upsampling
|
392 |
+
for i_level in reversed(range(self.num_resolutions)):
|
393 |
+
for i_block in range(self.num_res_blocks+1):
|
394 |
+
h = self.up[i_level].block[i_block](h, temb)
|
395 |
+
if len(self.up[i_level].attn) > 0:
|
396 |
+
h = self.up[i_level].attn[i_block](h)
|
397 |
+
if i_level != 0:
|
398 |
+
h = self.up[i_level].upsample(h)
|
399 |
+
|
400 |
+
# end
|
401 |
+
if self.give_pre_end:
|
402 |
+
return h
|
403 |
+
|
404 |
+
h = self.norm_out(h)
|
405 |
+
h = nonlinearity(h)
|
406 |
+
h = self.conv_out(h)
|
407 |
+
if self.tanh_out:
|
408 |
+
h = torch.tanh(h)
|
409 |
+
return h
|
410 |
+
|
411 |
+
|
412 |
+
class FrozenAutoencoderKL(nn.Module):
|
413 |
+
def __init__(self, ddconfig, embed_dim, pretrained_path, scale_factor=0.18215):
|
414 |
+
super().__init__()
|
415 |
+
print(f'Create autoencoder with scale_factor={scale_factor}')
|
416 |
+
self.encoder = Encoder(**ddconfig)
|
417 |
+
self.decoder = Decoder(**ddconfig)
|
418 |
+
assert ddconfig["double_z"]
|
419 |
+
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
420 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
421 |
+
self.embed_dim = embed_dim
|
422 |
+
self.scale_factor = scale_factor
|
423 |
+
m, u = self.load_state_dict(torch.load(pretrained_path, map_location='cpu'))
|
424 |
+
assert len(m) == 0 and len(u) == 0
|
425 |
+
self.eval()
|
426 |
+
self.requires_grad_(False)
|
427 |
+
|
428 |
+
def encode_moments(self, x):
|
429 |
+
h = self.encoder(x)
|
430 |
+
moments = self.quant_conv(h)
|
431 |
+
return moments
|
432 |
+
|
433 |
+
def sample(self, moments):
|
434 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
435 |
+
logvar = torch.clamp(logvar, -30.0, 20.0)
|
436 |
+
std = torch.exp(0.5 * logvar)
|
437 |
+
z = mean + std * torch.randn_like(mean)
|
438 |
+
z = self.scale_factor * z
|
439 |
+
return z
|
440 |
+
|
441 |
+
def encode(self, x):
|
442 |
+
moments = self.encode_moments(x)
|
443 |
+
z = self.sample(moments)
|
444 |
+
return z
|
445 |
+
|
446 |
+
def decode(self, z):
|
447 |
+
z = (1. / self.scale_factor) * z
|
448 |
+
z = self.post_quant_conv(z)
|
449 |
+
dec = self.decoder(z)
|
450 |
+
return dec
|
451 |
+
|
452 |
+
def forward(self, inputs, fn):
|
453 |
+
if fn == 'encode_moments':
|
454 |
+
return self.encode_moments(inputs)
|
455 |
+
elif fn == 'encode':
|
456 |
+
return self.encode(inputs)
|
457 |
+
elif fn == 'decode':
|
458 |
+
return self.decode(inputs)
|
459 |
+
else:
|
460 |
+
raise NotImplementedError
|
461 |
+
|
462 |
+
|
463 |
+
def get_model(pretrained_path, scale_factor=0.18215):
|
464 |
+
ddconfig = dict(
|
465 |
+
double_z=True,
|
466 |
+
z_channels=4,
|
467 |
+
resolution=256,
|
468 |
+
in_channels=3,
|
469 |
+
out_ch=3,
|
470 |
+
ch=128,
|
471 |
+
ch_mult=[1, 2, 4, 4],
|
472 |
+
num_res_blocks=2,
|
473 |
+
attn_resolutions=[],
|
474 |
+
dropout=0.0
|
475 |
+
)
|
476 |
+
return FrozenAutoencoderKL(ddconfig, 4, pretrained_path, scale_factor)
|
477 |
+
|
478 |
+
|
479 |
+
def main():
|
480 |
+
import torchvision.transforms as transforms
|
481 |
+
from torchvision.utils import save_image
|
482 |
+
import os
|
483 |
+
from PIL import Image
|
484 |
+
|
485 |
+
model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
|
486 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
487 |
+
model = model.to(device)
|
488 |
+
|
489 |
+
scale_factor = 0.18215
|
490 |
+
T = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor()])
|
491 |
+
path = 'imgs'
|
492 |
+
fnames = os.listdir(path)
|
493 |
+
for fname in fnames:
|
494 |
+
p = os.path.join(path, fname)
|
495 |
+
img = Image.open(p)
|
496 |
+
img = T(img)
|
497 |
+
img = img * 2. - 1
|
498 |
+
img = img[None, ...]
|
499 |
+
img = img.to(device)
|
500 |
+
|
501 |
+
# with torch.cuda.amp.autocast():
|
502 |
+
# moments = model.encode_moments(img)
|
503 |
+
# mean, logvar = torch.chunk(moments, 2, dim=1)
|
504 |
+
# logvar = torch.clamp(logvar, -30.0, 20.0)
|
505 |
+
# std = torch.exp(0.5 * logvar)
|
506 |
+
# zs = [(mean + std * torch.randn_like(mean)) * scale_factor for _ in range(4)]
|
507 |
+
# recons = [model.decode(z) for z in zs]
|
508 |
+
|
509 |
+
with torch.cuda.amp.autocast():
|
510 |
+
print('test encode & decode')
|
511 |
+
recons = [model.decode(model.encode(img)) for _ in range(4)]
|
512 |
+
|
513 |
+
out = torch.cat([img, *recons], dim=0)
|
514 |
+
out = (out + 1) * 0.5
|
515 |
+
save_image(out, f'recons_{fname}')
|
516 |
+
|
517 |
+
|
518 |
+
if __name__ == "__main__":
|
519 |
+
main()
|
src/COP-GEN-Beta/libs/timm.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code from timm 0.3.2
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import math
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
|
8 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
9 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
10 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
11 |
+
def norm_cdf(x):
|
12 |
+
# Computes standard normal cumulative distribution function
|
13 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
14 |
+
|
15 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
16 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
17 |
+
"The distribution of values may be incorrect.",
|
18 |
+
stacklevel=2)
|
19 |
+
|
20 |
+
with torch.no_grad():
|
21 |
+
# Values are generated by using a truncated uniform distribution and
|
22 |
+
# then using the inverse CDF for the normal distribution.
|
23 |
+
# Get upper and lower cdf values
|
24 |
+
l = norm_cdf((a - mean) / std)
|
25 |
+
u = norm_cdf((b - mean) / std)
|
26 |
+
|
27 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
28 |
+
# [2l-1, 2u-1].
|
29 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
30 |
+
|
31 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
32 |
+
# standard normal
|
33 |
+
tensor.erfinv_()
|
34 |
+
|
35 |
+
# Transform to proper mean, std
|
36 |
+
tensor.mul_(std * math.sqrt(2.))
|
37 |
+
tensor.add_(mean)
|
38 |
+
|
39 |
+
# Clamp to ensure it's in the proper range
|
40 |
+
tensor.clamp_(min=a, max=b)
|
41 |
+
return tensor
|
42 |
+
|
43 |
+
|
44 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
45 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
46 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
47 |
+
normal distribution. The values are effectively drawn from the
|
48 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
49 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
50 |
+
the bounds. The method used for generating the random values works
|
51 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
52 |
+
Args:
|
53 |
+
tensor: an n-dimensional `torch.Tensor`
|
54 |
+
mean: the mean of the normal distribution
|
55 |
+
std: the standard deviation of the normal distribution
|
56 |
+
a: the minimum cutoff value
|
57 |
+
b: the maximum cutoff value
|
58 |
+
Examples:
|
59 |
+
>>> w = torch.empty(3, 5)
|
60 |
+
>>> nn.init.trunc_normal_(w)
|
61 |
+
"""
|
62 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
63 |
+
|
64 |
+
|
65 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
66 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
67 |
+
|
68 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
69 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
70 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
71 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
72 |
+
'survival rate' as the argument.
|
73 |
+
|
74 |
+
"""
|
75 |
+
if drop_prob == 0. or not training:
|
76 |
+
return x
|
77 |
+
keep_prob = 1 - drop_prob
|
78 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
79 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
80 |
+
random_tensor.floor_() # binarize
|
81 |
+
output = x.div(keep_prob) * random_tensor
|
82 |
+
return output
|
83 |
+
|
84 |
+
|
85 |
+
class DropPath(nn.Module):
|
86 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
87 |
+
"""
|
88 |
+
def __init__(self, drop_prob=None):
|
89 |
+
super(DropPath, self).__init__()
|
90 |
+
self.drop_prob = drop_prob
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
return drop_path(x, self.drop_prob, self.training)
|
94 |
+
|
95 |
+
|
96 |
+
class Mlp(nn.Module):
|
97 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
98 |
+
super().__init__()
|
99 |
+
out_features = out_features or in_features
|
100 |
+
hidden_features = hidden_features or in_features
|
101 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
102 |
+
self.act = act_layer()
|
103 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
104 |
+
self.drop = nn.Dropout(drop)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
x = self.fc1(x)
|
108 |
+
x = self.act(x)
|
109 |
+
x = self.drop(x)
|
110 |
+
x = self.fc2(x)
|
111 |
+
x = self.drop(x)
|
112 |
+
return x
|
src/COP-GEN-Beta/libs/triffuser_multi_post_ln.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
from .timm import trunc_normal_, DropPath, Mlp
|
5 |
+
import einops
|
6 |
+
import torch.utils.checkpoint
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
10 |
+
ATTENTION_MODE = 'flash'
|
11 |
+
else:
|
12 |
+
try:
|
13 |
+
import xformers
|
14 |
+
import xformers.ops
|
15 |
+
ATTENTION_MODE = 'xformers'
|
16 |
+
except:
|
17 |
+
ATTENTION_MODE = 'math'
|
18 |
+
print(f'attention mode is {ATTENTION_MODE}')
|
19 |
+
|
20 |
+
|
21 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
22 |
+
"""
|
23 |
+
Create sinusoidal timestep embeddings.
|
24 |
+
|
25 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
26 |
+
These may be fractional.
|
27 |
+
:param dim: the dimension of the output.
|
28 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
29 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
30 |
+
"""
|
31 |
+
half = dim // 2
|
32 |
+
freqs = torch.exp(
|
33 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
34 |
+
).to(device=timesteps.device)
|
35 |
+
args = timesteps[:, None].float() * freqs[None]
|
36 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
37 |
+
if dim % 2:
|
38 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
39 |
+
return embedding
|
40 |
+
|
41 |
+
|
42 |
+
def patchify(imgs, patch_size):
|
43 |
+
x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
def unpatchify(x, in_chans):
|
48 |
+
patch_size = int((x.shape[2] // in_chans) ** 0.5)
|
49 |
+
h = w = int(x.shape[1] ** .5)
|
50 |
+
assert h * w == x.shape[1] and patch_size ** 2 * in_chans == x.shape[2]
|
51 |
+
x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
def interpolate_pos_emb(pos_emb, old_shape, new_shape):
|
56 |
+
pos_emb = einops.rearrange(pos_emb, 'B (H W) C -> B C H W', H=old_shape[0], W=old_shape[1])
|
57 |
+
pos_emb = F.interpolate(pos_emb, new_shape, mode='bilinear')
|
58 |
+
pos_emb = einops.rearrange(pos_emb, 'B C H W -> B (H W) C')
|
59 |
+
return pos_emb
|
60 |
+
|
61 |
+
|
62 |
+
class Attention(nn.Module):
|
63 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
64 |
+
super().__init__()
|
65 |
+
self.num_heads = num_heads
|
66 |
+
head_dim = dim // num_heads
|
67 |
+
self.scale = qk_scale or head_dim ** -0.5
|
68 |
+
|
69 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
70 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
71 |
+
self.proj = nn.Linear(dim, dim)
|
72 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
B, L, C = x.shape
|
76 |
+
|
77 |
+
qkv = self.qkv(x)
|
78 |
+
if ATTENTION_MODE == 'flash':
|
79 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
|
80 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
81 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
82 |
+
x = einops.rearrange(x, 'B H L D -> B L (H D)')
|
83 |
+
elif ATTENTION_MODE == 'xformers':
|
84 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
|
85 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
86 |
+
x = xformers.ops.memory_efficient_attention(q, k, v)
|
87 |
+
x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
|
88 |
+
elif ATTENTION_MODE == 'math':
|
89 |
+
with torch.amp.autocast(device_type='cuda', enabled=False):
|
90 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
|
91 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
92 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
93 |
+
attn = attn.softmax(dim=-1)
|
94 |
+
attn = self.attn_drop(attn)
|
95 |
+
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
96 |
+
else:
|
97 |
+
raise NotImplemented
|
98 |
+
|
99 |
+
x = self.proj(x)
|
100 |
+
x = self.proj_drop(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
class Block(nn.Module):
|
105 |
+
|
106 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
107 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
|
108 |
+
super().__init__()
|
109 |
+
self.norm1 = norm_layer(dim) if skip else None
|
110 |
+
self.norm2 = norm_layer(dim)
|
111 |
+
|
112 |
+
self.attn = Attention(
|
113 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
114 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
115 |
+
self.norm3 = norm_layer(dim)
|
116 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
117 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
118 |
+
self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
|
119 |
+
self.use_checkpoint = use_checkpoint
|
120 |
+
|
121 |
+
def forward(self, x, skip=None):
|
122 |
+
if self.use_checkpoint:
|
123 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
|
124 |
+
else:
|
125 |
+
return self._forward(x, skip)
|
126 |
+
|
127 |
+
def _forward(self, x, skip=None):
|
128 |
+
if self.skip_linear is not None:
|
129 |
+
x = self.skip_linear(torch.cat([x, skip], dim=-1))
|
130 |
+
x = self.norm1(x)
|
131 |
+
x = x + self.drop_path(self.attn(x))
|
132 |
+
x = self.norm2(x)
|
133 |
+
|
134 |
+
x = x + self.drop_path(self.mlp(x))
|
135 |
+
x = self.norm3(x)
|
136 |
+
|
137 |
+
return x
|
138 |
+
|
139 |
+
|
140 |
+
class PatchEmbed(nn.Module):
|
141 |
+
""" Image to Patch Embedding
|
142 |
+
"""
|
143 |
+
def __init__(self, patch_size, in_chans=3, embed_dim=768):
|
144 |
+
super().__init__()
|
145 |
+
self.patch_size = patch_size
|
146 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
B, C, H, W = x.shape
|
150 |
+
assert H % self.patch_size == 0 and W % self.patch_size == 0
|
151 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class Triffuser(nn.Module):
|
156 |
+
def __init__(self, img_size, in_chans, patch_size, embed_dim=768, depth=12,
|
157 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, pos_drop_rate=0., drop_rate=0., attn_drop_rate=0.,
|
158 |
+
norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False,
|
159 |
+
num_modalities=None,
|
160 |
+
# text_dim=None,
|
161 |
+
# num_text_tokens=None,
|
162 |
+
clip_img_dim=None # All modalities with the same clip dimension
|
163 |
+
):
|
164 |
+
super().__init__()
|
165 |
+
self.in_chans = in_chans
|
166 |
+
self.patch_size = patch_size
|
167 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
168 |
+
self.num_modalities = num_modalities
|
169 |
+
if num_modalities is None:
|
170 |
+
raise ValueError("num_modalities must be provided")
|
171 |
+
|
172 |
+
self.patch_embeds = nn.ModuleList([PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) for _ in range(num_modalities)])
|
173 |
+
self.img_size = (img_size, img_size) if isinstance(img_size, int) else img_size # the default img size
|
174 |
+
assert self.img_size[0] % patch_size == 0 and self.img_size[1] % patch_size == 0
|
175 |
+
self.num_patches = (self.img_size[0] // patch_size) * (self.img_size[1] // patch_size)
|
176 |
+
|
177 |
+
self.time_img_embeds = nn.ModuleList([nn.Sequential(
|
178 |
+
nn.Linear(embed_dim, 4 * embed_dim),
|
179 |
+
nn.SiLU(),
|
180 |
+
nn.Linear(4 * embed_dim, embed_dim),
|
181 |
+
) if mlp_time_embed else nn.Identity() for _ in range(num_modalities)])
|
182 |
+
|
183 |
+
# self.text_embed = nn.Linear(text_dim, embed_dim)
|
184 |
+
# self.text_out = nn.Linear(embed_dim, text_dim)
|
185 |
+
|
186 |
+
# TODO: We skip clip embedding for now
|
187 |
+
# self.clip_img_embed = nn.Linear(clip_img_dim, embed_dim)
|
188 |
+
# self.clip_img_out = nn.Linear(embed_dim, clip_img_dim)
|
189 |
+
|
190 |
+
# self.num_text_tokens = num_text_tokens
|
191 |
+
# TODO: ATM we assume the same num_patches for all modalities
|
192 |
+
# 1 for time embedding token of each modality
|
193 |
+
# num_patches for each modality (assuming the same number of patches for all modalities)
|
194 |
+
self.num_tokens = 1 * self.num_modalities + self.num_patches * self.num_modalities
|
195 |
+
|
196 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
|
197 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
198 |
+
|
199 |
+
self.in_blocks = nn.ModuleList([
|
200 |
+
Block(
|
201 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
202 |
+
drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, use_checkpoint=use_checkpoint)
|
203 |
+
for _ in range(depth // 2)])
|
204 |
+
|
205 |
+
self.mid_block = Block(
|
206 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
207 |
+
drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, use_checkpoint=use_checkpoint)
|
208 |
+
|
209 |
+
self.out_blocks = nn.ModuleList([
|
210 |
+
Block(
|
211 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
212 |
+
drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, skip=True, use_checkpoint=use_checkpoint)
|
213 |
+
for _ in range(depth // 2)])
|
214 |
+
|
215 |
+
self.norm = norm_layer(embed_dim)
|
216 |
+
self.patch_dim = patch_size ** 2 * in_chans
|
217 |
+
self.decoder_preds = nn.ModuleList([nn.Linear(embed_dim, self.patch_dim, bias=True) for _ in range(num_modalities)])
|
218 |
+
|
219 |
+
trunc_normal_(self.pos_embed, std=.02)
|
220 |
+
self.apply(self._init_weights)
|
221 |
+
|
222 |
+
def _init_weights(self, m):
|
223 |
+
if isinstance(m, nn.Linear):
|
224 |
+
trunc_normal_(m.weight, std=.02)
|
225 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
226 |
+
nn.init.constant_(m.bias, 0)
|
227 |
+
elif isinstance(m, nn.LayerNorm):
|
228 |
+
nn.init.constant_(m.bias, 0)
|
229 |
+
nn.init.constant_(m.weight, 1.0)
|
230 |
+
|
231 |
+
@torch.jit.ignore
|
232 |
+
def no_weight_decay(self):
|
233 |
+
return {'pos_embed'}
|
234 |
+
|
235 |
+
def forward(self, imgs, t_imgs):
|
236 |
+
|
237 |
+
assert len(imgs) == len(t_imgs) == self.num_modalities
|
238 |
+
|
239 |
+
# TODO: We are still assuming all images have the same shape
|
240 |
+
_, _, H, W = imgs[0].shape
|
241 |
+
|
242 |
+
imgs = [self.patch_embeds[i](img) for i, img in enumerate(imgs)]
|
243 |
+
|
244 |
+
t_imgs_token = [self.time_img_embeds[i](timestep_embedding(t_img, self.embed_dim)) for i, t_img in enumerate(t_imgs)]
|
245 |
+
t_imgs_token = [t_img_token.unsqueeze(dim=1) for t_img_token in t_imgs_token]
|
246 |
+
|
247 |
+
# text = self.text_embed(text)
|
248 |
+
# clip_img = self.clip_img_embed(clip_img)
|
249 |
+
x = torch.cat((*t_imgs_token, *imgs), dim=1)
|
250 |
+
|
251 |
+
num_img_tokens = [img.size(1) for img in imgs] # Each image might have different number of tokens
|
252 |
+
num_t_tokens = [1] * self.num_modalities # There is only one time token for each modality
|
253 |
+
|
254 |
+
# TODO: ATM assume all modality images have the same shape
|
255 |
+
if H == self.img_size[0] and W == self.img_size[1]:
|
256 |
+
pos_embed = self.pos_embed
|
257 |
+
else: # interpolate the positional embedding when the input image is not of the default shape
|
258 |
+
raise NotImplementedError("Why are we here? Images are not of the default shape. Interpolate positional embedding.")
|
259 |
+
pos_embed_others, pos_embed_patches = torch.split(self.pos_embed, [1 + 1 + num_text_tokens + 1, self.num_patches], dim=1)
|
260 |
+
pos_embed_patches = interpolate_pos_emb(pos_embed_patches, (self.img_size[0] // self.patch_size, self.img_size[1] // self.patch_size),
|
261 |
+
(H // self.patch_size, W // self.patch_size))
|
262 |
+
pos_embed = torch.cat((pos_embed_others, pos_embed_patches), dim=1)
|
263 |
+
|
264 |
+
x = x + pos_embed
|
265 |
+
x = self.pos_drop(x)
|
266 |
+
|
267 |
+
skips = []
|
268 |
+
for blk in self.in_blocks:
|
269 |
+
x = blk(x)
|
270 |
+
skips.append(x)
|
271 |
+
|
272 |
+
x = self.mid_block(x)
|
273 |
+
|
274 |
+
for blk in self.out_blocks:
|
275 |
+
x = blk(x, skips.pop())
|
276 |
+
|
277 |
+
x = self.norm(x)
|
278 |
+
|
279 |
+
all_t_imgs = x.split((*num_t_tokens, *num_img_tokens), dim=1)
|
280 |
+
|
281 |
+
t_imgs_token_out = all_t_imgs[:self.num_modalities]
|
282 |
+
imgs_out = all_t_imgs[self.num_modalities:]
|
283 |
+
|
284 |
+
imgs_out = [self.decoder_preds[i](img_out) for i, img_out in enumerate(imgs_out)]
|
285 |
+
imgs_out = [unpatchify(img_out, self.in_chans) for img_out in imgs_out]
|
286 |
+
|
287 |
+
# clip_img_out = self.clip_img_out(clip_img_out)
|
288 |
+
# text_out = self.text_out(text_out)
|
289 |
+
|
290 |
+
return imgs_out
|
src/COP-GEN-Beta/majortom/NMajorTOM.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from pathlib import Path
|
6 |
+
import rasterio as rio
|
7 |
+
from PIL import Image
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
import random
|
10 |
+
|
11 |
+
class NMajorTOM(Dataset):
|
12 |
+
"""NMajorTOM Dataset with multiple modalities (https://huggingface.co/Major-TOM)
|
13 |
+
|
14 |
+
Args:
|
15 |
+
modalities (dict): Dictionary of modality configurations, where each key is a modality name
|
16 |
+
and value is a dict containing:
|
17 |
+
- df: Metadata dataframe for that modality
|
18 |
+
- local_dir: Root directory for that modality
|
19 |
+
- tif_bands: List of tif bands to read
|
20 |
+
- png_bands: List of png bands to read
|
21 |
+
- tif_transforms: List of transforms for tif files
|
22 |
+
- png_transforms: List of transforms for png files
|
23 |
+
random_flip (bool): Whether to randomly flip all modalities together
|
24 |
+
ratio_train_test (float): Ratio of training samples (e.g., 0.8 for 80% train, 20% test)
|
25 |
+
seed (int): Random seed for reproducible train/test splits
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, modalities, random_flip=True, ratio_train_test=0.8, seed=42):
|
29 |
+
super().__init__()
|
30 |
+
self.modalities = {}
|
31 |
+
|
32 |
+
# Set random seed for reproducibility
|
33 |
+
random.seed(seed)
|
34 |
+
|
35 |
+
# Process each modality's configuration
|
36 |
+
for modality_name, config in modalities.items():
|
37 |
+
# Drop rows that are complete duplicates across all relevant columns
|
38 |
+
num_rows = len(config['df'])
|
39 |
+
if modality_name == 'S1RTC':
|
40 |
+
relevant_cols = ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id',
|
41 |
+
'timestamp', 'nodata', 'orbit_state', 'centre_lat',
|
42 |
+
'centre_lon', 'crs', 'parquet_url', 'geometry']
|
43 |
+
else:
|
44 |
+
relevant_cols = list(config['df'].keys())
|
45 |
+
config['df'] = config['df'].drop_duplicates(subset=relevant_cols)
|
46 |
+
print(f"Dropped {num_rows - len(config['df'])} duplicates from {modality_name}")
|
47 |
+
|
48 |
+
# By now, we should have no duplicate grid_cells
|
49 |
+
if config['df']['grid_cell'].duplicated().any():
|
50 |
+
raise ValueError(f"Found rows with duplicate grid_cells but different values in modality {modality_name}")
|
51 |
+
|
52 |
+
self.modalities[modality_name] = {
|
53 |
+
'df': config['df'],
|
54 |
+
'local_dir': Path(config['local_dir']) if isinstance(config['local_dir'], str) else config['local_dir'],
|
55 |
+
'tif_bands': config['tif_bands'] if not isinstance(config['tif_bands'], str) else [config['tif_bands']],
|
56 |
+
'png_bands': config['png_bands'] if not isinstance(config['png_bands'], str) else [config['png_bands']],
|
57 |
+
'tif_transforms': transforms.Compose(config['tif_transforms']) if config['tif_transforms'] is not None else None,
|
58 |
+
'png_transforms': transforms.Compose(config['png_transforms']) if config['png_transforms'] is not None else None
|
59 |
+
}
|
60 |
+
|
61 |
+
self.random_flip = random_flip
|
62 |
+
|
63 |
+
# Get the set of grid_cells for each modality
|
64 |
+
grid_cells_by_modality = {
|
65 |
+
name: set(mod['df']['grid_cell'].values)
|
66 |
+
for name, mod in self.modalities.items()
|
67 |
+
}
|
68 |
+
|
69 |
+
# Check that all modalities share the same grid_cells
|
70 |
+
if len(grid_cells_by_modality) > 0:
|
71 |
+
reference_grid_cells = grid_cells_by_modality[list(grid_cells_by_modality.keys())[0]]
|
72 |
+
for modality_name, grid_cells in grid_cells_by_modality.items():
|
73 |
+
if grid_cells != reference_grid_cells:
|
74 |
+
missing = reference_grid_cells - grid_cells
|
75 |
+
extra = grid_cells - reference_grid_cells
|
76 |
+
error_msg = f"Modality {modality_name} has mismatched grid_cells.\n"
|
77 |
+
if missing:
|
78 |
+
error_msg += f"Missing grid_cells: {missing}\n"
|
79 |
+
if extra:
|
80 |
+
error_msg += f"Extra grid_cells: {extra}"
|
81 |
+
raise ValueError(error_msg)
|
82 |
+
|
83 |
+
# Sort all dataframes by grid_cell for consistent sampling
|
84 |
+
for modality in self.modalities.values():
|
85 |
+
modality['df'] = modality['df'].sort_values('grid_cell').reset_index(drop=True)
|
86 |
+
|
87 |
+
|
88 |
+
print("Creating train/test split...")
|
89 |
+
|
90 |
+
# After sorting dataframes, create train/test split
|
91 |
+
all_grid_cells = list(reference_grid_cells)
|
92 |
+
random.shuffle(all_grid_cells)
|
93 |
+
|
94 |
+
n_train = int(len(all_grid_cells) * ratio_train_test)
|
95 |
+
self.train_grid_cells = set(all_grid_cells[:n_train])
|
96 |
+
self.test_grid_cells = set(all_grid_cells[n_train:])
|
97 |
+
|
98 |
+
# Let's create a dictionary of grid_cells to split
|
99 |
+
self.grid_cell_to_split = {grid_cell: 'train' if grid_cell in self.train_grid_cells else 'test' for grid_cell in reference_grid_cells}
|
100 |
+
|
101 |
+
print(f"Split dataset into {len(self.train_grid_cells)} train and {len(self.test_grid_cells)} test grid cells")
|
102 |
+
|
103 |
+
def __len__(self):
|
104 |
+
# Return length of any modality (they should all be the same)
|
105 |
+
assert len(self.modalities) > 0, "No modalities provided"
|
106 |
+
# Get len for each modality and make sure they are the same
|
107 |
+
lengths = [len(mod['df']) for mod in self.modalities.values()]
|
108 |
+
if not all(x == lengths[0] for x in lengths):
|
109 |
+
raise ValueError("All modalities must have the same number of samples")
|
110 |
+
return lengths[0]
|
111 |
+
|
112 |
+
def __getitem__(self, idx):
|
113 |
+
result = {}
|
114 |
+
|
115 |
+
# Generate the same random flip decision for all modalities
|
116 |
+
do_flip = self.random_flip and random.random() < 0.5
|
117 |
+
|
118 |
+
# Get the grid cell for this index (they're all the same across modalities)
|
119 |
+
first_modality = list(self.modalities.keys())[0]
|
120 |
+
current_grid_cell = self.modalities[first_modality]['df'].iloc[idx]['grid_cell']
|
121 |
+
|
122 |
+
# Determine if this sample is in train or test set
|
123 |
+
split = self.grid_cell_to_split[current_grid_cell]
|
124 |
+
|
125 |
+
for modality_name, modality in self.modalities.items():
|
126 |
+
meta = modality['df'].iloc[idx]
|
127 |
+
product_id = meta.product_id if 'product_id' in meta.index else "id"
|
128 |
+
grid_cell = meta.grid_cell
|
129 |
+
row = grid_cell.split('_')[0]
|
130 |
+
|
131 |
+
path = modality['local_dir'] / Path(f"{row}/{grid_cell}/{product_id}")
|
132 |
+
out_dict = {}
|
133 |
+
|
134 |
+
# Process TIF bands
|
135 |
+
for band in modality['tif_bands']:
|
136 |
+
with rio.open(path / f'{band}.tif') as f:
|
137 |
+
out = f.read() # out = torch.from_numpy(f.read()).float()
|
138 |
+
if modality['tif_transforms'] is not None:
|
139 |
+
out = modality['tif_transforms'](out)
|
140 |
+
out_dict[band] = out
|
141 |
+
|
142 |
+
# Process PNG bands
|
143 |
+
for band in modality['png_bands']:
|
144 |
+
out = Image.open(path / f'{band}.png')
|
145 |
+
if modality['png_transforms'] is not None:
|
146 |
+
out = modality['png_transforms'](out)
|
147 |
+
out_dict[band] = out
|
148 |
+
|
149 |
+
# Apply the same random flip to all bands in this modality
|
150 |
+
if do_flip:
|
151 |
+
out_dict = {k: v.flip(-1) for k, v in out_dict.items()}
|
152 |
+
|
153 |
+
# Add split information to the output dictionary
|
154 |
+
out_dict['split'] = split
|
155 |
+
out_dict['grid_cell'] = current_grid_cell
|
156 |
+
|
157 |
+
result[modality_name] = out_dict
|
158 |
+
|
159 |
+
# Assert the grid_cells are the same for all modalities in the resulting dictionary
|
160 |
+
if len(result) > 0:
|
161 |
+
first_modality = list(result.keys())[0]
|
162 |
+
first_grid_cell = self.modalities[first_modality]['df'].iloc[idx]['grid_cell']
|
163 |
+
for modality_name in result.keys():
|
164 |
+
current_grid_cell = self.modalities[modality_name]['df'].iloc[idx]['grid_cell']
|
165 |
+
if current_grid_cell != first_grid_cell:
|
166 |
+
raise ValueError(f"Mismatched grid_cells found: {current_grid_cell} != {first_grid_cell}")
|
167 |
+
# Add grid_cell to the output dictionary for verification
|
168 |
+
result[modality_name]['grid_cell'] = current_grid_cell
|
169 |
+
|
170 |
+
return result
|
src/COP-GEN-Beta/majortom/coverage_vis.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from mpl_toolkits.basemap import Basemap
|
5 |
+
import PIL
|
6 |
+
|
7 |
+
def get_mask(df):
|
8 |
+
"""
|
9 |
+
Take a Major TOM dataframe and create a mask corresponding to available cells
|
10 |
+
"""
|
11 |
+
|
12 |
+
mask = np.zeros((2004,4008), dtype=np.uint8)
|
13 |
+
row_offset = -1002
|
14 |
+
col_offset = -2004
|
15 |
+
|
16 |
+
nodata = df['nodata'].values > 0.5
|
17 |
+
|
18 |
+
yy = mask.shape[0] - (np.array(df['grid_row_u']) - row_offset) - 1
|
19 |
+
xx = np.array(df['grid_col_r']) - col_offset
|
20 |
+
|
21 |
+
yy = yy[~nodata]
|
22 |
+
xx = xx[~nodata]
|
23 |
+
|
24 |
+
mask[yy, xx] = 255
|
25 |
+
|
26 |
+
return PIL.Image.fromarray(mask)
|
27 |
+
|
28 |
+
def fig2img(fig):
|
29 |
+
"""Convert a Matplotlib figure to a PIL Image and return it"""
|
30 |
+
import io
|
31 |
+
buf = io.BytesIO()
|
32 |
+
fig.savefig(buf)
|
33 |
+
buf.seek(0)
|
34 |
+
img = PIL.Image.open(buf)
|
35 |
+
return img
|
36 |
+
|
37 |
+
def light_basemap():
|
38 |
+
"""
|
39 |
+
Bright coloured contours
|
40 |
+
"""
|
41 |
+
|
42 |
+
with plt.ioff():
|
43 |
+
fig, ax = plt.subplots(figsize=(48,24), dpi=167)
|
44 |
+
|
45 |
+
m = Basemap(projection='sinu', lat_0=0, lon_0=0, resolution='l', ax=ax)
|
46 |
+
m.fillcontinents(color="#9eba9b", lake_color='#CCDDFF')
|
47 |
+
m.drawmapboundary(fill_color="#CCDDFF")
|
48 |
+
m.drawcountries(color="#666666", linewidth=1)
|
49 |
+
m.drawcoastlines(color="#666666", linewidth=1)
|
50 |
+
|
51 |
+
plt.gca().set_axis_off()
|
52 |
+
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
|
53 |
+
hspace = 0, wspace = 0)
|
54 |
+
plt.margins(0,0)
|
55 |
+
|
56 |
+
return fig2img(fig)
|
57 |
+
|
58 |
+
def dark_basemap():
|
59 |
+
"""
|
60 |
+
Dark contours
|
61 |
+
"""
|
62 |
+
|
63 |
+
with plt.ioff():
|
64 |
+
fig, ax = plt.subplots(figsize=(48,24), dpi=167)
|
65 |
+
|
66 |
+
m = Basemap(projection='sinu', lat_0=0, lon_0=0, resolution='l', ax=ax)
|
67 |
+
m.fillcontinents(color="#242424", lake_color='#242424')
|
68 |
+
m.drawmapboundary(fill_color="#242424")
|
69 |
+
m.drawcountries(color="#000000", linewidth=1)
|
70 |
+
m.drawcoastlines(color="#000000", linewidth=1)
|
71 |
+
|
72 |
+
plt.gca().set_axis_off()
|
73 |
+
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
|
74 |
+
hspace = 0, wspace = 0)
|
75 |
+
plt.margins(0,0)
|
76 |
+
|
77 |
+
return fig2img(fig)
|
78 |
+
|
79 |
+
def get_coveragemap(input, input2=None):
|
80 |
+
"""
|
81 |
+
Creates a complete coloured Major TOM coverage figure in the same style as in the official documentation
|
82 |
+
|
83 |
+
Optionally, input2 can be provided and then, the map plots a map with extra colours indicating cells available only in input (green) or only input2 (blue)
|
84 |
+
"""
|
85 |
+
|
86 |
+
if input2 is None:
|
87 |
+
return single_coveragemap(input)
|
88 |
+
else:
|
89 |
+
cmap1 = single_coveragemap(input)
|
90 |
+
cmap2 = single_coveragemap(input2)
|
91 |
+
|
92 |
+
# arrays for mixing
|
93 |
+
inp1_arr = np.array(cmap1)[...,:3]
|
94 |
+
inp2_arr = np.array(cmap2)[...,:3]
|
95 |
+
|
96 |
+
common_arr = inp1_arr*(inp1_arr.sum(-1) == inp2_arr.sum(-1))[:,:,None]
|
97 |
+
common_arr[:,:,(1,2)] = 0
|
98 |
+
inp1_arr[:,:,(0,2)] = 0 # Green - indicates presence of S2 only
|
99 |
+
inp2_arr[:,:,(0,1)] = 0 # Blue - indicates presense of DEM only
|
100 |
+
|
101 |
+
return PIL.Image.fromarray(((common_arr + inp1_arr + inp2_arr)).astype(np.uint8))
|
102 |
+
|
103 |
+
|
104 |
+
def single_coveragemap(input):
|
105 |
+
"""
|
106 |
+
Creates a complete coloured Major TOM coverage figure in the same style as in the official documentation
|
107 |
+
"""
|
108 |
+
|
109 |
+
# compute mask if df is provided
|
110 |
+
if isinstance(input, pd.DataFrame):
|
111 |
+
mask = get_mask(input)
|
112 |
+
else:
|
113 |
+
mask = input
|
114 |
+
|
115 |
+
basemap = light_basemap()
|
116 |
+
basemap_d = dark_basemap()
|
117 |
+
|
118 |
+
outside_earth = np.array(basemap.convert('RGBA'))[:, :, 0] == 255
|
119 |
+
outside_earth = PIL.Image.fromarray(outside_earth)
|
120 |
+
|
121 |
+
mask = mask.resize(basemap.size, PIL.Image.NEAREST)
|
122 |
+
|
123 |
+
basemap.putalpha(mask)
|
124 |
+
|
125 |
+
# Mask outside of earth
|
126 |
+
basemap.paste(outside_earth, (0,0), outside_earth)
|
127 |
+
|
128 |
+
basemap_d.paste(basemap, (0,0), basemap)
|
129 |
+
|
130 |
+
return basemap_d
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
DATASET_NAME = 'Major-TOM/Core-S2L2A'
|
134 |
+
meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
|
135 |
+
df = pd.read_parquet(meta_path)
|
136 |
+
|
137 |
+
# This is how you make a coverage figure!
|
138 |
+
coverage_img = get_coveragemap(df)
|
139 |
+
|
140 |
+
coverage_img.save('coverage-example.png', format='PNG')
|
141 |
+
|
142 |
+
# and this is how you can create an overap for 2 datasets!
|
143 |
+
DATASET_NAME = 'Major-TOM/Core-DEM'
|
144 |
+
meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
|
145 |
+
dem_df = pd.read_parquet(meta_path)
|
146 |
+
|
147 |
+
coverage_img = get_coveragemap(df,dem_df)
|
148 |
+
|
149 |
+
coverage_img.save('overlap-coverage-example.png', format='PNG')
|
src/COP-GEN-Beta/majortom/download_world.py
ADDED
@@ -0,0 +1,1009 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from shapely.geometry import box
|
3 |
+
from concurrent.futures import ThreadPoolExecutor
|
4 |
+
from pathlib import Path
|
5 |
+
import os
|
6 |
+
import geopandas as gpd
|
7 |
+
import pandas as pd
|
8 |
+
import pyarrow.parquet as pq
|
9 |
+
from typing import List, Dict, Set
|
10 |
+
import logging
|
11 |
+
import urllib.request
|
12 |
+
from concurrent import futures
|
13 |
+
import fsspec
|
14 |
+
from tqdm import tqdm
|
15 |
+
import tempfile
|
16 |
+
import time
|
17 |
+
import random
|
18 |
+
|
19 |
+
|
20 |
+
S2L2A_METADATA = ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'cloud_cover', 'nodata', 'centre_lat', 'centre_lon', 'crs', 'parquet_url', 'parquet_row', 'geometry']
|
21 |
+
S2L1C_METADATA = ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'cloud_cover', 'nodata', 'centre_lat', 'centre_lon', 'crs', 'parquet_url', 'parquet_row', 'geometry']
|
22 |
+
S1RTC_METADATA = ['grid_cell', 'grid_row_u', 'grid_col_r', 'product_id', 'timestamp', 'nodata', 'orbit_state', 'centre_lat', 'centre_lon', 'crs', 'parquet_url', 'parquet_row']
|
23 |
+
DEM_METADATA = ['grid_cell', 'grid_row_u', 'grid_col_r', 'nodata', 'max_val', 'min_val', 'centre_lat', 'centre_lon', 'crs', 'parquet_url', 'parquet_row', '__index_level_0__']
|
24 |
+
|
25 |
+
METADATA_COLUMNS = {
|
26 |
+
'Core-S2L2A': S2L2A_METADATA,
|
27 |
+
'Core-S2L1C': S2L1C_METADATA,
|
28 |
+
'Core-S1RTC': S1RTC_METADATA,
|
29 |
+
'Core-DEM': DEM_METADATA
|
30 |
+
}
|
31 |
+
|
32 |
+
CONTENT = {
|
33 |
+
'Core-S2L2A': ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12', 'cloud_mask'],
|
34 |
+
'Core-S2L1C': ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12', 'cloud_mask'],
|
35 |
+
'Core-S1RTC': ['vv', 'vh'],
|
36 |
+
'Core-DEM': ['DEM', 'compressed']
|
37 |
+
}
|
38 |
+
|
39 |
+
# Default max workers for extraction (can be higher as it's CPU-bound)
|
40 |
+
MAX_WORKERS = 32
|
41 |
+
# Default max workers for download (more conservative to avoid network issues)
|
42 |
+
DEFAULT_DOWNLOAD_WORKERS = 8
|
43 |
+
|
44 |
+
def parse_args():
|
45 |
+
|
46 |
+
if "INTERACTIVE" in os.environ: # Set INTERACTIVE=1 when running manually
|
47 |
+
return argparse.Namespace(
|
48 |
+
data_dir="./data/majorTOM",
|
49 |
+
bbox=[-180.0, -90.0, 180.0, 90.0],
|
50 |
+
sources=['Core-S2L2A', 'Core-S2L1C', 'Core-S1RTC', 'Core-DEM'],
|
51 |
+
subset_name="world",
|
52 |
+
start_date="2017-01-01",
|
53 |
+
end_date="2025-01-01",
|
54 |
+
cloud_cover=[0, 10],
|
55 |
+
preview=True,
|
56 |
+
mode="full",
|
57 |
+
delete_parquets=False,
|
58 |
+
download_workers=DEFAULT_DOWNLOAD_WORKERS,
|
59 |
+
revalidate=False
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
parser = argparse.ArgumentParser(description='Download satellite imagery from Major-TOM dataset')
|
63 |
+
parser.add_argument('--data-dir', type=str, default='./data/majorTOM',
|
64 |
+
help='Data directory for downloaded files')
|
65 |
+
parser.add_argument('--bbox', type=float, nargs=4,
|
66 |
+
default=[2.9559111595, 43.8179931641, 55.4920501709, 65.808380127],
|
67 |
+
help='Bounding box coordinates: minx miny maxx maxy')
|
68 |
+
parser.add_argument('--sources', type=str, nargs='+',
|
69 |
+
default=['Core-S2L2A', 'Core-S2L1C', 'Core-S1RTC'],
|
70 |
+
help='List of source names for the datasets')
|
71 |
+
parser.add_argument('--subset-name', type=str, required=True,
|
72 |
+
help='Name for the geographical subset being created')
|
73 |
+
parser.add_argument('--start-date', type=str, default='2017-01-01',
|
74 |
+
help='Start date for temporal range (YYYY-MM-DD)')
|
75 |
+
parser.add_argument('--end-date', type=str, default='2025-01-01',
|
76 |
+
help='End date for temporal range (YYYY-MM-DD)')
|
77 |
+
parser.add_argument('--cloud-cover', type=float, nargs=2, default=[0, 10],
|
78 |
+
help='Cloud cover range (min max)')
|
79 |
+
parser.add_argument('--criteria', type=str, default=None,
|
80 |
+
help='Criteria for timestamp deduplication. Currently we support "latest"')
|
81 |
+
parser.add_argument('--n-samples', type=int, default=None,
|
82 |
+
help='Number of samples to download')
|
83 |
+
parser.add_argument('--seed', type=int, default=None,
|
84 |
+
help='Random seed for reproducibility')
|
85 |
+
parser.add_argument('--preview', action='store_true',
|
86 |
+
help='If True, only print the number of samples for each source that will be downloaded')
|
87 |
+
parser.add_argument('--mode', type=str, choices=['full', 'download', 'extract'], default='full',
|
88 |
+
help='Mode of operation: full (download and extract), download (download parquets only), extract (extract from downloaded parquets)')
|
89 |
+
parser.add_argument('--delete-parquets', action='store_true',
|
90 |
+
help='Delete parquet files after extraction (only used with extract mode)')
|
91 |
+
parser.add_argument('--download-workers', type=int, default=DEFAULT_DOWNLOAD_WORKERS,
|
92 |
+
help=f'Number of parallel workers for downloading files. Default: {DEFAULT_DOWNLOAD_WORKERS}. Reduce this number if downloads are slow.')
|
93 |
+
parser.add_argument('--revalidate', action='store_true',
|
94 |
+
help='Force revalidation of all parquet files and redownload if corrupted')
|
95 |
+
return parser.parse_args()
|
96 |
+
|
97 |
+
|
98 |
+
def fix_crs(df):
|
99 |
+
if df['crs'].iloc[0].startswith('EPSG:EPSG:'):
|
100 |
+
df['crs'] = df['crs'].str.replace('EPSG:EPSG:', 'EPSG:', regex=False)
|
101 |
+
return df
|
102 |
+
|
103 |
+
def my_filter_metadata(df,
|
104 |
+
region=None,
|
105 |
+
daterange=None,
|
106 |
+
cloud_cover=(0,100),
|
107 |
+
nodata=(0, 1.0)
|
108 |
+
):
|
109 |
+
"""Filters the Major-TOM dataframe based on several parameters
|
110 |
+
|
111 |
+
Args:
|
112 |
+
df (geopandas dataframe): Parent dataframe
|
113 |
+
region (shapely geometry object) : Region of interest
|
114 |
+
daterange (tuple) : Inclusive range of dates (example format: '2020-01-01')
|
115 |
+
cloud_cover (tuple) : Inclusive percentage range (0-100) of cloud cover
|
116 |
+
nodata (tuple) : Inclusive fraction (0.0-1.0) of no data allowed in a sample
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
df: a filtered dataframe
|
120 |
+
"""
|
121 |
+
# temporal filtering
|
122 |
+
if daterange is not None and 'timestamp' in df.columns:
|
123 |
+
assert (isinstance(daterange, list) or isinstance(daterange, tuple)) and len(daterange)==2
|
124 |
+
df = df[df.timestamp >= daterange[0]]
|
125 |
+
df = df[df.timestamp <= daterange[1]]
|
126 |
+
|
127 |
+
# spatial filtering
|
128 |
+
if region is not None:
|
129 |
+
idxs = df.sindex.query(region)
|
130 |
+
df = df.take(idxs)
|
131 |
+
# cloud filtering
|
132 |
+
if cloud_cover is not None:
|
133 |
+
df = df[df.cloud_cover >= cloud_cover[0]]
|
134 |
+
df = df[df.cloud_cover <= cloud_cover[1]]
|
135 |
+
|
136 |
+
# spatial filtering
|
137 |
+
if nodata is not None:
|
138 |
+
df = df[df.nodata >= nodata[0]]
|
139 |
+
df = df[df.nodata <= nodata[1]]
|
140 |
+
|
141 |
+
return df
|
142 |
+
|
143 |
+
def my_filter_download(df, local_dir, source_name, by_row=False, verbose=False, tif_columns=None, download_workers=DEFAULT_DOWNLOAD_WORKERS):
|
144 |
+
"""Downloads and unpacks the data of Major-TOM based on a metadata dataframe"""
|
145 |
+
if isinstance(local_dir, str):
|
146 |
+
local_dir = Path(local_dir)
|
147 |
+
|
148 |
+
# identify all parquets that need to be downloaded (group them)
|
149 |
+
urls = df.parquet_url.unique()
|
150 |
+
print(f'Starting parallel download of {len(urls)} parquet files.') if verbose else None
|
151 |
+
|
152 |
+
def process_parquet(url):
|
153 |
+
# Create a unique temporary file for each thread
|
154 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=".parquet", dir=local_dir).name
|
155 |
+
|
156 |
+
# identify all relevant rows for this parquet
|
157 |
+
rows = df[df.parquet_url == url].parquet_row.unique()
|
158 |
+
|
159 |
+
max_retries = 3
|
160 |
+
retry_delay = 5 # seconds
|
161 |
+
success = False
|
162 |
+
last_error = None
|
163 |
+
|
164 |
+
for attempt in range(max_retries):
|
165 |
+
try:
|
166 |
+
if not by_row:
|
167 |
+
# Create an opener with a longer timeout
|
168 |
+
opener = urllib.request.build_opener()
|
169 |
+
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
|
170 |
+
urllib.request.install_opener(opener)
|
171 |
+
|
172 |
+
# Download with timeout using urlopen (30 minutes timeout)
|
173 |
+
with urllib.request.urlopen(url, timeout=1800) as response:
|
174 |
+
with open(temp_file, 'wb') as out_file:
|
175 |
+
out_file.write(response.read())
|
176 |
+
temp_path = temp_file
|
177 |
+
else:
|
178 |
+
f = fsspec.open(url)
|
179 |
+
temp_path = f.open()
|
180 |
+
|
181 |
+
# Process the downloaded parquet file
|
182 |
+
try:
|
183 |
+
with pq.ParquetFile(temp_path) as pf:
|
184 |
+
for row_idx in rows:
|
185 |
+
table = pf.read_row_group(row_idx)
|
186 |
+
|
187 |
+
product_id = table['product_id'][0].as_py() if 'product_id' in table.column_names else "id"
|
188 |
+
grid_cell = table['grid_cell'][0].as_py()
|
189 |
+
row = grid_cell.split('_')[0]
|
190 |
+
|
191 |
+
dest = local_dir / Path(f"{source_name}/{row}/{grid_cell}/{product_id}")
|
192 |
+
dest.mkdir(exist_ok=True, parents=True)
|
193 |
+
|
194 |
+
if tif_columns == 'all':
|
195 |
+
columns = [col for col in table.column_names if col[0] == 'B']
|
196 |
+
if source_name in ['Core-S2L1C', 'Core-S2L2A']:
|
197 |
+
columns.append('cloud_mask')
|
198 |
+
elif tif_columns is None:
|
199 |
+
columns = []
|
200 |
+
else:
|
201 |
+
columns = tif_columns
|
202 |
+
|
203 |
+
# Save tifs
|
204 |
+
for col in columns:
|
205 |
+
with open(dest / f"{col}.tif", "wb") as f:
|
206 |
+
f.write(table[col][0].as_py())
|
207 |
+
|
208 |
+
# Save thumbnail
|
209 |
+
with open(dest / "thumbnail.png", "wb") as f:
|
210 |
+
f.write(table['thumbnail'][0].as_py())
|
211 |
+
|
212 |
+
success = True
|
213 |
+
break # Successfully processed the file, exit retry loop
|
214 |
+
|
215 |
+
except Exception as e:
|
216 |
+
last_error = f"Error processing parquet content: {str(e)}"
|
217 |
+
if attempt < max_retries - 1:
|
218 |
+
print(f"Error processing parquet content for {url}, attempt {attempt + 1}/{max_retries}: {str(e)}")
|
219 |
+
time.sleep(retry_delay)
|
220 |
+
continue
|
221 |
+
|
222 |
+
finally:
|
223 |
+
# Cleanup
|
224 |
+
if not by_row:
|
225 |
+
try:
|
226 |
+
os.remove(temp_path)
|
227 |
+
except:
|
228 |
+
pass
|
229 |
+
else:
|
230 |
+
try:
|
231 |
+
f.close()
|
232 |
+
except:
|
233 |
+
pass
|
234 |
+
|
235 |
+
except urllib.error.HTTPError as e:
|
236 |
+
last_error = f"HTTP Error {e.code}: {str(e)}"
|
237 |
+
if e.code == 504 and attempt < max_retries - 1:
|
238 |
+
print(f"Timeout error for {url}, attempt {attempt + 1}/{max_retries}. Retrying in {retry_delay} seconds...")
|
239 |
+
time.sleep(retry_delay)
|
240 |
+
continue
|
241 |
+
except Exception as e:
|
242 |
+
last_error = str(e)
|
243 |
+
if attempt < max_retries - 1:
|
244 |
+
print(f"Error downloading {url}, attempt {attempt + 1}/{max_retries}: {str(e)}")
|
245 |
+
time.sleep(retry_delay)
|
246 |
+
continue
|
247 |
+
|
248 |
+
return {
|
249 |
+
'url': url,
|
250 |
+
'success': success,
|
251 |
+
'error': last_error if not success else None
|
252 |
+
}
|
253 |
+
|
254 |
+
# Use ThreadPoolExecutor for parallel downloads
|
255 |
+
# max_workers = min(len(urls), MAX_WORKERS*4) # Use more workers since it's I/O bound
|
256 |
+
max_workers = min(len(urls), download_workers)
|
257 |
+
print(f"Using {max_workers} workers for parallel downloads") if verbose else None
|
258 |
+
results = []
|
259 |
+
|
260 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
261 |
+
future_to_url = {executor.submit(process_parquet, url): url for url in urls}
|
262 |
+
|
263 |
+
for future in tqdm(
|
264 |
+
futures.as_completed(future_to_url),
|
265 |
+
total=len(urls),
|
266 |
+
desc=f'Downloading {source_name} parquets'
|
267 |
+
):
|
268 |
+
results.append(future.result())
|
269 |
+
|
270 |
+
# Process results and handle failures
|
271 |
+
failed_downloads = [r for r in results if not r['success']]
|
272 |
+
if failed_downloads:
|
273 |
+
print(f"\nWarning: Failed to download {len(failed_downloads)} parquet files for {source_name}")
|
274 |
+
print("\nFailed downloads:")
|
275 |
+
for fail in failed_downloads:
|
276 |
+
print(f"URL: {fail['url']}")
|
277 |
+
print(f"Error: {fail['error']}")
|
278 |
+
print("---")
|
279 |
+
raise RuntimeError(f"Some parquet files failed to download for {source_name}. Please retry the download.")
|
280 |
+
|
281 |
+
print(f"Successfully downloaded and processed {len(urls) - len(failed_downloads)} parquet files for {source_name}")
|
282 |
+
|
283 |
+
|
284 |
+
def my_metadata_from_url(access_url, local_url):
|
285 |
+
local_url, response = urllib.request.urlretrieve(access_url, local_url)
|
286 |
+
df = pq.read_table(local_url).to_pandas()
|
287 |
+
if 'timestamp' in df.columns:
|
288 |
+
df['timestamp'] = pd.to_datetime(df.timestamp)
|
289 |
+
df = fix_crs(df) # Fix CRS typo if present
|
290 |
+
gdf = gpd.GeoDataFrame(
|
291 |
+
df, geometry=gpd.points_from_xy(df.centre_lon, df.centre_lat), crs=df.crs.iloc[0]
|
292 |
+
)
|
293 |
+
return gdf
|
294 |
+
|
295 |
+
def get_metadata(source: str, output_dir: Path) -> gpd.GeoDataFrame:
|
296 |
+
"""Fetch metadata from HuggingFace dataset for a specific source"""
|
297 |
+
access_url = f"https://huggingface.co/datasets/Major-TOM/{source}/resolve/main/metadata.parquet?download=true"
|
298 |
+
local_url = output_dir / source / "metadata.parquet"
|
299 |
+
local_url.parent.mkdir(exist_ok=True, parents=True)
|
300 |
+
|
301 |
+
if local_url.exists():
|
302 |
+
print(f"Using cached metadata for {source}")
|
303 |
+
df = pq.read_table(local_url).to_pandas()
|
304 |
+
if 'timestamp' in df.columns:
|
305 |
+
df['timestamp'] = pd.to_datetime(df.timestamp)
|
306 |
+
df = fix_crs(df)
|
307 |
+
gdf = gpd.GeoDataFrame(
|
308 |
+
df, geometry=gpd.points_from_xy(df.centre_lon, df.centre_lat), crs=df.crs.iloc[0]
|
309 |
+
)
|
310 |
+
else:
|
311 |
+
print(f"Downloading metadata for {source}...")
|
312 |
+
gdf = my_metadata_from_url(access_url, local_url)
|
313 |
+
|
314 |
+
return gdf
|
315 |
+
|
316 |
+
|
317 |
+
def filter_data(gdf, bbox, cloud_cover, date_range):
|
318 |
+
"""Filter metadata based on given parameters"""
|
319 |
+
region = box(*bbox)
|
320 |
+
return my_filter_metadata(
|
321 |
+
gdf,
|
322 |
+
cloud_cover=cloud_cover,
|
323 |
+
region=region,
|
324 |
+
daterange=date_range,
|
325 |
+
nodata=(0.0, 0.0)
|
326 |
+
)
|
327 |
+
|
328 |
+
|
329 |
+
def find_common_samples(filtered_dfs: Dict[str, gpd.GeoDataFrame]) -> Dict[str, gpd.GeoDataFrame]:
|
330 |
+
"""Find samples that share common grid cells across all datasets"""
|
331 |
+
# Create sets of grid_cells for each dataset
|
332 |
+
grid_cell_sets = {
|
333 |
+
source: set(df['grid_cell'].unique())
|
334 |
+
for source, df in filtered_dfs.items()
|
335 |
+
}
|
336 |
+
|
337 |
+
# Find intersection of all grid cell sets
|
338 |
+
common_grid_cells = set.intersection(*grid_cell_sets.values())
|
339 |
+
print(f"\033[92mFound {len(common_grid_cells)} common grid cells across all sources\033[0m")
|
340 |
+
|
341 |
+
# Filter dataframes to keep only rows with common grid cells
|
342 |
+
filtered_common = {}
|
343 |
+
for source, df in filtered_dfs.items():
|
344 |
+
filtered_common[source] = df[df['grid_cell'].isin(common_grid_cells)]
|
345 |
+
print(f"{source}: {len(filtered_common[source])} samples for common grid cells")
|
346 |
+
|
347 |
+
return filtered_common
|
348 |
+
|
349 |
+
|
350 |
+
def download_source_files(df: gpd.GeoDataFrame, output_dir: Path, source: str, mode: str = 'full', delete_parquets: bool = False, download_workers: int = DEFAULT_DOWNLOAD_WORKERS, revalidate: bool = False):
|
351 |
+
"""Download files for a specific source"""
|
352 |
+
print(f"Processing files for {source}...")
|
353 |
+
|
354 |
+
if mode == 'download':
|
355 |
+
# Only download parquet files without extracting
|
356 |
+
download_parquet_files(
|
357 |
+
df,
|
358 |
+
local_dir=output_dir,
|
359 |
+
source_name=source,
|
360 |
+
download_workers=download_workers,
|
361 |
+
revalidate=revalidate,
|
362 |
+
verbose=True
|
363 |
+
)
|
364 |
+
elif mode == 'extract':
|
365 |
+
# Extract data from already downloaded parquet files
|
366 |
+
extract_from_parquet_files(
|
367 |
+
df,
|
368 |
+
local_dir=output_dir,
|
369 |
+
source_name=source,
|
370 |
+
delete_parquets=delete_parquets,
|
371 |
+
verbose=True,
|
372 |
+
tif_columns=CONTENT[source]
|
373 |
+
)
|
374 |
+
else: # mode == 'full'
|
375 |
+
# Use the original function for backwards compatibility
|
376 |
+
my_filter_download(
|
377 |
+
df,
|
378 |
+
local_dir=output_dir,
|
379 |
+
source_name=source,
|
380 |
+
by_row=False,
|
381 |
+
verbose=True,
|
382 |
+
tif_columns=CONTENT[source],
|
383 |
+
download_workers=download_workers
|
384 |
+
)
|
385 |
+
|
386 |
+
def get_and_filter_source(args, source: str, data_dir: Path) -> gpd.GeoDataFrame:
|
387 |
+
"""Process a single source: get metadata and filter it"""
|
388 |
+
source_dir = data_dir / source
|
389 |
+
source_dir.mkdir(exist_ok=True, parents=True)
|
390 |
+
|
391 |
+
# Get and filter metadata for each source
|
392 |
+
gdf = get_metadata(source, data_dir)
|
393 |
+
|
394 |
+
# Only apply cloud cover filter for Sentinel-2 sources
|
395 |
+
cloud_cover_filter = tuple(args.cloud_cover) if source.startswith('Core-S2') else None
|
396 |
+
|
397 |
+
filtered_df = filter_data(
|
398 |
+
gdf,
|
399 |
+
bbox=args.bbox,
|
400 |
+
cloud_cover=cloud_cover_filter,
|
401 |
+
date_range=(args.start_date, args.end_date)
|
402 |
+
)
|
403 |
+
print(f"Found {len(filtered_df)} samples for {source} in the specified region")
|
404 |
+
return filtered_df
|
405 |
+
|
406 |
+
def download_source_parallel(source_df_tuple: tuple, subset_dir: Path, mode: str = 'full', delete_parquets: bool = False, download_workers: int = DEFAULT_DOWNLOAD_WORKERS, revalidate: bool = False):
|
407 |
+
"""Download files for a source sequentially, with resume capability"""
|
408 |
+
source, df = source_df_tuple
|
409 |
+
source_subset_dir = subset_dir / source
|
410 |
+
source_subset_dir.mkdir(exist_ok=True, parents=True)
|
411 |
+
|
412 |
+
# Save filtered metadata
|
413 |
+
metadata_path = source_subset_dir / "metadata.parquet"
|
414 |
+
df.to_parquet(metadata_path)
|
415 |
+
print(f"Saved filtered metadata for {source} to {metadata_path}")
|
416 |
+
|
417 |
+
# If we're only downloading parquet files, we don't need to check for existing tif files
|
418 |
+
if mode == 'download':
|
419 |
+
download_source_files(df, subset_dir, source, mode=mode, delete_parquets=delete_parquets, download_workers=download_workers, revalidate=revalidate)
|
420 |
+
print(f"Completed parquet downloads for {source}")
|
421 |
+
return
|
422 |
+
|
423 |
+
# If we're extracting and the extraction metadata exists, we don't need to check for existing files
|
424 |
+
parquet_dir = subset_dir / source / "parquets"
|
425 |
+
extraction_file = parquet_dir / "extraction_metadata.parquet"
|
426 |
+
filtered_df_file = parquet_dir / "filtered_df.parquet"
|
427 |
+
|
428 |
+
if mode == 'extract' and extraction_file.exists() and filtered_df_file.exists():
|
429 |
+
print(f"Using saved extraction metadata for {source}")
|
430 |
+
download_source_files(df, subset_dir, source, mode=mode, delete_parquets=delete_parquets, download_workers=download_workers, revalidate=revalidate)
|
431 |
+
print(f"Completed extraction for {source}")
|
432 |
+
return
|
433 |
+
|
434 |
+
# Filter out already processed grid cells more efficiently
|
435 |
+
def get_existing_files(df, subset_dir, source):
|
436 |
+
# Create all possible paths
|
437 |
+
# For DEM, use 'id' as product_id, for other sources use actual product_id
|
438 |
+
product_ids = df['product_id'] if 'product_id' in df.columns else pd.Series(['id'] * len(df))
|
439 |
+
grid_cells = df['grid_cell']
|
440 |
+
row_dirs = grid_cells.str.split('_').str[0]
|
441 |
+
|
442 |
+
# Vectorized path creation
|
443 |
+
paths = [
|
444 |
+
subset_dir / source / row_dir / grid_cell / product_id / "thumbnail.png"
|
445 |
+
for row_dir, grid_cell, product_id in zip(row_dirs, grid_cells, product_ids)
|
446 |
+
]
|
447 |
+
|
448 |
+
# Batch existence check
|
449 |
+
exists_mask = [path.exists() for path in tqdm(paths, desc=f"Checking existing files for {source}", unit="file")]
|
450 |
+
return pd.Series(exists_mask, index=df.index)
|
451 |
+
|
452 |
+
# Create mask of unprocessed files
|
453 |
+
exists_mask = get_existing_files(df, subset_dir, source)
|
454 |
+
df_to_process = df[~exists_mask]
|
455 |
+
|
456 |
+
if len(df_to_process) == 0:
|
457 |
+
print(f"All files for {source} are already processed. Skipping.")
|
458 |
+
return
|
459 |
+
|
460 |
+
print(f"Found {len(df) - len(df_to_process)} already processed files")
|
461 |
+
print(f"Processing remaining {len(df_to_process)} files for {source}...")
|
462 |
+
|
463 |
+
# Process the remaining data files
|
464 |
+
download_source_files(df_to_process, subset_dir, source, mode=mode, delete_parquets=delete_parquets, download_workers=download_workers, revalidate=revalidate)
|
465 |
+
print(f"Completed processing for {source}")
|
466 |
+
|
467 |
+
def remove_duplicates(common_dfs: Dict[str, gpd.GeoDataFrame],
|
468 |
+
criteria: str = None) -> Dict[str, gpd.GeoDataFrame]:
|
469 |
+
"""Remove duplicates from common dataframes based on source-specific relevant columns."""
|
470 |
+
for source, df in common_dfs.items():
|
471 |
+
num_rows = len(df)
|
472 |
+
|
473 |
+
if 'timestamp' in df.columns:
|
474 |
+
if criteria == "latest":
|
475 |
+
# Sort by timestamp and keep the latest
|
476 |
+
df = df.sort_values(by='timestamp', ascending=False)
|
477 |
+
elif criteria == None:
|
478 |
+
raise ValueError("Please, specify a criteria for deduplication. Currently we do not support multiple timestamps for the same grid_cell.")
|
479 |
+
else:
|
480 |
+
raise ValueError("Criteria not supported")
|
481 |
+
|
482 |
+
# TODO:
|
483 |
+
# Product_id includes the timestamp.
|
484 |
+
# We ignore one of the two orbit_states to avoid duplicates.
|
485 |
+
# We can also ignore cloud_cover since we have already filtered by cloud_cover
|
486 |
+
# We also ignore crs. Apparently, there are rows that are entirely duplicates except for the crs (? wierd)
|
487 |
+
# We also ignore centre_lat and centre_lon since not always are aligned
|
488 |
+
subset_columns = [col for col in df.columns if col not in [
|
489 |
+
'parquet_row', 'parquet_url', 'geometry', 'timestamp', 'product_id',
|
490 |
+
'orbit_state', 'cloud_cover', 'crs', 'centre_lat', 'centre_lon'
|
491 |
+
]]
|
492 |
+
df = df.drop_duplicates(subset=subset_columns)
|
493 |
+
|
494 |
+
# Verify no remaining duplicates in grid_cell
|
495 |
+
if df['grid_cell'].duplicated().any():
|
496 |
+
print(df[df['grid_cell'].duplicated()])
|
497 |
+
raise ValueError(f"Found rows with duplicate grid_cells but different values in source {source}")
|
498 |
+
|
499 |
+
common_dfs[source] = df
|
500 |
+
print(f"\033[94mDropped {num_rows - len(df)} duplicates from {source}\033[0m")
|
501 |
+
|
502 |
+
return common_dfs
|
503 |
+
|
504 |
+
def sample_common_dfs(common_dfs: Dict[str, gpd.GeoDataFrame], n_samples: int, seed: int) -> Dict[str, gpd.GeoDataFrame]:
|
505 |
+
"""Sample common dataframes to have n_samples samples per source"""
|
506 |
+
# Get all unique grid cells that appear in all dataframes
|
507 |
+
grid_cells_sets = [set(df['grid_cell'].unique()) for df in common_dfs.values()]
|
508 |
+
all_grid_cells = list(set.intersection(*grid_cells_sets))
|
509 |
+
if not all_grid_cells:
|
510 |
+
raise ValueError("No common grid cells found across all sources")
|
511 |
+
|
512 |
+
# Sort grid cells for reproducibility before sampling
|
513 |
+
all_grid_cells.sort()
|
514 |
+
|
515 |
+
# Randomly sample grid cells
|
516 |
+
random.seed(seed)
|
517 |
+
sampled_grid_cells = set(random.sample(all_grid_cells, min(n_samples, len(all_grid_cells))))
|
518 |
+
|
519 |
+
# Filter each dataframe to only include the sampled grid cells
|
520 |
+
result = {}
|
521 |
+
for source, df in common_dfs.items():
|
522 |
+
result[source] = df[df['grid_cell'].isin(sampled_grid_cells)]
|
523 |
+
print(f"Sampled {len(result[source])} rows for {source}")
|
524 |
+
|
525 |
+
return result
|
526 |
+
|
527 |
+
def is_valid_parquet(parquet_path):
|
528 |
+
"""
|
529 |
+
Checks if a parquet file is valid and not empty.
|
530 |
+
|
531 |
+
Args:
|
532 |
+
parquet_path: Path to the parquet file
|
533 |
+
|
534 |
+
Returns:
|
535 |
+
bool: True if the parquet file is valid, False otherwise
|
536 |
+
"""
|
537 |
+
try:
|
538 |
+
# Check if file exists and has a non-zero size (not empty)
|
539 |
+
if not os.path.exists(parquet_path) or os.path.getsize(parquet_path) == 0:
|
540 |
+
return False
|
541 |
+
|
542 |
+
# Try to open and read metadata from the parquet file
|
543 |
+
with pq.ParquetFile(parquet_path) as pf:
|
544 |
+
# Check if there's at least one row group
|
545 |
+
if pf.num_row_groups == 0:
|
546 |
+
return False
|
547 |
+
|
548 |
+
# Try to read metadata of the first row group to verify basic integrity
|
549 |
+
pf.metadata
|
550 |
+
|
551 |
+
# Optionally, try reading a small sample of data to further verify
|
552 |
+
table = pf.read_row_group(0, columns=['grid_cell'])
|
553 |
+
|
554 |
+
return True
|
555 |
+
except Exception as e:
|
556 |
+
print(f"Error validating parquet file {parquet_path}: {str(e)}")
|
557 |
+
return False
|
558 |
+
|
559 |
+
def download_parquet_files(df, local_dir, source_name, download_workers=DEFAULT_DOWNLOAD_WORKERS, revalidate=False, verbose=False):
|
560 |
+
"""Downloads only the parquet files without extracting data, saving them to disk"""
|
561 |
+
if isinstance(local_dir, str):
|
562 |
+
local_dir = Path(local_dir)
|
563 |
+
|
564 |
+
# Create a directory to store parquet files
|
565 |
+
parquet_dir = local_dir / source_name / "parquets"
|
566 |
+
parquet_dir.mkdir(exist_ok=True, parents=True)
|
567 |
+
|
568 |
+
# Identify all parquets that need to be downloaded
|
569 |
+
urls = df.parquet_url.unique()
|
570 |
+
print(f'Starting parallel download of {len(urls)} parquet files.') if verbose else None
|
571 |
+
|
572 |
+
def download_parquet(url):
|
573 |
+
# Get the filename from the URL
|
574 |
+
filename = url.split('/')[-1].split('?')[0]
|
575 |
+
parquet_path = parquet_dir / filename
|
576 |
+
|
577 |
+
# Skip if file already exists and is valid (and we're not forcing revalidation)
|
578 |
+
if parquet_path.exists() and not revalidate:
|
579 |
+
if is_valid_parquet(parquet_path):
|
580 |
+
return {
|
581 |
+
'url': url,
|
582 |
+
'path': parquet_path,
|
583 |
+
'success': True,
|
584 |
+
'error': None,
|
585 |
+
'skipped': True
|
586 |
+
}
|
587 |
+
else:
|
588 |
+
# File exists but is corrupted or empty, delete it for redownload
|
589 |
+
print(f"Found corrupted or invalid parquet file: {parquet_path}. Will redownload.")
|
590 |
+
try:
|
591 |
+
os.remove(parquet_path)
|
592 |
+
except Exception as e:
|
593 |
+
print(f"Warning: Failed to delete corrupted file {parquet_path}: {str(e)}")
|
594 |
+
elif parquet_path.exists() and revalidate:
|
595 |
+
# If we're revalidating, check the file and delete if invalid
|
596 |
+
if not is_valid_parquet(parquet_path):
|
597 |
+
print(f"Revalidation: Found corrupted parquet file: {parquet_path}. Will redownload.")
|
598 |
+
try:
|
599 |
+
os.remove(parquet_path)
|
600 |
+
except Exception as e:
|
601 |
+
print(f"Warning: Failed to delete corrupted file {parquet_path}: {str(e)}")
|
602 |
+
else:
|
603 |
+
# File is valid, skip download
|
604 |
+
print(f"Revalidation: Confirmed valid parquet file: {parquet_path}")
|
605 |
+
return {
|
606 |
+
'url': url,
|
607 |
+
'path': parquet_path,
|
608 |
+
'success': True,
|
609 |
+
'error': None,
|
610 |
+
'skipped': True
|
611 |
+
}
|
612 |
+
|
613 |
+
max_retries = 3
|
614 |
+
retry_delay = 5 # seconds
|
615 |
+
success = False
|
616 |
+
last_error = None
|
617 |
+
|
618 |
+
for attempt in range(max_retries):
|
619 |
+
try:
|
620 |
+
# Create an opener with a longer timeout
|
621 |
+
opener = urllib.request.build_opener()
|
622 |
+
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
|
623 |
+
urllib.request.install_opener(opener)
|
624 |
+
|
625 |
+
# Download with timeout using urlopen (30 minutes timeout)
|
626 |
+
with urllib.request.urlopen(url, timeout=1800) as response:
|
627 |
+
with open(parquet_path, 'wb') as out_file:
|
628 |
+
out_file.write(response.read())
|
629 |
+
|
630 |
+
# Verify the downloaded file is valid
|
631 |
+
if not is_valid_parquet(parquet_path):
|
632 |
+
last_error = "Downloaded file is corrupted or invalid"
|
633 |
+
if attempt < max_retries - 1:
|
634 |
+
print(f"Error: Downloaded parquet file is corrupted, attempt {attempt + 1}/{max_retries}. Retrying...")
|
635 |
+
os.remove(parquet_path)
|
636 |
+
time.sleep(retry_delay)
|
637 |
+
continue
|
638 |
+
|
639 |
+
success = True
|
640 |
+
break # Successfully downloaded, exit retry loop
|
641 |
+
|
642 |
+
except urllib.error.HTTPError as e:
|
643 |
+
last_error = f"HTTP Error {e.code}: {str(e)}"
|
644 |
+
if e.code == 504 and attempt < max_retries - 1:
|
645 |
+
print(f"Timeout error for {url}, attempt {attempt + 1}/{max_retries}. Retrying in {retry_delay} seconds...")
|
646 |
+
time.sleep(retry_delay)
|
647 |
+
continue
|
648 |
+
except Exception as e:
|
649 |
+
last_error = str(e)
|
650 |
+
if attempt < max_retries - 1:
|
651 |
+
print(f"Error downloading {url}, attempt {attempt + 1}/{max_retries}: {str(e)}")
|
652 |
+
time.sleep(retry_delay)
|
653 |
+
continue
|
654 |
+
# Make sure the file is deleted if it was partially downloaded
|
655 |
+
if parquet_path.exists():
|
656 |
+
try:
|
657 |
+
os.remove(parquet_path)
|
658 |
+
except:
|
659 |
+
pass
|
660 |
+
|
661 |
+
return {
|
662 |
+
'url': url,
|
663 |
+
'path': parquet_path if success else None,
|
664 |
+
'success': success,
|
665 |
+
'error': last_error if not success else None,
|
666 |
+
'skipped': False
|
667 |
+
}
|
668 |
+
|
669 |
+
# Use ThreadPoolExecutor for parallel downloads with the specified number of workers
|
670 |
+
max_workers = min(len(urls), download_workers)
|
671 |
+
print(f"Using {max_workers} workers for parallel downloads") if verbose else None
|
672 |
+
results = []
|
673 |
+
|
674 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
675 |
+
future_to_url = {executor.submit(download_parquet, url): url for url in urls}
|
676 |
+
|
677 |
+
for future in tqdm(
|
678 |
+
futures.as_completed(future_to_url),
|
679 |
+
total=len(urls),
|
680 |
+
desc=f'Downloading {source_name} parquets'
|
681 |
+
):
|
682 |
+
results.append(future.result())
|
683 |
+
|
684 |
+
# Process results and handle failures
|
685 |
+
failed_downloads = [r for r in results if not r['success']]
|
686 |
+
skipped_downloads = [r for r in results if r['skipped']]
|
687 |
+
|
688 |
+
if failed_downloads:
|
689 |
+
print(f"\nWarning: Failed to download {len(failed_downloads)} parquet files for {source_name}")
|
690 |
+
print("\nFailed downloads:")
|
691 |
+
for fail in failed_downloads:
|
692 |
+
print(f"URL: {fail['url']}")
|
693 |
+
print(f"Error: {fail['error']}")
|
694 |
+
print("---")
|
695 |
+
raise RuntimeError(f"Some parquet files failed to download for {source_name}. Please retry the download.")
|
696 |
+
|
697 |
+
print(f"Successfully downloaded {len(results) - len(failed_downloads) - len(skipped_downloads)} parquet files for {source_name}")
|
698 |
+
print(f"Skipped {len(skipped_downloads)} valid existing parquet files")
|
699 |
+
|
700 |
+
# Create a mapping from URLs to local file paths
|
701 |
+
url_to_path = {r['url']: r['path'] for r in results if r['success']}
|
702 |
+
|
703 |
+
# Save the URL to path mapping
|
704 |
+
mapping_file = parquet_dir / "url_to_path.parquet"
|
705 |
+
mapping_df = pd.DataFrame({
|
706 |
+
'url': list(url_to_path.keys()),
|
707 |
+
'path': [str(path) for path in url_to_path.values()]
|
708 |
+
})
|
709 |
+
mapping_df.to_parquet(mapping_file)
|
710 |
+
|
711 |
+
# Save the extraction metadata - creating a dataframe that maps URLs to the rows that need to be extracted
|
712 |
+
extraction_meta = []
|
713 |
+
for url in urls:
|
714 |
+
rows = df[df.parquet_url == url].parquet_row.unique()
|
715 |
+
for row in rows:
|
716 |
+
extraction_meta.append({
|
717 |
+
'url': url,
|
718 |
+
'row': row
|
719 |
+
})
|
720 |
+
|
721 |
+
extraction_df = pd.DataFrame(extraction_meta)
|
722 |
+
extraction_file = parquet_dir / "extraction_metadata.parquet"
|
723 |
+
extraction_df.to_parquet(extraction_file)
|
724 |
+
|
725 |
+
# Also save the full filtered dataframe for reference
|
726 |
+
filtered_df_file = parquet_dir / "filtered_df.parquet"
|
727 |
+
df.to_parquet(filtered_df_file)
|
728 |
+
|
729 |
+
print(f"Saved extraction metadata for {len(extraction_meta)} rows across {len(urls)} parquet files")
|
730 |
+
|
731 |
+
return url_to_path
|
732 |
+
|
733 |
+
def extract_from_parquet_files(df, local_dir, source_name, delete_parquets=False, verbose=False, tif_columns=None):
|
734 |
+
"""Extracts data from already downloaded parquet files"""
|
735 |
+
if isinstance(local_dir, str):
|
736 |
+
local_dir = Path(local_dir)
|
737 |
+
|
738 |
+
# Path to the directory where parquet files are stored
|
739 |
+
parquet_dir = local_dir / source_name / "parquets"
|
740 |
+
|
741 |
+
# Check if the URL to path mapping exists
|
742 |
+
mapping_file = parquet_dir / "url_to_path.parquet"
|
743 |
+
if not mapping_file.exists():
|
744 |
+
raise FileNotFoundError(f"URL to path mapping file not found at {mapping_file}. Please download parquet files first.")
|
745 |
+
|
746 |
+
# Load the URL to path mapping
|
747 |
+
mapping_df = pd.read_parquet(mapping_file)
|
748 |
+
url_to_path = dict(zip(mapping_df['url'], mapping_df['path']))
|
749 |
+
|
750 |
+
# Try to load the extraction metadata if it exists, otherwise use the provided dataframe
|
751 |
+
extraction_file = parquet_dir / "extraction_metadata.parquet"
|
752 |
+
filtered_df_file = parquet_dir / "filtered_df.parquet"
|
753 |
+
|
754 |
+
if extraction_file.exists() and filtered_df_file.exists():
|
755 |
+
print("Using saved extraction metadata")
|
756 |
+
extraction_df = pd.read_parquet(extraction_file)
|
757 |
+
|
758 |
+
# We need to load the original filtered df to get all the metadata
|
759 |
+
saved_df = pd.read_parquet(filtered_df_file)
|
760 |
+
|
761 |
+
# If a specific subset of df was provided, filter extraction_df to only those URLs
|
762 |
+
if df is not None:
|
763 |
+
urls_to_extract = df.parquet_url.unique()
|
764 |
+
extraction_df = extraction_df[extraction_df['url'].isin(urls_to_extract)]
|
765 |
+
saved_df = saved_df[saved_df.parquet_url.isin(urls_to_extract)]
|
766 |
+
|
767 |
+
# Replace the input df with the saved one
|
768 |
+
df = saved_df
|
769 |
+
else:
|
770 |
+
# If no saved metadata, create extraction_df from the provided df
|
771 |
+
print("No saved extraction metadata found, using provided dataframe")
|
772 |
+
extraction_df = []
|
773 |
+
for url in df.parquet_url.unique():
|
774 |
+
rows = df[df.parquet_url == url].parquet_row.unique()
|
775 |
+
for row in rows:
|
776 |
+
extraction_df.append({
|
777 |
+
'url': url,
|
778 |
+
'row': row
|
779 |
+
})
|
780 |
+
extraction_df = pd.DataFrame(extraction_df)
|
781 |
+
|
782 |
+
# Get all unique URLs that need to be processed
|
783 |
+
urls = extraction_df['url'].unique()
|
784 |
+
print(f'Starting extraction from {len(urls)} parquet files.') if verbose else None
|
785 |
+
|
786 |
+
# Check if all required parquet files exist and are valid
|
787 |
+
missing_or_invalid_urls = []
|
788 |
+
for url in urls:
|
789 |
+
if url not in url_to_path:
|
790 |
+
missing_or_invalid_urls.append((url, "Missing"))
|
791 |
+
elif not is_valid_parquet(url_to_path[url]):
|
792 |
+
missing_or_invalid_urls.append((url, "Invalid/Corrupted"))
|
793 |
+
|
794 |
+
if missing_or_invalid_urls:
|
795 |
+
print(f"Warning: {len(missing_or_invalid_urls)} parquet files are missing or corrupted. Please download them first.")
|
796 |
+
print("Issues with URLs:")
|
797 |
+
for url, issue in missing_or_invalid_urls[:5]: # Show first 5 problem URLs
|
798 |
+
print(f" {url} - {issue}")
|
799 |
+
if len(missing_or_invalid_urls) > 5:
|
800 |
+
print(f" ... and {len(missing_or_invalid_urls) - 5} more")
|
801 |
+
raise FileNotFoundError("Some required parquet files are missing or corrupted. Please run the download step again.")
|
802 |
+
|
803 |
+
def process_parquet(url):
|
804 |
+
# Get the local path of the parquet file
|
805 |
+
parquet_path = url_to_path[url]
|
806 |
+
|
807 |
+
# Get the rows in this parquet file that we need to extract
|
808 |
+
rows = extraction_df[extraction_df['url'] == url]['row'].unique()
|
809 |
+
|
810 |
+
success = False
|
811 |
+
last_error = None
|
812 |
+
|
813 |
+
try:
|
814 |
+
with pq.ParquetFile(parquet_path) as pf:
|
815 |
+
for row_idx in rows:
|
816 |
+
table = pf.read_row_group(row_idx)
|
817 |
+
|
818 |
+
product_id = table['product_id'][0].as_py() if 'product_id' in table.column_names else "id"
|
819 |
+
grid_cell = table['grid_cell'][0].as_py()
|
820 |
+
row = grid_cell.split('_')[0]
|
821 |
+
|
822 |
+
dest = local_dir / Path(f"{source_name}/{row}/{grid_cell}/{product_id}")
|
823 |
+
dest.mkdir(exist_ok=True, parents=True)
|
824 |
+
|
825 |
+
if tif_columns == 'all':
|
826 |
+
columns = [col for col in table.column_names if col[0] == 'B']
|
827 |
+
if source_name in ['Core-S2L1C', 'Core-S2L2A']:
|
828 |
+
columns.append('cloud_mask')
|
829 |
+
elif tif_columns is None:
|
830 |
+
columns = []
|
831 |
+
else:
|
832 |
+
columns = tif_columns
|
833 |
+
|
834 |
+
# Save tifs
|
835 |
+
for col in columns:
|
836 |
+
with open(dest / f"{col}.tif", "wb") as f:
|
837 |
+
f.write(table[col][0].as_py())
|
838 |
+
|
839 |
+
# Save thumbnail
|
840 |
+
with open(dest / "thumbnail.png", "wb") as f:
|
841 |
+
f.write(table['thumbnail'][0].as_py())
|
842 |
+
|
843 |
+
success = True
|
844 |
+
|
845 |
+
# Delete the parquet file if requested
|
846 |
+
if delete_parquets:
|
847 |
+
try:
|
848 |
+
os.remove(parquet_path)
|
849 |
+
except Exception as e:
|
850 |
+
print(f"Warning: Failed to delete parquet file {parquet_path}: {str(e)}")
|
851 |
+
|
852 |
+
except Exception as e:
|
853 |
+
last_error = str(e)
|
854 |
+
|
855 |
+
return {
|
856 |
+
'url': url,
|
857 |
+
'path': parquet_path,
|
858 |
+
'success': success,
|
859 |
+
'error': last_error if not success else None
|
860 |
+
}
|
861 |
+
|
862 |
+
# Use ThreadPoolExecutor for parallel processing
|
863 |
+
max_workers = min(len(urls), MAX_WORKERS)
|
864 |
+
results = []
|
865 |
+
|
866 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
867 |
+
future_to_url = {executor.submit(process_parquet, url): url for url in urls}
|
868 |
+
|
869 |
+
for future in tqdm(
|
870 |
+
futures.as_completed(future_to_url),
|
871 |
+
total=len(urls),
|
872 |
+
desc=f'Extracting from {source_name} parquets'
|
873 |
+
):
|
874 |
+
results.append(future.result())
|
875 |
+
|
876 |
+
# Process results and handle failures
|
877 |
+
failed_extractions = [r for r in results if not r['success']]
|
878 |
+
if failed_extractions:
|
879 |
+
print(f"\nWarning: Failed to extract from {len(failed_extractions)} parquet files for {source_name}")
|
880 |
+
print("\nFailed extractions:")
|
881 |
+
for fail in failed_extractions:
|
882 |
+
print(f"URL: {fail['url']}")
|
883 |
+
print(f"Path: {fail['path']}")
|
884 |
+
print(f"Error: {fail['error']}")
|
885 |
+
print("---")
|
886 |
+
raise RuntimeError(f"Some parquet extractions failed for {source_name}.")
|
887 |
+
|
888 |
+
print(f"Successfully extracted data from {len(results) - len(failed_extractions)} parquet files for {source_name}")
|
889 |
+
|
890 |
+
# Clean up the metadata files if all parquet files were deleted
|
891 |
+
if delete_parquets and not any(os.path.exists(r['path']) for r in results):
|
892 |
+
try:
|
893 |
+
# Delete all metadata files
|
894 |
+
for meta_file in [mapping_file, extraction_file, filtered_df_file]:
|
895 |
+
if meta_file.exists():
|
896 |
+
os.remove(meta_file)
|
897 |
+
|
898 |
+
# Try to remove the parquets directory if it's empty
|
899 |
+
if os.path.exists(parquet_dir) and not os.listdir(parquet_dir):
|
900 |
+
os.rmdir(parquet_dir)
|
901 |
+
|
902 |
+
print(f"Cleaned up metadata files and directory for {source_name}")
|
903 |
+
except Exception as e:
|
904 |
+
print(f"Warning: Failed to clean up metadata files or directory: {str(e)}")
|
905 |
+
|
906 |
+
return results
|
907 |
+
|
908 |
+
def main():
|
909 |
+
args = parse_args()
|
910 |
+
logging.basicConfig(level=logging.INFO)
|
911 |
+
|
912 |
+
data_dir = Path(args.data_dir)
|
913 |
+
subset_dir = data_dir / args.subset_name
|
914 |
+
|
915 |
+
# Always process metadata and filtering for all modes
|
916 |
+
print("\033[92mFetching and filtering metadata...\033[0m")
|
917 |
+
|
918 |
+
# Parallel processing of metadata fetching and filtering
|
919 |
+
max_workers = min(len(args.sources), MAX_WORKERS)
|
920 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
921 |
+
future_to_source = {
|
922 |
+
executor.submit(get_and_filter_source, args, source, data_dir): source
|
923 |
+
for source in args.sources
|
924 |
+
}
|
925 |
+
|
926 |
+
# Collect results while maintaining order
|
927 |
+
filtered_dfs = {}
|
928 |
+
for future in futures.as_completed(future_to_source):
|
929 |
+
source = future_to_source[future]
|
930 |
+
try:
|
931 |
+
filtered_dfs[source] = future.result()
|
932 |
+
except Exception as e:
|
933 |
+
print(f"Error processing {source}: {e}")
|
934 |
+
raise e
|
935 |
+
|
936 |
+
# Synchronization point: find common samples across all sources
|
937 |
+
common_dfs = find_common_samples(filtered_dfs)
|
938 |
+
|
939 |
+
# Remove duplicates for each of the common_dfs
|
940 |
+
common_dfs = remove_duplicates(common_dfs, criteria=args.criteria)
|
941 |
+
|
942 |
+
# After removing duplicates, print the number of samples for each source
|
943 |
+
print("\033[92mAfter removing duplicates:\033[0m")
|
944 |
+
for source, df in common_dfs.items():
|
945 |
+
print(f"{source}: {len(df)} samples for common grid cells")
|
946 |
+
|
947 |
+
if args.preview:
|
948 |
+
return
|
949 |
+
|
950 |
+
if args.n_samples is not None: # Else, we download all samples.
|
951 |
+
print(f"Sampling {args.n_samples} samples per source...")
|
952 |
+
common_dfs = sample_common_dfs(common_dfs, args.n_samples, args.seed)
|
953 |
+
print(f"Done sampling {args.n_samples} grid cells per source!")
|
954 |
+
|
955 |
+
# Remove Core-DEM from common_dfs, because it is already downloaded.
|
956 |
+
# Comment / Uncomment when needed.
|
957 |
+
# common_dfs.pop('Core-DEM')
|
958 |
+
# common_dfs.pop('Core-S1RTC')
|
959 |
+
# common_dfs.pop('Core-S2L1C')
|
960 |
+
# common_dfs.pop('Core-S2L2A')
|
961 |
+
print(f"We will only process the following modalities: {list(common_dfs.keys())}")
|
962 |
+
|
963 |
+
# Print information about download workers
|
964 |
+
if args.mode in ['download', 'full']:
|
965 |
+
print(f"\033[94mUsing {args.download_workers} workers for parallel downloads\033[0m")
|
966 |
+
print("If downloads are slow, try reducing this number with the --download-workers parameter")
|
967 |
+
|
968 |
+
if args.revalidate:
|
969 |
+
print("\033[94mRevalidating all parquet files (will check for corrupted files)\033[0m")
|
970 |
+
else:
|
971 |
+
print("Use --revalidate to force checking of existing parquet files for corruption")
|
972 |
+
|
973 |
+
# Execute the appropriate action based on mode
|
974 |
+
if args.mode == 'download':
|
975 |
+
print("\033[92mStarting download of parquet files...\033[0m")
|
976 |
+
for source, df in common_dfs.items():
|
977 |
+
print(f"\033[94mDownloading parquets for modality: {source}\033[0m")
|
978 |
+
download_source_parallel((source, df), subset_dir, mode='download',
|
979 |
+
delete_parquets=args.delete_parquets,
|
980 |
+
download_workers=args.download_workers,
|
981 |
+
revalidate=args.revalidate)
|
982 |
+
print("\033[92mParquet file download complete.\033[0m")
|
983 |
+
print("To extract data from these parquet files, run this script with --mode extract")
|
984 |
+
|
985 |
+
elif args.mode == 'extract':
|
986 |
+
print("\033[92mStarting extraction from parquet files...\033[0m")
|
987 |
+
for source, df in common_dfs.items():
|
988 |
+
print(f"\033[94mExtracting data for modality: {source}\033[0m")
|
989 |
+
download_source_parallel((source, df), subset_dir, mode='extract',
|
990 |
+
delete_parquets=args.delete_parquets,
|
991 |
+
download_workers=args.download_workers,
|
992 |
+
revalidate=args.revalidate)
|
993 |
+
print("\033[92mData extraction complete.\033[0m")
|
994 |
+
if args.delete_parquets:
|
995 |
+
print("Parquet files have been deleted.")
|
996 |
+
else:
|
997 |
+
print("To delete the parquet files, run this script with --mode extract --delete-parquets")
|
998 |
+
|
999 |
+
else: # mode == 'full'
|
1000 |
+
print("\033[92mStarting full download and extraction process...\033[0m")
|
1001 |
+
for source, df in common_dfs.items():
|
1002 |
+
print(f"\033[94mProcessing modality: {source}\033[0m")
|
1003 |
+
download_source_parallel((source, df), subset_dir, mode='full',
|
1004 |
+
download_workers=args.download_workers,
|
1005 |
+
revalidate=args.revalidate)
|
1006 |
+
print("\033[92mDownload and extraction complete.\033[0m")
|
1007 |
+
|
1008 |
+
if __name__ == "__main__":
|
1009 |
+
main()
|
src/COP-GEN-Beta/prepare_dataset_images.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from libs.autoencoder import get_model
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import torchvision
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
from tqdm import tqdm
|
9 |
+
import os
|
10 |
+
import argparse
|
11 |
+
from pathlib import Path
|
12 |
+
import glob
|
13 |
+
from majortom.NMajorTOM import NMajorTOM
|
14 |
+
import pyarrow.parquet as pq
|
15 |
+
import geopandas as gpd
|
16 |
+
import pandas as pd
|
17 |
+
from majortom.coverage_vis import get_coveragemap
|
18 |
+
|
19 |
+
torch.manual_seed(0)
|
20 |
+
np.random.seed(0)
|
21 |
+
|
22 |
+
PATCH_SIZE = 256
|
23 |
+
GRID_SIZE = 4 # 4x4 grid of patches
|
24 |
+
|
25 |
+
SATELLITE_CONFIGS = {
|
26 |
+
'S2L2A': {
|
27 |
+
'tif_bands': ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12', 'cloud_mask'],
|
28 |
+
'png_bands': ['thumbnail'],
|
29 |
+
'tif_transforms': [],
|
30 |
+
'png_transforms': [
|
31 |
+
transforms.CenterCrop(PATCH_SIZE * GRID_SIZE), # Crop to 1024x1024
|
32 |
+
transforms.ToTensor(),
|
33 |
+
transforms.Normalize(mean=(0.5,), std=(0.5,))
|
34 |
+
]
|
35 |
+
},
|
36 |
+
'S2L1C': {
|
37 |
+
'tif_bands': ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12', 'cloud_mask'],
|
38 |
+
'png_bands': ['thumbnail'],
|
39 |
+
'tif_transforms': [],
|
40 |
+
'png_transforms': [
|
41 |
+
transforms.CenterCrop(PATCH_SIZE * GRID_SIZE),
|
42 |
+
transforms.ToTensor(),
|
43 |
+
transforms.Normalize(mean=(0.5,), std=(0.5,))
|
44 |
+
]
|
45 |
+
},
|
46 |
+
'S1RTC': {
|
47 |
+
'tif_bands': ['vv', 'vh'],
|
48 |
+
'png_bands': ['thumbnail'],
|
49 |
+
'tif_transforms': [],
|
50 |
+
'png_transforms': [
|
51 |
+
transforms.CenterCrop(PATCH_SIZE * GRID_SIZE),
|
52 |
+
transforms.ToTensor(),
|
53 |
+
transforms.Normalize(mean=(0.5,), std=(0.5,))
|
54 |
+
]
|
55 |
+
},
|
56 |
+
'DEM': {
|
57 |
+
'tif_bands': ['DEM', 'compressed'],
|
58 |
+
'png_bands': ['thumbnail'],
|
59 |
+
'tif_transforms': [],
|
60 |
+
'png_transforms': [
|
61 |
+
transforms.Resize(1068), # First, interpolate to match the resolution of the other modalities (1068x1068)
|
62 |
+
transforms.CenterCrop(PATCH_SIZE * GRID_SIZE),
|
63 |
+
transforms.ToTensor(),
|
64 |
+
transforms.Normalize(mean=(0.5,), std=(0.5,))
|
65 |
+
]
|
66 |
+
}
|
67 |
+
}
|
68 |
+
|
69 |
+
def fix_crs(df):
|
70 |
+
if df['crs'].iloc[0].startswith('EPSG:EPSG:'):
|
71 |
+
df['crs'] = df['crs'].str.replace('EPSG:EPSG:', 'EPSG:', regex=False)
|
72 |
+
return df
|
73 |
+
|
74 |
+
def load_metadata(path):
|
75 |
+
df = pq.read_table(path).to_pandas()
|
76 |
+
if 'timestamp' in df.columns:
|
77 |
+
df['timestamp'] = pd.to_datetime(df.timestamp)
|
78 |
+
df = fix_crs(df)
|
79 |
+
gdf = gpd.GeoDataFrame(
|
80 |
+
df, geometry=gpd.points_from_xy(df.centre_lon, df.centre_lat), crs=df.crs.iloc[0]
|
81 |
+
)
|
82 |
+
return gdf
|
83 |
+
|
84 |
+
def process_satellite(subset_path, satellite_types, bands_per_type, ratio_train_test, seed):
|
85 |
+
"""Process multiple satellite types simultaneously while ensuring they're paired"""
|
86 |
+
modalities = {}
|
87 |
+
filtered_dfs = {}
|
88 |
+
|
89 |
+
# First, load metadata for all satellite types
|
90 |
+
for sat_type in satellite_types:
|
91 |
+
metadata_path = os.path.join(subset_path, f"Core-{sat_type}", "metadata.parquet")
|
92 |
+
if not os.path.exists(metadata_path):
|
93 |
+
print(f"Skipping {sat_type}: metadata not found at {metadata_path}")
|
94 |
+
continue
|
95 |
+
|
96 |
+
gdf = load_metadata(metadata_path)
|
97 |
+
local_dir = os.path.join(subset_path, f"Core-{sat_type}")
|
98 |
+
|
99 |
+
# Split bands into tif and png based on configuration
|
100 |
+
tif_bands = [b for b in bands_per_type[sat_type] if b in SATELLITE_CONFIGS[sat_type]['tif_bands']]
|
101 |
+
png_bands = [b for b in bands_per_type[sat_type] if b in SATELLITE_CONFIGS[sat_type]['png_bands']]
|
102 |
+
|
103 |
+
print(f"\nChecking files for {sat_type}...")
|
104 |
+
|
105 |
+
# Check which indices have all required files
|
106 |
+
valid_indices = []
|
107 |
+
|
108 |
+
for idx in tqdm(range(len(gdf)), desc=f"Validating {sat_type} samples", unit="samples"):
|
109 |
+
row = gdf.iloc[idx]
|
110 |
+
grid_cell = row.grid_cell
|
111 |
+
row_id = grid_cell.split('_')[0]
|
112 |
+
product_id = row.product_id if 'product_id' in row.index else "id"
|
113 |
+
|
114 |
+
base_path = os.path.join(local_dir, row_id, grid_cell, product_id)
|
115 |
+
all_files_exist = True
|
116 |
+
|
117 |
+
# Check TIF files
|
118 |
+
for band in tif_bands:
|
119 |
+
if not os.path.exists(os.path.join(base_path, f"{band}.tif")):
|
120 |
+
all_files_exist = False
|
121 |
+
break
|
122 |
+
|
123 |
+
# Check PNG files
|
124 |
+
if all_files_exist: # Only check PNGs if TIFs exist
|
125 |
+
for band in png_bands:
|
126 |
+
if not os.path.exists(os.path.join(base_path, f"{band}.png")):
|
127 |
+
all_files_exist = False
|
128 |
+
break
|
129 |
+
|
130 |
+
if all_files_exist:
|
131 |
+
valid_indices.append(idx)
|
132 |
+
|
133 |
+
filtered_df = gdf.iloc[valid_indices].copy()
|
134 |
+
print(f"Found {len(filtered_df)} valid samples out of {len(gdf)} for {sat_type}")
|
135 |
+
filtered_dfs[sat_type] = filtered_df
|
136 |
+
|
137 |
+
# Find common grid cells across all modalities
|
138 |
+
grid_cell_sets = {
|
139 |
+
source: set(df['grid_cell'].unique())
|
140 |
+
for source, df in filtered_dfs.items()
|
141 |
+
}
|
142 |
+
|
143 |
+
# Find intersection of all grid cell sets
|
144 |
+
common_grid_cells = set.intersection(*grid_cell_sets.values())
|
145 |
+
print(f"\nFound {len(common_grid_cells)} common grid cells across all modalities")
|
146 |
+
|
147 |
+
# Filter all modalities to keep only common grid cells
|
148 |
+
for sat_type in satellite_types:
|
149 |
+
if sat_type not in filtered_dfs:
|
150 |
+
continue
|
151 |
+
|
152 |
+
df = filtered_dfs[sat_type]
|
153 |
+
df = df[df['grid_cell'].isin(common_grid_cells)]
|
154 |
+
print(f"{sat_type}: {len(df)} samples for common grid cells")
|
155 |
+
|
156 |
+
modalities[sat_type] = {
|
157 |
+
'df': df,
|
158 |
+
'local_dir': os.path.join(subset_path, f"Core-{sat_type}"),
|
159 |
+
'tif_bands': tif_bands,
|
160 |
+
'png_bands': png_bands,
|
161 |
+
'tif_transforms': SATELLITE_CONFIGS[sat_type]['tif_transforms'],
|
162 |
+
'png_transforms': SATELLITE_CONFIGS[sat_type]['png_transforms']
|
163 |
+
}
|
164 |
+
|
165 |
+
dataset = NMajorTOM(modalities=modalities, ratio_train_test=ratio_train_test, seed=seed)
|
166 |
+
|
167 |
+
return dataset, len(common_grid_cells)
|
168 |
+
|
169 |
+
def is_valid_image(filepath):
|
170 |
+
"""Check if an image file is valid and can be opened. Deletes the file if corrupted."""
|
171 |
+
try:
|
172 |
+
from PIL import Image
|
173 |
+
with Image.open(filepath) as img:
|
174 |
+
img.verify() # Verify it's actually an image
|
175 |
+
return True
|
176 |
+
except Exception:
|
177 |
+
print(f" Warning: Corrupted or invalid image found: {filepath}")
|
178 |
+
try:
|
179 |
+
os.remove(filepath)
|
180 |
+
print(f" Deleted corrupted file: {filepath}")
|
181 |
+
except Exception as e:
|
182 |
+
print(f" Failed to delete corrupted file {filepath}: {e}")
|
183 |
+
return False
|
184 |
+
|
185 |
+
def get_existing_complete_grid_cells(output_dir, satellite_types, bands_per_type, num_grid_cells, expected_patches=16):
|
186 |
+
"""Returns a set of grid_cells that already have all their patches for all modalities"""
|
187 |
+
complete_grid_cells_by_sat = {}
|
188 |
+
corrupted_grid_cells = set() # Track grid cells with corrupted files
|
189 |
+
|
190 |
+
for sat_type in satellite_types:
|
191 |
+
sat_base_dir = f"{sat_type}_{'_'.join(bands_per_type[sat_type])}"
|
192 |
+
complete_grid_cells_by_sat[sat_type] = set()
|
193 |
+
|
194 |
+
# Check both train and test directories
|
195 |
+
for split in ['train', 'test']:
|
196 |
+
dir_path = os.path.join(output_dir, split, sat_base_dir)
|
197 |
+
print(f" Checking {dir_path} for existing complete grid cells")
|
198 |
+
if not os.path.exists(dir_path):
|
199 |
+
print(f" Warning: Directory {dir_path} does not exist")
|
200 |
+
continue
|
201 |
+
|
202 |
+
# Get all PNG files and extract their grid cells
|
203 |
+
png_files = glob.glob(os.path.join(dir_path, "*.png"))
|
204 |
+
print(f" Found {len(png_files)} PNG files in {dir_path}")
|
205 |
+
current_grid_cells = {}
|
206 |
+
|
207 |
+
for f in png_files:
|
208 |
+
# This will now delete the file if it's corrupted
|
209 |
+
if not is_valid_image(f):
|
210 |
+
# Get the grid cell from the corrupted file
|
211 |
+
base_name = os.path.basename(f)
|
212 |
+
corrupted_grid_cell = "_".join(base_name.split("_")[:-1])
|
213 |
+
# Add to set of corrupted grid cells
|
214 |
+
corrupted_grid_cells.add(corrupted_grid_cell)
|
215 |
+
# Remove this grid cell from our complete cells since we'll need to regenerate it
|
216 |
+
if corrupted_grid_cell in current_grid_cells:
|
217 |
+
del current_grid_cells[corrupted_grid_cell]
|
218 |
+
continue
|
219 |
+
|
220 |
+
base_name = os.path.basename(f)
|
221 |
+
grid_cell = "_".join(base_name.split("_")[:-1]) # Remove patch number
|
222 |
+
current_grid_cells[grid_cell] = current_grid_cells.get(grid_cell, 0) + 1
|
223 |
+
|
224 |
+
# Keep only grid cells with exactly the expected number of patches
|
225 |
+
complete_cells = {gc for gc, count in current_grid_cells.items() if count == expected_patches}
|
226 |
+
print(f" Found {len(complete_cells)} complete grid cells in {split} split for {sat_type}")
|
227 |
+
complete_grid_cells_by_sat[sat_type].update(complete_cells)
|
228 |
+
|
229 |
+
print(f"Total complete grid cells for {sat_type}: {len(complete_grid_cells_by_sat[sat_type])}")
|
230 |
+
|
231 |
+
# Find grid cells that are complete across all satellite types
|
232 |
+
if not complete_grid_cells_by_sat:
|
233 |
+
return set()
|
234 |
+
|
235 |
+
complete_grid_cells = set.intersection(*complete_grid_cells_by_sat.values())
|
236 |
+
|
237 |
+
# Remove any grid cells that had corrupted files
|
238 |
+
complete_grid_cells = complete_grid_cells - corrupted_grid_cells
|
239 |
+
|
240 |
+
# Print detailed debugging information
|
241 |
+
print("\nComplete grid cells by satellite type:")
|
242 |
+
for sat_type, cells in complete_grid_cells_by_sat.items():
|
243 |
+
print(f"{sat_type}: {len(cells)} grid cells")
|
244 |
+
print(f"\nGrid cells complete across all types: {len(complete_grid_cells)}")
|
245 |
+
if corrupted_grid_cells:
|
246 |
+
print(f"Removed {len(corrupted_grid_cells)} grid cells due to corrupted files")
|
247 |
+
|
248 |
+
if len(complete_grid_cells) < num_grid_cells:
|
249 |
+
# Find which grid cells are missing from which satellite types
|
250 |
+
all_grid_cells = set.union(*complete_grid_cells_by_sat.values())
|
251 |
+
print("\nAnalyzing missing grid cells:")
|
252 |
+
for grid_cell in all_grid_cells:
|
253 |
+
missing_from = [sat_type for sat_type in satellite_types
|
254 |
+
if grid_cell not in complete_grid_cells_by_sat[sat_type]]
|
255 |
+
if missing_from:
|
256 |
+
print(f"Grid cell {grid_cell} is missing from: {', '.join(missing_from)}")
|
257 |
+
|
258 |
+
return complete_grid_cells
|
259 |
+
|
260 |
+
def crop_images(dataset, satellite_types, bands_per_type, output_dir, num_grid_cells, flip=False, center_crop=False):
|
261 |
+
"""Extract features for all modalities simultaneously while ensuring they're paired"""
|
262 |
+
from concurrent.futures import ThreadPoolExecutor
|
263 |
+
import itertools
|
264 |
+
|
265 |
+
# Create output directories if saving PNGs
|
266 |
+
for sat_type in satellite_types:
|
267 |
+
sat_base_dir = f"{sat_type}_{'_'.join(bands_per_type[sat_type])}"
|
268 |
+
os.makedirs(os.path.join(output_dir, 'train', sat_base_dir), exist_ok=True)
|
269 |
+
os.makedirs(os.path.join(output_dir, 'test', sat_base_dir), exist_ok=True)
|
270 |
+
|
271 |
+
# Get already processed grid cells
|
272 |
+
print("Checking for existing complete grid cells...")
|
273 |
+
# Adjust the expected patch count based on center_crop mode
|
274 |
+
expected_patches = 1 if center_crop else GRID_SIZE * GRID_SIZE
|
275 |
+
# complete_grid_cells = get_existing_complete_grid_cells(output_dir, satellite_types, bands_per_type, num_grid_cells, expected_patches)
|
276 |
+
complete_grid_cells = set()
|
277 |
+
print(f"Found {len(complete_grid_cells)} already processed grid cells")
|
278 |
+
|
279 |
+
# Pre-calculate patch positions (only used if not center_crop)
|
280 |
+
patch_positions = list(itertools.product(range(GRID_SIZE), range(GRID_SIZE)))
|
281 |
+
|
282 |
+
def process_sample(sample):
|
283 |
+
"""Process a single sample (large image) and return metadata for all its patches"""
|
284 |
+
# Check if this grid cell is already processed
|
285 |
+
grid_cell = sample[satellite_types[0]]['grid_cell']
|
286 |
+
if grid_cell in complete_grid_cells:
|
287 |
+
print(f"Skipping {grid_cell} because it already has all its patches")
|
288 |
+
return []
|
289 |
+
|
290 |
+
sample_metadata = []
|
291 |
+
|
292 |
+
for sat_type in satellite_types:
|
293 |
+
modality_data = sample[sat_type]
|
294 |
+
split = modality_data['split']
|
295 |
+
grid_cell = modality_data['grid_cell']
|
296 |
+
|
297 |
+
img = modality_data['thumbnail']
|
298 |
+
|
299 |
+
if center_crop:
|
300 |
+
# Calculate center crop coordinates
|
301 |
+
h, w = img.shape[-2:]
|
302 |
+
start_h = (h - PATCH_SIZE) // 2
|
303 |
+
start_w = (w - PATCH_SIZE) // 2
|
304 |
+
patch = img[:, start_h:start_h + PATCH_SIZE, start_w:start_w + PATCH_SIZE]
|
305 |
+
patches = patch.unsqueeze(0) # Add batch dimension
|
306 |
+
else:
|
307 |
+
# Original patchifying logic
|
308 |
+
C = img.size(0)
|
309 |
+
patches = img.unfold(1, PATCH_SIZE, PATCH_SIZE).unfold(2, PATCH_SIZE, PATCH_SIZE)
|
310 |
+
patches = patches.permute(0, 1, 2, 3, 4).reshape(C, -1, PATCH_SIZE, PATCH_SIZE)
|
311 |
+
patches = patches.permute(1, 0, 2, 3) # [N_patches, C, H, W]
|
312 |
+
|
313 |
+
if sat_type == 'DEM':
|
314 |
+
patches = patches.repeat(1, 3, 1, 1)
|
315 |
+
|
316 |
+
# Compute paths once
|
317 |
+
sat_base_dir = f"{sat_type}_thumbnail"
|
318 |
+
save_dir = os.path.join(output_dir, split, sat_base_dir)
|
319 |
+
|
320 |
+
# Batch denormalize
|
321 |
+
patches_denorm = (patches.detach().cpu() + 1) / 2
|
322 |
+
|
323 |
+
# Save images
|
324 |
+
for patch_idx, patch in enumerate(patches_denorm):
|
325 |
+
if center_crop:
|
326 |
+
filename = f"{grid_cell}_center.png"
|
327 |
+
metadata = {
|
328 |
+
'grid_cell': grid_cell,
|
329 |
+
'satellite': sat_type,
|
330 |
+
'bands': 'thumbnail',
|
331 |
+
'split': split,
|
332 |
+
'patch_num': 0,
|
333 |
+
'patch_row': (GRID_SIZE - 1) // 2,
|
334 |
+
'patch_col': (GRID_SIZE - 1) // 2
|
335 |
+
}
|
336 |
+
else:
|
337 |
+
filename = f"{grid_cell}_{patch_idx}.png"
|
338 |
+
metadata = {
|
339 |
+
'grid_cell': grid_cell,
|
340 |
+
'satellite': sat_type,
|
341 |
+
'bands': 'thumbnail',
|
342 |
+
'split': split,
|
343 |
+
'patch_num': patch_idx,
|
344 |
+
'patch_row': patch_positions[patch_idx][0],
|
345 |
+
'patch_col': patch_positions[patch_idx][1]
|
346 |
+
}
|
347 |
+
|
348 |
+
torchvision.utils.save_image(patch, os.path.join(save_dir, filename))
|
349 |
+
sample_metadata.append(metadata)
|
350 |
+
|
351 |
+
return sample_metadata
|
352 |
+
|
353 |
+
# Process samples in parallel
|
354 |
+
all_metadata = []
|
355 |
+
total_samples = len(dataset)
|
356 |
+
|
357 |
+
print(f"Processing {total_samples} samples...")
|
358 |
+
|
359 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
360 |
+
# Create a list to store futures
|
361 |
+
futures = []
|
362 |
+
|
363 |
+
# Submit tasks with progress bar around the dataset iteration
|
364 |
+
for sample in tqdm(dataset, total=total_samples,
|
365 |
+
desc="Processing samples",
|
366 |
+
unit="sample",
|
367 |
+
dynamic_ncols=True):
|
368 |
+
future = executor.submit(process_sample, sample)
|
369 |
+
futures.append(future)
|
370 |
+
|
371 |
+
# Collect results
|
372 |
+
for future in futures:
|
373 |
+
metadata = future.result()
|
374 |
+
if metadata: # Only add metadata for newly processed samples
|
375 |
+
all_metadata.extend(metadata)
|
376 |
+
|
377 |
+
# Convert to DataFrame and split by train/test
|
378 |
+
if all_metadata: # Only process if we have new metadata
|
379 |
+
df = pd.DataFrame(all_metadata)
|
380 |
+
train_df = df[df['split'] == 'train'].drop('split', axis=1)
|
381 |
+
test_df = df[df['split'] == 'test'].drop('split', axis=1)
|
382 |
+
|
383 |
+
# Load existing metadata if it exists and append new data
|
384 |
+
train_path = os.path.join(output_dir, 'train_metadata.parquet')
|
385 |
+
test_path = os.path.join(output_dir, 'test_metadata.parquet')
|
386 |
+
|
387 |
+
if os.path.exists(train_path):
|
388 |
+
existing_train = pd.read_parquet(train_path)
|
389 |
+
train_df = pd.concat([existing_train, train_df], ignore_index=True)
|
390 |
+
# Deduplicate based on all columns
|
391 |
+
train_df = train_df.drop_duplicates(subset=['grid_cell', 'satellite', 'patch_num'])
|
392 |
+
|
393 |
+
if os.path.exists(test_path):
|
394 |
+
existing_test = pd.read_parquet(test_path)
|
395 |
+
test_df = pd.concat([existing_test, test_df], ignore_index=True)
|
396 |
+
# Deduplicate based on all columns
|
397 |
+
test_df = test_df.drop_duplicates(subset=['grid_cell', 'satellite', 'patch_num'])
|
398 |
+
|
399 |
+
# Save metadata
|
400 |
+
train_df.to_parquet(train_path)
|
401 |
+
test_df.to_parquet(test_path)
|
402 |
+
|
403 |
+
print(f"Processed {len(all_metadata) // (16 * len(satellite_types))} new grid cells")
|
404 |
+
print(f"Total metadata: {len(train_df)} training and {len(test_df)} testing samples")
|
405 |
+
else:
|
406 |
+
print("No new grid cells to process")
|
407 |
+
|
408 |
+
def visualize_patches(dataset, satellite_types, bands_per_type, output_dir):
|
409 |
+
"""Visualize the coverage of patches in a world map"""
|
410 |
+
# Take the first satellite type since they're all paired
|
411 |
+
sat_type = satellite_types[0]
|
412 |
+
modality = dataset.modalities[sat_type]
|
413 |
+
df = modality['df']
|
414 |
+
|
415 |
+
# Lets split into train and test.
|
416 |
+
# First, add the split column to the dataframe based on the grid_cell_to_split dictionary
|
417 |
+
df['split'] = df['grid_cell'].map(dataset.grid_cell_to_split)
|
418 |
+
|
419 |
+
# Create coverage map
|
420 |
+
coverage_img_all = get_coveragemap(df)
|
421 |
+
coverage_img_train = get_coveragemap(df[df['split'] == 'train'])
|
422 |
+
coverage_img_test = get_coveragemap(df[df['split'] == 'test'])
|
423 |
+
coverage_img_train_test = get_coveragemap(df[df['split'] == 'train'], df[df['split'] == 'test'])
|
424 |
+
|
425 |
+
# Save the coverage map
|
426 |
+
coverage_path_all = os.path.join(output_dir, 'coverage_map_all.png')
|
427 |
+
coverage_path_train = os.path.join(output_dir, 'coverage_map_train.png')
|
428 |
+
coverage_path_test = os.path.join(output_dir, 'coverage_map_test.png')
|
429 |
+
coverage_path_train_test = os.path.join(output_dir, 'coverage_map_train_test.png')
|
430 |
+
coverage_img_all.save(coverage_path_all, format='PNG')
|
431 |
+
coverage_img_train.save(coverage_path_train, format='PNG')
|
432 |
+
coverage_img_test.save(coverage_path_test, format='PNG')
|
433 |
+
coverage_img_train_test.save(coverage_path_train_test, format='PNG')
|
434 |
+
print(f"Saved coverage maps to {coverage_path_all}, {coverage_path_train}, {coverage_path_test} and {coverage_path_train_test}")
|
435 |
+
|
436 |
+
|
437 |
+
def main():
|
438 |
+
parser = argparse.ArgumentParser(description='Extract features from MajorTOM dataset')
|
439 |
+
parser.add_argument('--subset_path', required=True, help='Path to the subset folder')
|
440 |
+
parser.add_argument('--output_dir', required=True, help='Path to the output directory')
|
441 |
+
parser.add_argument('--bands', nargs='+', required=True, help='Bands to process (e.g., B1 B2 B3 DEM vv vh)')
|
442 |
+
parser.add_argument('--ratio_train_test', type=float, default=0.95, help='Ratio of training to testing data')
|
443 |
+
parser.add_argument('--flip', action='store_true', help='Flip the patches')
|
444 |
+
parser.add_argument('--visualize', action='store_true', help='Visualize the patches in a world map')
|
445 |
+
parser.add_argument('--seed', type=int, default=42, help='Random seed')
|
446 |
+
parser.add_argument('--center_crop', action='store_true', help='Use center crop instead of patchifying')
|
447 |
+
args = parser.parse_args()
|
448 |
+
|
449 |
+
# Get subset name from path
|
450 |
+
subset_name = Path(args.subset_path).name
|
451 |
+
|
452 |
+
print("Flip is set to", args.flip)
|
453 |
+
print("Seed is set to", args.seed)
|
454 |
+
print("Subset path is", args.subset_path)
|
455 |
+
print("Bands are", args.bands)
|
456 |
+
print("Ratio train test is", args.ratio_train_test)
|
457 |
+
print("Visualize is set to", args.visualize)
|
458 |
+
print("Center crop is set to", args.center_crop)
|
459 |
+
# Create the main output directory
|
460 |
+
all_bands = '_'.join(sorted(args.bands))
|
461 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
462 |
+
|
463 |
+
# Group bands by satellite type
|
464 |
+
bands_per_type = {}
|
465 |
+
satellite_types = []
|
466 |
+
for sat_type, config in SATELLITE_CONFIGS.items():
|
467 |
+
all_sat_bands = config['tif_bands'] + config['png_bands']
|
468 |
+
sat_bands = [b for b in args.bands if b in all_sat_bands]
|
469 |
+
if sat_bands:
|
470 |
+
bands_per_type[sat_type] = sat_bands
|
471 |
+
satellite_types.append(sat_type)
|
472 |
+
|
473 |
+
if satellite_types:
|
474 |
+
# Process all satellite types together
|
475 |
+
dataset, num_grid_cells = process_satellite(args.subset_path, satellite_types, bands_per_type, args.ratio_train_test, args.seed)
|
476 |
+
|
477 |
+
if args.visualize:
|
478 |
+
print("==> Visualizing patches...")
|
479 |
+
visualize_patches(dataset, satellite_types, bands_per_type, args.output_dir)
|
480 |
+
print("==> Done visualizing patches! Exiting...")
|
481 |
+
# exit()
|
482 |
+
|
483 |
+
print("==> Cropping images...")
|
484 |
+
crop_images(dataset, satellite_types, bands_per_type, args.output_dir, num_grid_cells,
|
485 |
+
flip=args.flip, center_crop=args.center_crop)
|
486 |
+
|
487 |
+
if __name__ == "__main__":
|
488 |
+
main()
|
src/COP-GEN-Beta/sample_n_triffuser.py
ADDED
@@ -0,0 +1,652 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ml_collections
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import utils
|
5 |
+
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
|
6 |
+
from absl import logging
|
7 |
+
import einops
|
8 |
+
import libs.autoencoder
|
9 |
+
from torchvision.utils import save_image, make_grid
|
10 |
+
import torchvision.transforms as standard_transforms
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
import time
|
14 |
+
import copy
|
15 |
+
from datasets import get_dataset
|
16 |
+
from torch.utils.data import Dataset, DataLoader
|
17 |
+
from tqdm.auto import tqdm
|
18 |
+
from torch.utils._pytree import tree_map
|
19 |
+
import glob
|
20 |
+
import os
|
21 |
+
import functools
|
22 |
+
from concurrent.futures import ThreadPoolExecutor
|
23 |
+
|
24 |
+
# Add profiling tools
|
25 |
+
class Profiler:
|
26 |
+
def __init__(self):
|
27 |
+
self.times = {}
|
28 |
+
|
29 |
+
def profile(self, func):
|
30 |
+
@functools.wraps(func)
|
31 |
+
def wrapper(*args, **kwargs):
|
32 |
+
start_time = time.time()
|
33 |
+
result = func(*args, **kwargs)
|
34 |
+
end_time = time.time()
|
35 |
+
|
36 |
+
func_name = func.__name__
|
37 |
+
if func_name not in self.times:
|
38 |
+
self.times[func_name] = []
|
39 |
+
self.times[func_name].append(end_time - start_time)
|
40 |
+
|
41 |
+
return result
|
42 |
+
return wrapper
|
43 |
+
|
44 |
+
def summary(self):
|
45 |
+
print("\n----- Profiling Summary -----")
|
46 |
+
for func_name, times in self.times.items():
|
47 |
+
avg_time = sum(times) / len(times)
|
48 |
+
total_time = sum(times)
|
49 |
+
calls = len(times)
|
50 |
+
print(f"{func_name}: {total_time:.2f}s total, {avg_time:.4f}s avg, {calls} calls")
|
51 |
+
print("----------------------------\n")
|
52 |
+
|
53 |
+
profiler = Profiler()
|
54 |
+
|
55 |
+
MODALITIES = {
|
56 |
+
4: ['dem', 's1_rtc', 's2_l1c', 's2_l2a'],
|
57 |
+
3: ['dem', 's1_rtc', 's2_l1c'],
|
58 |
+
2: ['s1_rtc', 's2_l2a'],
|
59 |
+
}
|
60 |
+
|
61 |
+
MODEL_RESOLUTION = 256
|
62 |
+
|
63 |
+
"""
|
64 |
+
Sampling for any n modalities
|
65 |
+
|
66 |
+
> python3 sample_n_triffuser.py --config=path --data_path=path --nnet_path=path \
|
67 |
+
--n_mod=int --n_samples=int \
|
68 |
+
--generate=[modalities] --condition=[modalities]
|
69 |
+
|
70 |
+
|
71 |
+
Generate all modalities unconditional (joint):
|
72 |
+
python3 sample_n_triffuser.py --n_mod=4 --generate=s2_l1c,s2_l2a,s1_rtc,dem
|
73 |
+
|
74 |
+
Generate a pair unconditional (joint):
|
75 |
+
python3 sample_n_triffuser.py --n_mod=4 --generate=s1_rtc,s2_l2a
|
76 |
+
|
77 |
+
Generate s1_rtc and s2_l2a, conditioned on dem and s2_l1c (conditional):
|
78 |
+
python3 sample_n_triffuser.py --n_mod=4 --generate=s1_rtc,s2_l2a --condition=dem,s2_l1c
|
79 |
+
|
80 |
+
Generate dem conditioned on s1_rtc, s2_l1c, s2_l2a (conditional):
|
81 |
+
python3 sample_n_triffuser.py --n_mod=4 --generate=dem --condition=s1_rtc,s2_l1c,s2_l2a
|
82 |
+
|
83 |
+
Generate dem conditioned on s1_rtc (conditional) (the rest are automatically ignored: s2_l1c, s2_l2a):
|
84 |
+
python3 sample_n_triffuser.py --n_mod=4 --generate=dem --condition=s1_rtc
|
85 |
+
|
86 |
+
Generate dem unconditional (marginal) (no condition, the rest are ignored):
|
87 |
+
python3 sample_n_triffuser.py --n_mod=4 --generate=dem
|
88 |
+
|
89 |
+
|
90 |
+
Note:
|
91 |
+
--generate flag is mandatory
|
92 |
+
"generate" modalities and "condition" modalities should always be different
|
93 |
+
|
94 |
+
|
95 |
+
"""
|
96 |
+
|
97 |
+
class CustomImageDataset(Dataset):
|
98 |
+
def __init__(self, folder_path, transform=None):
|
99 |
+
self.folder_path = folder_path
|
100 |
+
self.transform = transform
|
101 |
+
self.image_files = glob.glob(os.path.join(folder_path, "*.png"))
|
102 |
+
print("There are", len(self.image_files), "images in the dataset")
|
103 |
+
|
104 |
+
def __len__(self):
|
105 |
+
return len(self.image_files)
|
106 |
+
|
107 |
+
def __getitem__(self, idx):
|
108 |
+
image = Image.open(self.image_files[idx]).convert("RGB")
|
109 |
+
if self.transform:
|
110 |
+
image = self.transform(image)
|
111 |
+
# Return both the image and the filename
|
112 |
+
return image, os.path.basename(self.image_files[idx])
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
|
117 |
+
_betas = (
|
118 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
119 |
+
)
|
120 |
+
return _betas.numpy()
|
121 |
+
|
122 |
+
@profiler.profile
|
123 |
+
def prepare_contexts(config, images, filenames, device, autoencoder=None):
|
124 |
+
"""
|
125 |
+
If a modality is conditional, we need to return the npy feature encodings
|
126 |
+
If a modality is unconditional, we need to return random noise
|
127 |
+
|
128 |
+
batch_shape = (n_modalities, B, C, H, W)
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
img_contexts: Tensor containing contexts for each modality
|
132 |
+
processed_filenames: List of filenames, duplicated and labeled with version suffixes if n_samples > 1
|
133 |
+
"""
|
134 |
+
|
135 |
+
# Create a noise tensor with the same shape as the images batch
|
136 |
+
if config.data_type == 'lmdb':
|
137 |
+
effective_batch_size = images[0].shape[0] * config.n_samples if config.n_samples > 1 else images[0].shape[0]
|
138 |
+
img_contexts = torch.randn(config.num_modalities, effective_batch_size, *images[0].shape[1:], device=device)
|
139 |
+
elif config.data_type == 'folder-img':
|
140 |
+
# Calculate effective batch size (original batch size * n_samples)
|
141 |
+
effective_batch_size = images.shape[0] * config.n_samples if config.n_samples > 1 else images.shape[0]
|
142 |
+
# Multiply the images_batch shape by 2 because we have both mean and variance
|
143 |
+
# as output from the autoencoder
|
144 |
+
img_contexts = torch.randn(config.num_modalities, effective_batch_size, 2 * config.z_shape[0],
|
145 |
+
config.z_shape[1], config.z_shape[2], device=device)
|
146 |
+
|
147 |
+
# Process filenames - duplicate them if n_samples > 1 and add version suffixes
|
148 |
+
processed_filenames = []
|
149 |
+
if config.n_samples > 1:
|
150 |
+
for filename in filenames:
|
151 |
+
for i in range(config.n_samples):
|
152 |
+
processed_filenames.append(f"{filename}_v{i+1}")
|
153 |
+
else:
|
154 |
+
processed_filenames = filenames
|
155 |
+
|
156 |
+
# For each modality in the images_batch, if it is conditional, load and duplicate the npy feature encodings
|
157 |
+
for i, modality in enumerate(config.modalities):
|
158 |
+
if config.condition_modalities_mask[i]:
|
159 |
+
if config.data_type == 'lmdb':
|
160 |
+
# Duplicate each conditional input n_samples times
|
161 |
+
img_contexts[i] = images[i].repeat_interleave(config.n_samples, dim=0)
|
162 |
+
elif config.data_type == 'folder-img':
|
163 |
+
assert autoencoder is not None, "Autoencoder must be provided for folder-img data type"
|
164 |
+
# Duplicate each conditional input n_samples times
|
165 |
+
duplicated_batch = images.repeat_interleave(config.n_samples, dim=0)
|
166 |
+
img_contexts[i] = autoencoder.encode_moments(duplicated_batch)
|
167 |
+
|
168 |
+
# Padding the latents experiment
|
169 |
+
# duplicated_batch = images.repeat_interleave(config.n_samples, dim=0)
|
170 |
+
# intermediate_latents = autoencoder.encode_moments(duplicated_batch)
|
171 |
+
# padded_latents = torch.nn.functional.pad(intermediate_latents, (8, 8, 8, 8), mode='reflect')
|
172 |
+
# img_contexts[i] = padded_latents
|
173 |
+
|
174 |
+
return img_contexts, processed_filenames
|
175 |
+
|
176 |
+
def unpreprocess(v): # to B C H W and [0, 1]
|
177 |
+
v = 0.5 * (v + 1.)
|
178 |
+
v.clamp_(0., 1.)
|
179 |
+
return v
|
180 |
+
|
181 |
+
|
182 |
+
def set_seed(seed: int):
|
183 |
+
random.seed(seed)
|
184 |
+
np.random.seed(seed)
|
185 |
+
torch.manual_seed(seed)
|
186 |
+
torch.cuda.manual_seed_all(seed)
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
def evaluate(config):
|
191 |
+
if config.get('benchmark', False):
|
192 |
+
torch.backends.cudnn.benchmark = True
|
193 |
+
torch.backends.cudnn.deterministic = False
|
194 |
+
|
195 |
+
# Create output directory once at the start
|
196 |
+
os.makedirs(config.output_path, exist_ok=True)
|
197 |
+
# Create a directory for each modality if we are saving as pngs
|
198 |
+
if config.save_as == 'pngs':
|
199 |
+
for modality in config.generate_modalities:
|
200 |
+
os.makedirs(os.path.join(config.output_path, modality), exist_ok=True)
|
201 |
+
|
202 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
203 |
+
set_seed(config.seed)
|
204 |
+
|
205 |
+
config = ml_collections.FrozenConfigDict(config)
|
206 |
+
utils.set_logger(log_level='info')
|
207 |
+
|
208 |
+
_betas = stable_diffusion_beta_schedule()
|
209 |
+
N = len(_betas)
|
210 |
+
|
211 |
+
nnet = utils.get_nnet(**config.nnet)
|
212 |
+
logging.info(f'load nnet from {config.nnet_path}')
|
213 |
+
nnet.load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
|
214 |
+
nnet.to(device)
|
215 |
+
nnet.eval()
|
216 |
+
|
217 |
+
if config.data_type == 'lmdb':
|
218 |
+
# Edit the dataset path to the data path from the command line arguments
|
219 |
+
dataset_config = ml_collections.ConfigDict(config.to_dict())
|
220 |
+
dataset_config.dataset.path = config.data_path
|
221 |
+
|
222 |
+
# Always return the filename
|
223 |
+
dataset_config.dataset.return_filename = True
|
224 |
+
|
225 |
+
dataset = get_dataset(**dataset_config.dataset)
|
226 |
+
# TODO: This is not intuitive. Split is train but it is returning the test set. See datasets.py
|
227 |
+
test_dataset = dataset.get_split(split='train', labeled=False)
|
228 |
+
# Create a generator with fixed seed for reproducible shuffling
|
229 |
+
g = torch.Generator()
|
230 |
+
g.manual_seed(config.seed) # Using the same seed as set earlier in the code
|
231 |
+
dataloader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=True, drop_last=False,
|
232 |
+
num_workers=8, pin_memory=True, persistent_workers=True, generator=g)
|
233 |
+
|
234 |
+
elif config.data_type == 'folder-img':
|
235 |
+
print("config.data_path", config.data_path)
|
236 |
+
|
237 |
+
if config.resolution >= MODEL_RESOLUTION:
|
238 |
+
transform = standard_transforms.Compose([
|
239 |
+
standard_transforms.CenterCrop(MODEL_RESOLUTION),
|
240 |
+
standard_transforms.ToTensor(),
|
241 |
+
standard_transforms.Normalize(mean=(0.5,), std=(0.5,)),
|
242 |
+
])
|
243 |
+
else:
|
244 |
+
padding_4sides = (MODEL_RESOLUTION - config.resolution) // 2
|
245 |
+
transform = standard_transforms.Compose([
|
246 |
+
standard_transforms.CenterCrop(config.resolution),
|
247 |
+
standard_transforms.ToTensor(),
|
248 |
+
torch.nn.ReflectionPad2d(padding_4sides),
|
249 |
+
standard_transforms.Normalize(mean=(0.5,), std=(0.5,)),
|
250 |
+
])
|
251 |
+
dataset = CustomImageDataset(config.data_path, transform=transform)
|
252 |
+
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False, drop_last=False,
|
253 |
+
num_workers=8, pin_memory=True, persistent_workers=True)
|
254 |
+
else:
|
255 |
+
raise ValueError(f"Invalid data type: {config.data_type}. Must be one of ['lmdb', 'folder-img']")
|
256 |
+
|
257 |
+
autoencoder = libs.autoencoder.get_model(**config.autoencoder)
|
258 |
+
autoencoder.to(device)
|
259 |
+
|
260 |
+
|
261 |
+
@profiler.profile
|
262 |
+
def split_joint(x, z_imgs, config):
|
263 |
+
"""
|
264 |
+
Input:
|
265 |
+
x: (B, C, H, W)
|
266 |
+
is only the modalities that are being denoised
|
267 |
+
z_imgs: (M, B, C, H, W)
|
268 |
+
the original img_latents for all modalities
|
269 |
+
(but we only use the ones for the modalities that are being denoised)
|
270 |
+
config: config
|
271 |
+
|
272 |
+
First, split the input into the modalities into correct shape
|
273 |
+
Second, return a full list of the modalities,
|
274 |
+
including the ones being conditioned on and the ones being ignored.
|
275 |
+
|
276 |
+
Returns list of all modalities (some are denoised, some are conditioned on, some are ignored)
|
277 |
+
|
278 |
+
"""
|
279 |
+
|
280 |
+
C, H, W = config.z_shape
|
281 |
+
z_dim = C * H * W
|
282 |
+
z_generated = x.split([z_dim] * len(config.generate_modalities), dim=1)
|
283 |
+
z_generated = {modality: einops.rearrange(z_i, 'B (C H W) -> B C H W', C=C, H=H, W=W)
|
284 |
+
for z_i, modality in zip(z_generated, config.generate_modalities)}
|
285 |
+
|
286 |
+
z = []
|
287 |
+
for i, modality in enumerate(config.modalities):
|
288 |
+
# Modalities that are being denoised
|
289 |
+
if modality in config.generate_modalities:
|
290 |
+
z.append(z_generated[modality])
|
291 |
+
# Modalities that are being conditioned on
|
292 |
+
elif modality in config.condition_modalities:
|
293 |
+
z.append(z_imgs[i])
|
294 |
+
# Modalities that are ignored
|
295 |
+
else:
|
296 |
+
z.append(torch.randn(x.shape[0], C, H, W, device=device))
|
297 |
+
|
298 |
+
return z
|
299 |
+
|
300 |
+
|
301 |
+
@profiler.profile
|
302 |
+
def combine_joint(z):
|
303 |
+
"""
|
304 |
+
Input:
|
305 |
+
z: list of ONLY the modalities that are being denoised
|
306 |
+
Returns:
|
307 |
+
z: (B, C * H * W)
|
308 |
+
"""
|
309 |
+
z = torch.concat([einops.rearrange(z_i, 'B C H W -> B (C H W)') for z_i in z], dim=-1)
|
310 |
+
return z
|
311 |
+
|
312 |
+
@torch.cuda.amp.autocast()
|
313 |
+
@profiler.profile
|
314 |
+
def encode(_batch):
|
315 |
+
return autoencoder.encode(_batch)
|
316 |
+
|
317 |
+
@torch.cuda.amp.autocast()
|
318 |
+
@profiler.profile
|
319 |
+
def decode(_batch):
|
320 |
+
return autoencoder.decode(_batch)
|
321 |
+
|
322 |
+
def get_data_generator():
|
323 |
+
# Run single epoch
|
324 |
+
for data in tqdm(dataloader, desc='epoch'):
|
325 |
+
yield data
|
326 |
+
|
327 |
+
logging.info("Num of modalities: %d", config.num_modalities)
|
328 |
+
logging.info("Num of images in dataloader: %d", len(dataloader))
|
329 |
+
logging.info("Generate modalities: %s", config.generate_modalities)
|
330 |
+
logging.info("Condition modalities: %s", config.condition_modalities)
|
331 |
+
logging.info("Condition modalities mask: %s", config.condition_modalities_mask)
|
332 |
+
logging.info("Generate modalities mask: %s", config.generate_modalities_mask)
|
333 |
+
logging.info(f'N={N}')
|
334 |
+
|
335 |
+
|
336 |
+
@profiler.profile
|
337 |
+
def run_nnet(x, t, z_imgs):
|
338 |
+
|
339 |
+
timesteps = [t if mask else torch.zeros_like(t) for mask in config.generate_modalities_mask]
|
340 |
+
|
341 |
+
# ==== EXPAND TO ALL MODALITIES ====
|
342 |
+
z = split_joint(x, z_imgs, config=config)
|
343 |
+
# z = {modality1: z_generated_modality1, modality2: z_conditioned_modality2, ...}
|
344 |
+
|
345 |
+
# == DEBUG CODE: Decode, unprocess, and save both modalities side by side
|
346 |
+
# z_decoded_1 = decode(z[0])
|
347 |
+
# z_decoded_2 = decode(z[1])
|
348 |
+
# z_decoded_1 = unpreprocess(z_decoded_1)
|
349 |
+
# z_decoded_2 = unpreprocess(z_decoded_2)
|
350 |
+
# z_decoded_combined = torch.cat([z_decoded_1, z_decoded_2], dim=-1) # Concatenate along width dimension
|
351 |
+
# print(f"saving image z_decoded_combined_{t}.png")
|
352 |
+
# save_image(z_decoded_combined, os.path.join(config.output_path, f"z_supeeerdecoded_{t}.png"))
|
353 |
+
# == DEBUG CODE END ==
|
354 |
+
|
355 |
+
"""
|
356 |
+
nnet expects:
|
357 |
+
- z: (M, B, C, H, W)
|
358 |
+
- t_imgs: (M, B)
|
359 |
+
where M is the number of modalities.
|
360 |
+
|
361 |
+
That is, z should be a list of M batches, each batch corresponding to a modality.
|
362 |
+
E.g. num_modalities(M)=3, batch_size(B)=16, z_shape(C, H, W)=(4, 32, 32) ->
|
363 |
+
z = [(16, 4, 32, 32), (16, 4, 32, 32), (16, 4, 32, 32)]
|
364 |
+
t_imgs = [(16,), (16,), (16,)]
|
365 |
+
"""
|
366 |
+
|
367 |
+
z_out = nnet(z, t_imgs=timesteps)
|
368 |
+
|
369 |
+
# ==== SELECT ONLY THE GENERATED MODALITIES for the denoising process ====
|
370 |
+
z_out_generated = [z_out[i]
|
371 |
+
for i, modality in enumerate(config.modalities)
|
372 |
+
if modality in config.generate_modalities]
|
373 |
+
|
374 |
+
x_out = combine_joint(z_out_generated)
|
375 |
+
|
376 |
+
if config.sample.scale == 0.:
|
377 |
+
return x_out
|
378 |
+
|
379 |
+
return x_out # TODO: Implement classifier-free guidance if there is time
|
380 |
+
|
381 |
+
@profiler.profile
|
382 |
+
def sample_fn(z_imgs, **kwargs):
|
383 |
+
# Calculate effective batch size
|
384 |
+
effective_batch_size = z_imgs[0].shape[0]
|
385 |
+
|
386 |
+
# Generate random initial noise for the modalities being generated/denoised
|
387 |
+
_z_init = torch.randn(len(config.generate_modalities), effective_batch_size, *z_imgs[0].shape[1:], device=device)
|
388 |
+
|
389 |
+
_x_init = combine_joint(_z_init)
|
390 |
+
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
|
391 |
+
|
392 |
+
@profiler.profile
|
393 |
+
def model_fn(x, t_continuous):
|
394 |
+
t = t_continuous * N
|
395 |
+
return run_nnet(x, t, z_imgs)
|
396 |
+
|
397 |
+
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
|
398 |
+
with torch.no_grad():
|
399 |
+
with torch.autocast(device_type=device):
|
400 |
+
start_time = time.time()
|
401 |
+
x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)
|
402 |
+
end_time = time.time()
|
403 |
+
print(f'\ngenerate {config.batch_size} samples with {config.sample.sample_steps} steps takes {end_time - start_time:.2f}s')
|
404 |
+
|
405 |
+
_zs = split_joint(x, z_imgs, config=config)
|
406 |
+
|
407 |
+
# Replace the conditional modalities with the original images
|
408 |
+
for i, mask in enumerate(config.condition_modalities_mask):
|
409 |
+
if mask:
|
410 |
+
_zs[i] = z_imgs[i]
|
411 |
+
|
412 |
+
return _zs
|
413 |
+
|
414 |
+
|
415 |
+
data_generator = get_data_generator()
|
416 |
+
for idx_batch, batch in enumerate(data_generator):
|
417 |
+
|
418 |
+
batch_start_time = time.time()
|
419 |
+
|
420 |
+
# Unpack the batch into images and filenames
|
421 |
+
original_images, original_filenames = batch
|
422 |
+
|
423 |
+
# print(filenames)
|
424 |
+
|
425 |
+
# Track data loading and preprocessing time
|
426 |
+
preprocess_start = time.time()
|
427 |
+
images = tree_map(lambda x: x.to(device), original_images)
|
428 |
+
# In addition to preparing the contexts (returns mean and variance),
|
429 |
+
# we need to actually sample the values from the distribution
|
430 |
+
img_contexts, filenames = prepare_contexts(config, images, original_filenames, device=device, autoencoder=autoencoder)
|
431 |
+
z_imgs = torch.stack([autoencoder.sample(img_context) for img_context in img_contexts])
|
432 |
+
preprocess_time = time.time() - preprocess_start
|
433 |
+
|
434 |
+
# Track sampling time
|
435 |
+
sample_start = time.time()
|
436 |
+
_zs = sample_fn(z_imgs)
|
437 |
+
sample_time = time.time() - sample_start
|
438 |
+
|
439 |
+
# Track decoding time
|
440 |
+
decode_start = time.time()
|
441 |
+
samples_unstacked = [unpreprocess(decode(_z)) for _z in _zs]
|
442 |
+
|
443 |
+
# Crop back to input resolution if it is smaller than MODEL_RESOLUTION
|
444 |
+
if config.resolution < MODEL_RESOLUTION:
|
445 |
+
samples_unstacked = [standard_transforms.functional.center_crop(sample, output_size=config.resolution)
|
446 |
+
for sample in samples_unstacked]
|
447 |
+
|
448 |
+
samples = torch.stack(samples_unstacked, dim=0)
|
449 |
+
decode_time = time.time() - decode_start
|
450 |
+
|
451 |
+
# Track saving time
|
452 |
+
save_start = time.time()
|
453 |
+
|
454 |
+
if config.save_as == 'grid':
|
455 |
+
|
456 |
+
b = samples.shape[1] # batch size
|
457 |
+
# Properly interleave samples from all modalities
|
458 |
+
# For each sample index, get all modalities before moving to next sample
|
459 |
+
samples = torch.stack([samples[j, i] for i in range(b) for j in range(config.nnet.num_modalities)]).view(-1, *samples.shape[2:])
|
460 |
+
# If the number of modalities is 3 then we plot in 9 columns
|
461 |
+
n_cols = 9 if config.nnet.num_modalities == 3 else 8
|
462 |
+
samples = make_grid(samples, n_cols)
|
463 |
+
save_path = os.path.join(config.output_path, f'grid_{idx_batch}.png')
|
464 |
+
save_image(samples, save_path)
|
465 |
+
|
466 |
+
|
467 |
+
# plot_real_images = '/home/s2254242/projects/pangaea_terramind/data/test_set_1/test' # We want to plot into a grid_real_images_{idx_batch}.png the real images
|
468 |
+
plot_real_images = ''
|
469 |
+
|
470 |
+
if plot_real_images != '':
|
471 |
+
# Load real images from files
|
472 |
+
real_images_list = []
|
473 |
+
for filename in original_filenames:
|
474 |
+
for modality in config.modalities:
|
475 |
+
img_path = os.path.join(plot_real_images, modality, f"{filename}.png")
|
476 |
+
img = Image.open(img_path).convert("RGB")
|
477 |
+
img_tensor = standard_transforms.ToTensor()(img)
|
478 |
+
real_images_list.append(img_tensor)
|
479 |
+
|
480 |
+
# Stack and create grid
|
481 |
+
real_images = torch.stack(real_images_list)
|
482 |
+
real_grid = make_grid(real_images, n_cols)
|
483 |
+
real_save_path = os.path.join(config.output_path, f'grid_real_{idx_batch}.png')
|
484 |
+
save_image(real_grid, real_save_path)
|
485 |
+
|
486 |
+
elif config.save_as == 'pngs':
|
487 |
+
# Define a helper function to save a single image
|
488 |
+
def save_single_image(args):
|
489 |
+
modality_idx, modality, b_idx = args
|
490 |
+
filename = filenames[b_idx] if isinstance(filenames, list) else filenames
|
491 |
+
save_path = os.path.join(os.path.join(config.output_path, modality), f"{filename}.png")
|
492 |
+
save_image(samples[modality_idx][b_idx], save_path)
|
493 |
+
|
494 |
+
# Create a list of all save operations needed
|
495 |
+
save_tasks = []
|
496 |
+
for i, modality in enumerate(config.modalities):
|
497 |
+
if modality in config.generate_modalities:
|
498 |
+
modality_dir = os.path.join(config.output_path, modality)
|
499 |
+
for b_idx in range(samples[i].shape[0]):
|
500 |
+
save_tasks.append((i, modality, b_idx))
|
501 |
+
|
502 |
+
# Use ThreadPoolExecutor to parallelize the saving process
|
503 |
+
max_workers = min(16, len(save_tasks)) # Limit to 16 threads max
|
504 |
+
if max_workers > 0: # Only create pool if there are tasks
|
505 |
+
print(f"Saving {len(save_tasks)} images using {max_workers} threads...")
|
506 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
507 |
+
list(tqdm(executor.map(save_single_image, save_tasks), total=len(save_tasks), desc="Saving images"))
|
508 |
+
|
509 |
+
elif config.data_type == 'folder-img':
|
510 |
+
# Get indices for all modalities we want to save
|
511 |
+
save_modalities = ['s1_rtc', 's2_l2a']
|
512 |
+
# append_real_from_paths = ['data/pastis_pngs/sar/']
|
513 |
+
modality_indices = [config.modalities.index(m) for m in save_modalities]
|
514 |
+
|
515 |
+
for i in range(min(config.batch_size, len(filenames))):
|
516 |
+
# Stack the samples from different modalities horizontally
|
517 |
+
concat_samples = torch.cat([samples[idx, i] for idx in modality_indices], dim=2)
|
518 |
+
|
519 |
+
# Append real images from specified paths
|
520 |
+
real_images = []
|
521 |
+
for real_path in append_real_from_paths:
|
522 |
+
real_img_path = os.path.join(real_path, filenames[i])
|
523 |
+
real_img = Image.open(real_img_path).convert("RGB")
|
524 |
+
real_img_tensor = standard_transforms.ToTensor()(real_img)
|
525 |
+
real_images.append(real_img_tensor)
|
526 |
+
|
527 |
+
real_images_tensor = torch.cat(real_images, dim=2) if len(real_images) > 1 else real_images[0]
|
528 |
+
concat_samples = torch.cat([concat_samples, real_images_tensor.to(device)], dim=2)
|
529 |
+
|
530 |
+
save_path = os.path.join(config.output_path, filenames[i])
|
531 |
+
save_image(concat_samples, save_path)
|
532 |
+
|
533 |
+
save_time = time.time() - save_start
|
534 |
+
|
535 |
+
batch_total_time = time.time() - batch_start_time
|
536 |
+
|
537 |
+
print(f'\nBatch {idx_batch} timing:')
|
538 |
+
print(f' Preprocessing: {preprocess_time:.2f}s ({preprocess_time/batch_total_time*100:.1f}%)')
|
539 |
+
print(f' Sampling: {sample_time:.2f}s ({sample_time/batch_total_time*100:.1f}%)')
|
540 |
+
print(f' Decoding: {decode_time:.2f}s ({decode_time/batch_total_time*100:.1f}%)')
|
541 |
+
print(f' Saving: {save_time:.2f}s ({save_time/batch_total_time*100:.1f}%)')
|
542 |
+
print(f' Total: {batch_total_time:.2f}s')
|
543 |
+
|
544 |
+
print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB')
|
545 |
+
print(f'\nresults are saved in {os.path.join(config.output_path)} :)')
|
546 |
+
|
547 |
+
# After processing, display the profiling summary
|
548 |
+
if idx_batch % 5 == 0 or idx_batch == len(dataloader) - 1:
|
549 |
+
profiler.summary()
|
550 |
+
|
551 |
+
|
552 |
+
from absl import flags
|
553 |
+
from absl import app
|
554 |
+
from ml_collections import config_flags
|
555 |
+
import os
|
556 |
+
|
557 |
+
|
558 |
+
FLAGS = flags.FLAGS
|
559 |
+
config_flags.DEFINE_config_file(
|
560 |
+
"config", None, "Configuration.", lock_config=False)
|
561 |
+
flags.DEFINE_string("data_path", None, "Path to the data")
|
562 |
+
flags.DEFINE_string("data_type", 'lmdb', "Type of data to load (lmdb, folder-img)")
|
563 |
+
flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
|
564 |
+
flags.DEFINE_string("output_path", None, "The path to save the generated images")
|
565 |
+
flags.DEFINE_integer("n_mod", None, "Number of modalities")
|
566 |
+
flags.DEFINE_integer("n_samples", 1, "The number of samples to generate with the same condition")
|
567 |
+
flags.DEFINE_string("generate", None, "Comma-separated list of modalities to generate (s2_l1c,s2_l2a,s1_rtc,dem)")
|
568 |
+
flags.DEFINE_string("condition", None, "Comma-separated list of modalities to condition on (s2_l1c,s2_l2a,s1_rtc,dem)")
|
569 |
+
flags.DEFINE_string("save_as", 'grid', "How to save the generated images (grid, pngs)")
|
570 |
+
flags.DEFINE_integer("resolution", 256, "The resolution of the images to generate")
|
571 |
+
flags.DEFINE_integer("seed", None, "Random seed for reproducibility (overrides config seed)")
|
572 |
+
|
573 |
+
|
574 |
+
def main(argv):
|
575 |
+
config = FLAGS.config
|
576 |
+
config.nnet_path = FLAGS.nnet_path
|
577 |
+
config.data_path = FLAGS.data_path
|
578 |
+
config.save_as = FLAGS.save_as
|
579 |
+
config.n_samples = FLAGS.n_samples if FLAGS.n_samples else 1
|
580 |
+
config.resolution = FLAGS.resolution
|
581 |
+
|
582 |
+
# Override seed if provided from command line
|
583 |
+
if FLAGS.seed is not None:
|
584 |
+
config.seed = FLAGS.seed
|
585 |
+
|
586 |
+
# batch_size controls the number of unique conditional images we use
|
587 |
+
config.batch_size = 6
|
588 |
+
|
589 |
+
config.modalities = MODALITIES[FLAGS.n_mod]
|
590 |
+
|
591 |
+
if FLAGS.generate is None:
|
592 |
+
raise ValueError("--generate flag is mandatory")
|
593 |
+
|
594 |
+
# Parse generate and condition modalities
|
595 |
+
config.generate_modalities = FLAGS.generate.split(',')
|
596 |
+
config.condition_modalities = FLAGS.condition.split(',') if FLAGS.condition else []
|
597 |
+
|
598 |
+
# Sort the modalities by the order of the config.modalities
|
599 |
+
config.generate_modalities = sorted(config.generate_modalities, key=lambda x: config.modalities.index(x))
|
600 |
+
config.condition_modalities = sorted(config.condition_modalities, key=lambda x: config.modalities.index(x))
|
601 |
+
|
602 |
+
config.generate_modalities_mask = [mod in config.generate_modalities for mod in config.modalities]
|
603 |
+
config.condition_modalities_mask = [mod in config.condition_modalities for mod in config.modalities]
|
604 |
+
|
605 |
+
# Validate modalities
|
606 |
+
valid_modalities = {'s2_l1c', 's2_l2a', 's1_rtc', 'dem'}
|
607 |
+
for mod in config.generate_modalities + config.condition_modalities:
|
608 |
+
if mod not in valid_modalities:
|
609 |
+
raise ValueError(f"Invalid modality: {mod}. Must be one of {valid_modalities}")
|
610 |
+
|
611 |
+
# Check that generate and condition modalities don't overlap
|
612 |
+
if set(config.generate_modalities) & set(config.condition_modalities):
|
613 |
+
raise ValueError("Generate and condition modalities must be different")
|
614 |
+
|
615 |
+
if FLAGS.data_type == 'lmdb':
|
616 |
+
# Check that there exists a data.mdb and a lock.mdb in the data path
|
617 |
+
if not os.path.exists(os.path.join(config.data_path, 'data.mdb')):
|
618 |
+
raise ValueError(f"data.mdb does not exist in {config.data_path}")
|
619 |
+
if not os.path.exists(os.path.join(config.data_path, 'lock.mdb')):
|
620 |
+
raise ValueError(f"lock.mdb does not exist in {config.data_path}")
|
621 |
+
elif FLAGS.data_type == 'folder-img':
|
622 |
+
# raise NotImplementedError("Folder-img data type not implemented")
|
623 |
+
pass
|
624 |
+
else:
|
625 |
+
raise ValueError(f"Invalid data type: {FLAGS.data_type}. Must be one of ['lmdb', 'folder-img']")
|
626 |
+
config.data_type = FLAGS.data_type
|
627 |
+
|
628 |
+
assert config.nnet.num_modalities == FLAGS.n_mod, "Number of modalities in the nnet must match the number of modalities in the command line arguments"
|
629 |
+
config.num_modalities = FLAGS.n_mod
|
630 |
+
|
631 |
+
# Format the output path based on conditions and modalities
|
632 |
+
clean_generate = [mod.replace('_', '') for mod in config.generate_modalities]
|
633 |
+
if config.condition_modalities:
|
634 |
+
clean_condition = [mod.replace('_', '') for mod in config.condition_modalities]
|
635 |
+
output_dir = f"condition_{'_'.join(clean_condition)}_generate_{'_'.join(clean_generate)}_{config.n_samples}samples"
|
636 |
+
else:
|
637 |
+
output_dir = f"generate_{'_'.join(clean_generate)}_{config.n_samples}samples"
|
638 |
+
|
639 |
+
if config.save_as == 'grid':
|
640 |
+
config.output_path = os.path.join(FLAGS.output_path, 'grids', output_dir)
|
641 |
+
else:
|
642 |
+
config.output_path = os.path.join(FLAGS.output_path, output_dir)
|
643 |
+
|
644 |
+
evaluate(config)
|
645 |
+
|
646 |
+
# Print final profiling summary
|
647 |
+
print("\n===== FINAL PROFILING SUMMARY =====")
|
648 |
+
profiler.summary()
|
649 |
+
|
650 |
+
|
651 |
+
if __name__ == "__main__":
|
652 |
+
app.run(main)
|
src/COP-GEN-Beta/scripts/download_rome.sh
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
DATA_DIR="./data/majorTOM"
|
4 |
+
START_DATE="2017-01-01"
|
5 |
+
END_DATE="2025-01-01"
|
6 |
+
SOURCES=("Core-S2L2A" "Core-S2L1C" "Core-S1RTC" "Core-DEM")
|
7 |
+
|
8 |
+
python3 majortom/download_world.py \
|
9 |
+
--data-dir $DATA_DIR \
|
10 |
+
--sources "${SOURCES[@]}" \
|
11 |
+
--start-date $START_DATE \
|
12 |
+
--end-date $END_DATE \
|
13 |
+
--cloud-cover 0 10 \
|
14 |
+
--subset-name "rome" \
|
15 |
+
--bbox 12.2 41.6 13.0 42.2 \
|
16 |
+
--criteria "latest" \
|
17 |
+
--n-samples 10 \
|
18 |
+
--seed 42
|
src/COP-GEN-Beta/tools/extract_parquet.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import pyarrow.parquet as pq
|
4 |
+
import pandas as pd
|
5 |
+
import json
|
6 |
+
from pathlib import Path
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
def parse_args():
|
10 |
+
parser = argparse.ArgumentParser(description='Extract all content from a parquet file')
|
11 |
+
parser.add_argument('--parquet-file', type=str, required=True,
|
12 |
+
help='Name of the parquet file to extract')
|
13 |
+
parser.add_argument('--output-dir', type=str, default='./extracted_data',
|
14 |
+
help='Directory to save extracted data (default: ./extracted_data)')
|
15 |
+
return parser.parse_args()
|
16 |
+
|
17 |
+
def extract_parquet_content(parquet_path, output_dir):
|
18 |
+
"""Extract all content from a parquet file and save it to the output directory"""
|
19 |
+
output_dir = Path(output_dir)
|
20 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
21 |
+
|
22 |
+
print(f"Extracting data from {parquet_path} to {output_dir}")
|
23 |
+
|
24 |
+
# Open the parquet file
|
25 |
+
pf = pq.ParquetFile(parquet_path)
|
26 |
+
print(f"File contains {pf.num_row_groups} row groups")
|
27 |
+
|
28 |
+
# Process each row group
|
29 |
+
for rg_idx in range(pf.num_row_groups):
|
30 |
+
print(f"\nProcessing row group {rg_idx+1}/{pf.num_row_groups}")
|
31 |
+
|
32 |
+
# Read the row group
|
33 |
+
table = pf.read_row_group(rg_idx)
|
34 |
+
df = table.to_pandas()
|
35 |
+
|
36 |
+
# Create a directory for this row group
|
37 |
+
if pf.num_row_groups > 1:
|
38 |
+
rg_dir = output_dir / f"row_group_{rg_idx}"
|
39 |
+
else:
|
40 |
+
rg_dir = output_dir
|
41 |
+
rg_dir.mkdir(exist_ok=True)
|
42 |
+
|
43 |
+
# Get metadata to create more meaningful directory names if possible
|
44 |
+
product_id = df['product_id'][0] if 'product_id' in df.columns else f"sample_{rg_idx}"
|
45 |
+
grid_cell = df['grid_cell'][0] if 'grid_cell' in df.columns else ""
|
46 |
+
|
47 |
+
# Create a more descriptive directory name if possible
|
48 |
+
sample_dir = rg_dir / f"{grid_cell}_{product_id}" if grid_cell else rg_dir / product_id
|
49 |
+
sample_dir.mkdir(exist_ok=True)
|
50 |
+
|
51 |
+
# Extract and save metadata to JSON
|
52 |
+
metadata = {}
|
53 |
+
for col in df.columns:
|
54 |
+
if df[col].dtype != 'object' or (len(df[col]) > 0 and not isinstance(df[col].iloc[0], bytes)):
|
55 |
+
# Convert non-binary data to JSON-serializable format
|
56 |
+
try:
|
57 |
+
if col == 'timestamp' and pd.api.types.is_datetime64_any_dtype(df[col]):
|
58 |
+
metadata[col] = df[col].iloc[0].strftime('%Y-%m-%d %H:%M:%S')
|
59 |
+
else:
|
60 |
+
value = df[col].iloc[0]
|
61 |
+
# Handle numpy types
|
62 |
+
if hasattr(value, 'item'):
|
63 |
+
metadata[col] = value.item()
|
64 |
+
else:
|
65 |
+
metadata[col] = value
|
66 |
+
except Exception as e:
|
67 |
+
metadata[col] = f"Error converting: {str(e)}"
|
68 |
+
|
69 |
+
# Save metadata
|
70 |
+
with open(sample_dir / "metadata.json", "w") as f:
|
71 |
+
json.dump(metadata, f, indent=2, default=str)
|
72 |
+
|
73 |
+
# Extract and save binary data
|
74 |
+
binary_columns = []
|
75 |
+
for col in df.columns:
|
76 |
+
if df[col].dtype == 'object' and len(df[col]) > 0 and isinstance(df[col].iloc[0], bytes):
|
77 |
+
binary_columns.append(col)
|
78 |
+
binary_data = df[col].iloc[0]
|
79 |
+
|
80 |
+
# Determine file extension based on common column naming conventions
|
81 |
+
if col == 'thumbnail':
|
82 |
+
extension = '.png'
|
83 |
+
elif col.startswith('B') and col[1:].isdigit(): # Sentinel-2 bands
|
84 |
+
extension = '.tif'
|
85 |
+
elif col in ['vv', 'vh']: # Sentinel-1 bands
|
86 |
+
extension = '.tif'
|
87 |
+
elif col == 'DEM': # DEM data
|
88 |
+
extension = '.tif'
|
89 |
+
elif col == 'cloud_mask':
|
90 |
+
extension = '.tif'
|
91 |
+
else:
|
92 |
+
extension = '.bin' # Generic binary data
|
93 |
+
|
94 |
+
# Save binary data
|
95 |
+
file_path = sample_dir / f"{col}{extension}"
|
96 |
+
with open(file_path, "wb") as f:
|
97 |
+
f.write(binary_data)
|
98 |
+
print(f" Saved {col}{extension}, size: {len(binary_data)/1024:.1f} KB")
|
99 |
+
|
100 |
+
print(f" Extracted metadata and {len(binary_columns)} binary files to {sample_dir}")
|
101 |
+
|
102 |
+
def main():
|
103 |
+
args = parse_args()
|
104 |
+
parquet_path = Path(args.parquet_file)
|
105 |
+
|
106 |
+
if not parquet_path.exists():
|
107 |
+
print(f"Error: File {parquet_path} not found")
|
108 |
+
sys.exit(1)
|
109 |
+
|
110 |
+
# Extract all content
|
111 |
+
extract_parquet_content(parquet_path, args.output_dir)
|
112 |
+
print("\nExtraction complete!")
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
main()
|
src/COP-GEN-Beta/tools/fid_score.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
|
2 |
+
|
3 |
+
The FID metric calculates the distance between two distributions of images.
|
4 |
+
Typically, we have summary statistics (mean & covariance matrix) of one
|
5 |
+
of these distributions, while the 2nd distribution is given by a GAN.
|
6 |
+
|
7 |
+
When run as a stand-alone program, it compares the distribution of
|
8 |
+
images that are stored as PNG/JPEG at a specified location with a
|
9 |
+
distribution given by summary statistics (in pickle format).
|
10 |
+
|
11 |
+
The FID is calculated by assuming that X_1 and X_2 are the activations of
|
12 |
+
the pool_3 layer of the inception net for generated samples and real world
|
13 |
+
samples respectively.
|
14 |
+
|
15 |
+
See --help to see further details.
|
16 |
+
|
17 |
+
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
|
18 |
+
of Tensorflow
|
19 |
+
|
20 |
+
Copyright 2018 Institute of Bioinformatics, JKU Linz
|
21 |
+
|
22 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
23 |
+
you may not use this file except in compliance with the License.
|
24 |
+
You may obtain a copy of the License at
|
25 |
+
|
26 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
27 |
+
|
28 |
+
Unless required by applicable law or agreed to in writing, software
|
29 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
30 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
31 |
+
See the License for the specific language governing permissions and
|
32 |
+
limitations under the License.
|
33 |
+
"""
|
34 |
+
import os
|
35 |
+
import pathlib
|
36 |
+
|
37 |
+
import numpy as np
|
38 |
+
import torch
|
39 |
+
import torchvision.transforms as TF
|
40 |
+
from PIL import Image
|
41 |
+
from scipy import linalg
|
42 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
43 |
+
|
44 |
+
try:
|
45 |
+
from tqdm import tqdm
|
46 |
+
except ImportError:
|
47 |
+
# If tqdm is not available, provide a mock version of it
|
48 |
+
def tqdm(x):
|
49 |
+
return x
|
50 |
+
|
51 |
+
from .inception import InceptionV3
|
52 |
+
|
53 |
+
|
54 |
+
IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
|
55 |
+
'tif', 'tiff', 'webp'}
|
56 |
+
|
57 |
+
|
58 |
+
class ImagePathDataset(torch.utils.data.Dataset):
|
59 |
+
def __init__(self, files, transforms=None):
|
60 |
+
self.files = files
|
61 |
+
self.transforms = transforms
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
return len(self.files)
|
65 |
+
|
66 |
+
def __getitem__(self, i):
|
67 |
+
path = self.files[i]
|
68 |
+
img = Image.open(path).convert('RGB')
|
69 |
+
if self.transforms is not None:
|
70 |
+
img = self.transforms(img)
|
71 |
+
return img
|
72 |
+
|
73 |
+
|
74 |
+
def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8):
|
75 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
76 |
+
|
77 |
+
Params:
|
78 |
+
-- files : List of image files paths
|
79 |
+
-- model : Instance of inception model
|
80 |
+
-- batch_size : Batch size of images for the model to process at once.
|
81 |
+
Make sure that the number of samples is a multiple of
|
82 |
+
the batch size, otherwise some samples are ignored. This
|
83 |
+
behavior is retained to match the original FID score
|
84 |
+
implementation.
|
85 |
+
-- dims : Dimensionality of features returned by Inception
|
86 |
+
-- device : Device to run calculations
|
87 |
+
-- num_workers : Number of parallel dataloader workers
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
91 |
+
activations of the given tensor when feeding inception with the
|
92 |
+
query tensor.
|
93 |
+
"""
|
94 |
+
model.eval()
|
95 |
+
|
96 |
+
if batch_size > len(files):
|
97 |
+
print(('Warning: batch size is bigger than the data size. '
|
98 |
+
'Setting batch size to data size'))
|
99 |
+
batch_size = len(files)
|
100 |
+
|
101 |
+
dataset = ImagePathDataset(files, transforms=TF.ToTensor())
|
102 |
+
dataloader = torch.utils.data.DataLoader(dataset,
|
103 |
+
batch_size=batch_size,
|
104 |
+
shuffle=False,
|
105 |
+
drop_last=False,
|
106 |
+
num_workers=num_workers)
|
107 |
+
|
108 |
+
pred_arr = np.empty((len(files), dims))
|
109 |
+
|
110 |
+
start_idx = 0
|
111 |
+
|
112 |
+
for batch in tqdm(dataloader):
|
113 |
+
batch = batch.to(device)
|
114 |
+
|
115 |
+
with torch.no_grad():
|
116 |
+
pred = model(batch)[0]
|
117 |
+
|
118 |
+
# If model output is not scalar, apply global spatial average pooling.
|
119 |
+
# This happens if you choose a dimensionality not equal 2048.
|
120 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
121 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
122 |
+
|
123 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
124 |
+
|
125 |
+
pred_arr[start_idx:start_idx + pred.shape[0]] = pred
|
126 |
+
|
127 |
+
start_idx = start_idx + pred.shape[0]
|
128 |
+
|
129 |
+
return pred_arr
|
130 |
+
|
131 |
+
|
132 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
133 |
+
"""Numpy implementation of the Frechet Distance.
|
134 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
135 |
+
and X_2 ~ N(mu_2, C_2) is
|
136 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
137 |
+
|
138 |
+
Stable version by Dougal J. Sutherland.
|
139 |
+
|
140 |
+
Params:
|
141 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
142 |
+
inception net (like returned by the function 'get_predictions')
|
143 |
+
for generated samples.
|
144 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
145 |
+
representative data set.
|
146 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
147 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
148 |
+
representative data set.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
-- : The Frechet Distance.
|
152 |
+
"""
|
153 |
+
|
154 |
+
mu1 = np.atleast_1d(mu1)
|
155 |
+
mu2 = np.atleast_1d(mu2)
|
156 |
+
|
157 |
+
sigma1 = np.atleast_2d(sigma1)
|
158 |
+
sigma2 = np.atleast_2d(sigma2)
|
159 |
+
|
160 |
+
assert mu1.shape == mu2.shape, \
|
161 |
+
'Training and test mean vectors have different lengths'
|
162 |
+
assert sigma1.shape == sigma2.shape, \
|
163 |
+
'Training and test covariances have different dimensions'
|
164 |
+
|
165 |
+
diff = mu1 - mu2
|
166 |
+
|
167 |
+
# Product might be almost singular
|
168 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
169 |
+
if not np.isfinite(covmean).all():
|
170 |
+
msg = ('fid calculation produces singular product; '
|
171 |
+
'adding %s to diagonal of cov estimates') % eps
|
172 |
+
print(msg)
|
173 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
174 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
175 |
+
|
176 |
+
# Numerical error might give slight imaginary component
|
177 |
+
if np.iscomplexobj(covmean):
|
178 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
179 |
+
m = np.max(np.abs(covmean.imag))
|
180 |
+
raise ValueError('Imaginary component {}'.format(m))
|
181 |
+
covmean = covmean.real
|
182 |
+
|
183 |
+
tr_covmean = np.trace(covmean)
|
184 |
+
|
185 |
+
return (diff.dot(diff) + np.trace(sigma1)
|
186 |
+
+ np.trace(sigma2) - 2 * tr_covmean)
|
187 |
+
|
188 |
+
|
189 |
+
def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
|
190 |
+
device='cpu', num_workers=8):
|
191 |
+
"""Calculation of the statistics used by the FID.
|
192 |
+
Params:
|
193 |
+
-- files : List of image files paths
|
194 |
+
-- model : Instance of inception model
|
195 |
+
-- batch_size : The images numpy array is split into batches with
|
196 |
+
batch size batch_size. A reasonable batch size
|
197 |
+
depends on the hardware.
|
198 |
+
-- dims : Dimensionality of features returned by Inception
|
199 |
+
-- device : Device to run calculations
|
200 |
+
-- num_workers : Number of parallel dataloader workers
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
204 |
+
the inception model.
|
205 |
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
206 |
+
the inception model.
|
207 |
+
"""
|
208 |
+
act = get_activations(files, model, batch_size, dims, device, num_workers)
|
209 |
+
mu = np.mean(act, axis=0)
|
210 |
+
sigma = np.cov(act, rowvar=False)
|
211 |
+
return mu, sigma
|
212 |
+
|
213 |
+
|
214 |
+
def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8):
|
215 |
+
if path.endswith('.npz'):
|
216 |
+
with np.load(path) as f:
|
217 |
+
m, s = f['mu'][:], f['sigma'][:]
|
218 |
+
else:
|
219 |
+
path = pathlib.Path(path)
|
220 |
+
files = sorted([file for ext in IMAGE_EXTENSIONS
|
221 |
+
for file in path.glob('*.{}'.format(ext))])
|
222 |
+
m, s = calculate_activation_statistics(files, model, batch_size,
|
223 |
+
dims, device, num_workers)
|
224 |
+
|
225 |
+
return m, s
|
226 |
+
|
227 |
+
|
228 |
+
def save_statistics_of_path(path, out_path, device=None, batch_size=50, dims=2048, num_workers=8):
|
229 |
+
if device is None:
|
230 |
+
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
|
231 |
+
else:
|
232 |
+
device = torch.device(device)
|
233 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
234 |
+
model = InceptionV3([block_idx]).to(device)
|
235 |
+
m1, s1 = compute_statistics_of_path(path, model, batch_size, dims, device, num_workers)
|
236 |
+
np.savez(out_path, mu=m1, sigma=s1)
|
237 |
+
|
238 |
+
|
239 |
+
def calculate_fid_given_paths(paths, device=None, batch_size=50, dims=2048, num_workers=8):
|
240 |
+
"""Calculates the FID of two paths"""
|
241 |
+
if device is None:
|
242 |
+
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
|
243 |
+
else:
|
244 |
+
device = torch.device(device)
|
245 |
+
|
246 |
+
for p in paths:
|
247 |
+
if not os.path.exists(p):
|
248 |
+
raise RuntimeError('Invalid path: %s' % p)
|
249 |
+
|
250 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
251 |
+
|
252 |
+
model = InceptionV3([block_idx]).to(device)
|
253 |
+
|
254 |
+
m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
|
255 |
+
dims, device, num_workers)
|
256 |
+
m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
|
257 |
+
dims, device, num_workers)
|
258 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
259 |
+
|
260 |
+
return fid_value
|
src/COP-GEN-Beta/tools/inception.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
try:
|
7 |
+
from torchvision.models.utils import load_state_dict_from_url
|
8 |
+
except ImportError:
|
9 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
10 |
+
|
11 |
+
# Inception weights ported to Pytorch from
|
12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
|
14 |
+
|
15 |
+
|
16 |
+
class InceptionV3(nn.Module):
|
17 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
18 |
+
|
19 |
+
# Index of default block of inception to return,
|
20 |
+
# corresponds to output of final average pooling
|
21 |
+
DEFAULT_BLOCK_INDEX = 3
|
22 |
+
|
23 |
+
# Maps feature dimensionality to their output blocks indices
|
24 |
+
BLOCK_INDEX_BY_DIM = {
|
25 |
+
64: 0, # First max pooling features
|
26 |
+
192: 1, # Second max pooling featurs
|
27 |
+
768: 2, # Pre-aux classifier features
|
28 |
+
2048: 3 # Final average pooling features
|
29 |
+
}
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
output_blocks=(DEFAULT_BLOCK_INDEX,),
|
33 |
+
resize_input=True,
|
34 |
+
normalize_input=True,
|
35 |
+
requires_grad=False,
|
36 |
+
use_fid_inception=True):
|
37 |
+
"""Build pretrained InceptionV3
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
output_blocks : list of int
|
42 |
+
Indices of blocks to return features of. Possible values are:
|
43 |
+
- 0: corresponds to output of first max pooling
|
44 |
+
- 1: corresponds to output of second max pooling
|
45 |
+
- 2: corresponds to output which is fed to aux classifier
|
46 |
+
- 3: corresponds to output of final average pooling
|
47 |
+
resize_input : bool
|
48 |
+
If true, bilinearly resizes input to width and height 299 before
|
49 |
+
feeding input to model. As the network without fully connected
|
50 |
+
layers is fully convolutional, it should be able to handle inputs
|
51 |
+
of arbitrary size, so resizing might not be strictly needed
|
52 |
+
normalize_input : bool
|
53 |
+
If true, scales the input from range (0, 1) to the range the
|
54 |
+
pretrained Inception network expects, namely (-1, 1)
|
55 |
+
requires_grad : bool
|
56 |
+
If true, parameters of the model require gradients. Possibly useful
|
57 |
+
for finetuning the network
|
58 |
+
use_fid_inception : bool
|
59 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
60 |
+
FID implementation. If false, uses the pretrained Inception model
|
61 |
+
available in torchvision. The FID Inception model has different
|
62 |
+
weights and a slightly different structure from torchvision's
|
63 |
+
Inception model. If you want to compute FID scores, you are
|
64 |
+
strongly advised to set this parameter to true to get comparable
|
65 |
+
results.
|
66 |
+
"""
|
67 |
+
super(InceptionV3, self).__init__()
|
68 |
+
|
69 |
+
self.resize_input = resize_input
|
70 |
+
self.normalize_input = normalize_input
|
71 |
+
self.output_blocks = sorted(output_blocks)
|
72 |
+
self.last_needed_block = max(output_blocks)
|
73 |
+
|
74 |
+
assert self.last_needed_block <= 3, \
|
75 |
+
'Last possible output block index is 3'
|
76 |
+
|
77 |
+
self.blocks = nn.ModuleList()
|
78 |
+
|
79 |
+
if use_fid_inception:
|
80 |
+
inception = fid_inception_v3()
|
81 |
+
else:
|
82 |
+
inception = _inception_v3(pretrained=True)
|
83 |
+
|
84 |
+
# Block 0: input to maxpool1
|
85 |
+
block0 = [
|
86 |
+
inception.Conv2d_1a_3x3,
|
87 |
+
inception.Conv2d_2a_3x3,
|
88 |
+
inception.Conv2d_2b_3x3,
|
89 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
90 |
+
]
|
91 |
+
self.blocks.append(nn.Sequential(*block0))
|
92 |
+
|
93 |
+
# Block 1: maxpool1 to maxpool2
|
94 |
+
if self.last_needed_block >= 1:
|
95 |
+
block1 = [
|
96 |
+
inception.Conv2d_3b_1x1,
|
97 |
+
inception.Conv2d_4a_3x3,
|
98 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
99 |
+
]
|
100 |
+
self.blocks.append(nn.Sequential(*block1))
|
101 |
+
|
102 |
+
# Block 2: maxpool2 to aux classifier
|
103 |
+
if self.last_needed_block >= 2:
|
104 |
+
block2 = [
|
105 |
+
inception.Mixed_5b,
|
106 |
+
inception.Mixed_5c,
|
107 |
+
inception.Mixed_5d,
|
108 |
+
inception.Mixed_6a,
|
109 |
+
inception.Mixed_6b,
|
110 |
+
inception.Mixed_6c,
|
111 |
+
inception.Mixed_6d,
|
112 |
+
inception.Mixed_6e,
|
113 |
+
]
|
114 |
+
self.blocks.append(nn.Sequential(*block2))
|
115 |
+
|
116 |
+
# Block 3: aux classifier to final avgpool
|
117 |
+
if self.last_needed_block >= 3:
|
118 |
+
block3 = [
|
119 |
+
inception.Mixed_7a,
|
120 |
+
inception.Mixed_7b,
|
121 |
+
inception.Mixed_7c,
|
122 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
123 |
+
]
|
124 |
+
self.blocks.append(nn.Sequential(*block3))
|
125 |
+
|
126 |
+
for param in self.parameters():
|
127 |
+
param.requires_grad = requires_grad
|
128 |
+
|
129 |
+
def forward(self, inp):
|
130 |
+
"""Get Inception feature maps
|
131 |
+
|
132 |
+
Parameters
|
133 |
+
----------
|
134 |
+
inp : torch.autograd.Variable
|
135 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
136 |
+
range (0, 1)
|
137 |
+
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
141 |
+
block, sorted ascending by index
|
142 |
+
"""
|
143 |
+
outp = []
|
144 |
+
x = inp
|
145 |
+
|
146 |
+
if self.resize_input:
|
147 |
+
x = F.interpolate(x,
|
148 |
+
size=(299, 299),
|
149 |
+
mode='bilinear',
|
150 |
+
align_corners=False)
|
151 |
+
|
152 |
+
if self.normalize_input:
|
153 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
154 |
+
|
155 |
+
for idx, block in enumerate(self.blocks):
|
156 |
+
x = block(x)
|
157 |
+
if idx in self.output_blocks:
|
158 |
+
outp.append(x)
|
159 |
+
|
160 |
+
if idx == self.last_needed_block:
|
161 |
+
break
|
162 |
+
|
163 |
+
return outp
|
164 |
+
|
165 |
+
|
166 |
+
def _inception_v3(*args, **kwargs):
|
167 |
+
"""Wraps `torchvision.models.inception_v3`
|
168 |
+
|
169 |
+
Skips default weight inititialization if supported by torchvision version.
|
170 |
+
See https://github.com/mseitzer/pytorch-fid/issues/28.
|
171 |
+
"""
|
172 |
+
try:
|
173 |
+
version = tuple(map(int, torchvision.__version__.split('.')[:2]))
|
174 |
+
except ValueError:
|
175 |
+
# Just a caution against weird version strings
|
176 |
+
version = (0,)
|
177 |
+
|
178 |
+
if version >= (0, 6):
|
179 |
+
kwargs['init_weights'] = False
|
180 |
+
|
181 |
+
return torchvision.models.inception_v3(*args, **kwargs)
|
182 |
+
|
183 |
+
|
184 |
+
def fid_inception_v3():
|
185 |
+
"""Build pretrained Inception model for FID computation
|
186 |
+
|
187 |
+
The Inception model for FID computation uses a different set of weights
|
188 |
+
and has a slightly different structure than torchvision's Inception.
|
189 |
+
|
190 |
+
This method first constructs torchvision's Inception and then patches the
|
191 |
+
necessary parts that are different in the FID Inception model.
|
192 |
+
"""
|
193 |
+
inception = _inception_v3(num_classes=1008,
|
194 |
+
aux_logits=False,
|
195 |
+
pretrained=False)
|
196 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
197 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
198 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
199 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
200 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
201 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
202 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
203 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
204 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
205 |
+
|
206 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
207 |
+
inception.load_state_dict(state_dict)
|
208 |
+
return inception
|
209 |
+
|
210 |
+
|
211 |
+
class FIDInceptionA(torchvision.models.inception.InceptionA):
|
212 |
+
"""InceptionA block patched for FID computation"""
|
213 |
+
def __init__(self, in_channels, pool_features):
|
214 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
215 |
+
|
216 |
+
def forward(self, x):
|
217 |
+
branch1x1 = self.branch1x1(x)
|
218 |
+
|
219 |
+
branch5x5 = self.branch5x5_1(x)
|
220 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
221 |
+
|
222 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
223 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
224 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
225 |
+
|
226 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
227 |
+
# its average calculation
|
228 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
229 |
+
count_include_pad=False)
|
230 |
+
branch_pool = self.branch_pool(branch_pool)
|
231 |
+
|
232 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
233 |
+
return torch.cat(outputs, 1)
|
234 |
+
|
235 |
+
|
236 |
+
class FIDInceptionC(torchvision.models.inception.InceptionC):
|
237 |
+
"""InceptionC block patched for FID computation"""
|
238 |
+
def __init__(self, in_channels, channels_7x7):
|
239 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
branch1x1 = self.branch1x1(x)
|
243 |
+
|
244 |
+
branch7x7 = self.branch7x7_1(x)
|
245 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
246 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
247 |
+
|
248 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
249 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
250 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
251 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
252 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
253 |
+
|
254 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
255 |
+
# its average calculation
|
256 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
257 |
+
count_include_pad=False)
|
258 |
+
branch_pool = self.branch_pool(branch_pool)
|
259 |
+
|
260 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
261 |
+
return torch.cat(outputs, 1)
|
262 |
+
|
263 |
+
|
264 |
+
class FIDInceptionE_1(torchvision.models.inception.InceptionE):
|
265 |
+
"""First InceptionE block patched for FID computation"""
|
266 |
+
def __init__(self, in_channels):
|
267 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
268 |
+
|
269 |
+
def forward(self, x):
|
270 |
+
branch1x1 = self.branch1x1(x)
|
271 |
+
|
272 |
+
branch3x3 = self.branch3x3_1(x)
|
273 |
+
branch3x3 = [
|
274 |
+
self.branch3x3_2a(branch3x3),
|
275 |
+
self.branch3x3_2b(branch3x3),
|
276 |
+
]
|
277 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
278 |
+
|
279 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
280 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
281 |
+
branch3x3dbl = [
|
282 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
283 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
284 |
+
]
|
285 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
286 |
+
|
287 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
288 |
+
# its average calculation
|
289 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
290 |
+
count_include_pad=False)
|
291 |
+
branch_pool = self.branch_pool(branch_pool)
|
292 |
+
|
293 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
294 |
+
return torch.cat(outputs, 1)
|
295 |
+
|
296 |
+
|
297 |
+
class FIDInceptionE_2(torchvision.models.inception.InceptionE):
|
298 |
+
"""Second InceptionE block patched for FID computation"""
|
299 |
+
def __init__(self, in_channels):
|
300 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
301 |
+
|
302 |
+
def forward(self, x):
|
303 |
+
branch1x1 = self.branch1x1(x)
|
304 |
+
|
305 |
+
branch3x3 = self.branch3x3_1(x)
|
306 |
+
branch3x3 = [
|
307 |
+
self.branch3x3_2a(branch3x3),
|
308 |
+
self.branch3x3_2b(branch3x3),
|
309 |
+
]
|
310 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
311 |
+
|
312 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
313 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
314 |
+
branch3x3dbl = [
|
315 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
316 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
317 |
+
]
|
318 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
319 |
+
|
320 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
321 |
+
# pooling. This is likely an error in this specific Inception
|
322 |
+
# implementation, as other Inception models use average pooling here
|
323 |
+
# (which matches the description in the paper).
|
324 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
325 |
+
branch_pool = self.branch_pool(branch_pool)
|
326 |
+
|
327 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
328 |
+
return torch.cat(outputs, 1)
|
src/COP-GEN-Beta/tools/inspect_parquet.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import pyarrow.parquet as pq
|
4 |
+
import pandas as pd
|
5 |
+
from pathlib import Path
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
def parse_args():
|
9 |
+
parser = argparse.ArgumentParser(description='Extract information from a parquet file')
|
10 |
+
parser.add_argument('--parquet-file', type=str, required=True,
|
11 |
+
help='Name of the parquet file in the current directory')
|
12 |
+
parser.add_argument('--row-group', type=int, default=None,
|
13 |
+
help='Specific row group to extract (default: all row groups)')
|
14 |
+
parser.add_argument('--sample-binary', action='store_true',
|
15 |
+
help='Print sample of binary content (first 100 bytes)')
|
16 |
+
return parser.parse_args()
|
17 |
+
|
18 |
+
def main():
|
19 |
+
args = parse_args()
|
20 |
+
parquet_path = Path(args.parquet_file)
|
21 |
+
|
22 |
+
if not parquet_path.exists():
|
23 |
+
print(f"Error: File {parquet_path} not found")
|
24 |
+
sys.exit(1)
|
25 |
+
|
26 |
+
print(f"\n--- Analyzing parquet file: {parquet_path} ---\n")
|
27 |
+
|
28 |
+
# Open the parquet file
|
29 |
+
pf = pq.ParquetFile(parquet_path)
|
30 |
+
|
31 |
+
# Print basic file information
|
32 |
+
print(f"File size: {parquet_path.stat().st_size / (1024*1024):.2f} MB")
|
33 |
+
print(f"Number of row groups: {pf.num_row_groups}")
|
34 |
+
print(f"Number of rows: {pf.metadata.num_rows}")
|
35 |
+
print(f"Number of columns: {len(pf.schema_arrow)}")
|
36 |
+
|
37 |
+
# Print schema information
|
38 |
+
print("\nSchema:")
|
39 |
+
for i, field in enumerate(pf.schema_arrow):
|
40 |
+
print(f" {i+1}. {field.name}: {field.type}")
|
41 |
+
|
42 |
+
# Process row groups
|
43 |
+
row_groups = [args.row_group] if args.row_group is not None else range(pf.num_row_groups)
|
44 |
+
|
45 |
+
for rg_idx in row_groups:
|
46 |
+
if rg_idx >= pf.num_row_groups:
|
47 |
+
print(f"Error: Row group {rg_idx} does not exist (max: {pf.num_row_groups-1})")
|
48 |
+
continue
|
49 |
+
|
50 |
+
print(f"\n--- Row Group {rg_idx} ---")
|
51 |
+
# Get row group metadata
|
52 |
+
rg_metadata = pf.metadata.row_group(rg_idx)
|
53 |
+
print(f"Row count: {rg_metadata.num_rows}")
|
54 |
+
|
55 |
+
# Read the row group
|
56 |
+
table = pf.read_row_group(rg_idx)
|
57 |
+
df = table.to_pandas()
|
58 |
+
|
59 |
+
# Display information about each column
|
60 |
+
print("\nColumn information:")
|
61 |
+
for col_name in df.columns:
|
62 |
+
col_data = df[col_name]
|
63 |
+
dtype = col_data.dtype
|
64 |
+
|
65 |
+
if dtype == 'object':
|
66 |
+
# Check if it's binary data
|
67 |
+
if len(col_data) > 0 and isinstance(col_data.iloc[0], bytes):
|
68 |
+
item_size = len(col_data.iloc[0])
|
69 |
+
print(f" {col_name}: Binary data, size: {item_size / 1024:.2f} KB")
|
70 |
+
|
71 |
+
if args.sample_binary and item_size > 0:
|
72 |
+
print(f" Sample (first 100 bytes): {col_data.iloc[0][:100]}")
|
73 |
+
else:
|
74 |
+
# For non-binary object columns
|
75 |
+
print(f" {col_name}: Object type, example: {col_data.iloc[0]}")
|
76 |
+
else:
|
77 |
+
# For numeric or other columns
|
78 |
+
if col_data.size > 0:
|
79 |
+
print(f" {col_name}: {dtype}, min: {col_data.min()}, max: {col_data.max()}, example: {col_data.iloc[0]}")
|
80 |
+
else:
|
81 |
+
print(f" {col_name}: {dtype}, empty column")
|
82 |
+
|
83 |
+
# Print specific metadata fields for Major-TOM dataset
|
84 |
+
if 'product_id' in df.columns:
|
85 |
+
print(f"\nProduct ID: {df['product_id'].iloc[0]}")
|
86 |
+
if 'grid_cell' in df.columns:
|
87 |
+
print(f"Grid Cell: {df['grid_cell'].iloc[0]}")
|
88 |
+
if 'timestamp' in df.columns:
|
89 |
+
print(f"Timestamp: {df['timestamp'].iloc[0]}")
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
main()
|
src/COP-GEN-Beta/tools/print_parquet_urls.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import pyarrow.parquet as pq
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
def parse_args():
|
7 |
+
parser = argparse.ArgumentParser(description='Read metadata.parquet and print download URLs')
|
8 |
+
parser.add_argument('--metadata-path', type=str, required=True,
|
9 |
+
help='Path to the metadata.parquet file')
|
10 |
+
return parser.parse_args()
|
11 |
+
|
12 |
+
def main():
|
13 |
+
args = parse_args()
|
14 |
+
metadata_path = Path(args.metadata_path)
|
15 |
+
|
16 |
+
# Read the parquet file
|
17 |
+
print(f"Reading metadata from: {metadata_path}")
|
18 |
+
df = pq.read_table(metadata_path).to_pandas()
|
19 |
+
|
20 |
+
# Extract unique parquet URLs
|
21 |
+
unique_urls = df['parquet_url'].unique()
|
22 |
+
|
23 |
+
# Print the URLs
|
24 |
+
print(f"\nFound {len(unique_urls)} unique parquet file URLs:")
|
25 |
+
for url in unique_urls:
|
26 |
+
print(url)
|
27 |
+
|
28 |
+
print(f"\nTotal number of samples in metadata: {len(df)}")
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
main()
|
src/COP-GEN-Beta/train_triffuser_discrete.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ml_collections
|
2 |
+
import torch
|
3 |
+
from torch import multiprocessing as mp
|
4 |
+
from datasets import get_dataset
|
5 |
+
from torchvision.utils import make_grid, save_image
|
6 |
+
import utils
|
7 |
+
import einops
|
8 |
+
from torch.utils._pytree import tree_map
|
9 |
+
import accelerate
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
|
13 |
+
import tempfile
|
14 |
+
from tools.fid_score import calculate_fid_given_paths
|
15 |
+
from absl import logging
|
16 |
+
import builtins
|
17 |
+
import os
|
18 |
+
import wandb
|
19 |
+
import libs.autoencoder
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
|
23 |
+
def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
|
24 |
+
_betas = (
|
25 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
26 |
+
)
|
27 |
+
return _betas.numpy()
|
28 |
+
|
29 |
+
|
30 |
+
def get_skip(alphas, betas):
|
31 |
+
N = len(betas) - 1
|
32 |
+
skip_alphas = np.ones([N + 1, N + 1], dtype=betas.dtype)
|
33 |
+
for s in range(N + 1):
|
34 |
+
skip_alphas[s, s + 1:] = alphas[s + 1:].cumprod()
|
35 |
+
skip_betas = np.zeros([N + 1, N + 1], dtype=betas.dtype)
|
36 |
+
for t in range(N + 1):
|
37 |
+
prod = betas[1: t + 1] * skip_alphas[1: t + 1, t]
|
38 |
+
skip_betas[:t, t] = (prod[::-1].cumsum())[::-1]
|
39 |
+
return skip_alphas, skip_betas
|
40 |
+
|
41 |
+
|
42 |
+
def stp(s, ts: torch.Tensor): # scalar tensor product
|
43 |
+
if isinstance(s, np.ndarray):
|
44 |
+
s = torch.from_numpy(s).type_as(ts)
|
45 |
+
extra_dims = (1,) * (ts.dim() - 1)
|
46 |
+
return s.view(-1, *extra_dims) * ts
|
47 |
+
|
48 |
+
|
49 |
+
def mos(a, start_dim=1): # mean of square
|
50 |
+
return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
|
51 |
+
|
52 |
+
|
53 |
+
class Schedule(object): # discrete time
|
54 |
+
def __init__(self, _betas):
|
55 |
+
r""" _betas[0...999] = betas[1...1000]
|
56 |
+
for n>=1, betas[n] is the variance of q(xn|xn-1)
|
57 |
+
for n=0, betas[0]=0
|
58 |
+
"""
|
59 |
+
|
60 |
+
self._betas = _betas
|
61 |
+
self.betas = np.append(0., _betas)
|
62 |
+
self.alphas = 1. - self.betas
|
63 |
+
self.N = len(_betas)
|
64 |
+
|
65 |
+
assert isinstance(self.betas, np.ndarray) and self.betas[0] == 0
|
66 |
+
assert isinstance(self.alphas, np.ndarray) and self.alphas[0] == 1
|
67 |
+
assert len(self.betas) == len(self.alphas)
|
68 |
+
|
69 |
+
# skip_alphas[s, t] = alphas[s + 1: t + 1].prod()
|
70 |
+
self.skip_alphas, self.skip_betas = get_skip(self.alphas, self.betas)
|
71 |
+
self.cum_alphas = self.skip_alphas[0] # cum_alphas = alphas.cumprod()
|
72 |
+
self.cum_betas = self.skip_betas[0]
|
73 |
+
self.snr = self.cum_alphas / self.cum_betas
|
74 |
+
|
75 |
+
def tilde_beta(self, s, t):
|
76 |
+
return self.skip_betas[s, t] * self.cum_betas[s] / self.cum_betas[t]
|
77 |
+
|
78 |
+
def sample(self, x0, multi_modal=False): # sample from q(xn|x0), where n is uniform
|
79 |
+
if multi_modal:
|
80 |
+
n_list = []
|
81 |
+
eps_list = []
|
82 |
+
xn_list = []
|
83 |
+
for x0_i in x0:
|
84 |
+
n = np.random.choice(list(range(1, self.N + 1)), (len(x0_i),))
|
85 |
+
eps = torch.randn_like(x0_i)
|
86 |
+
xn = stp(self.cum_alphas[n] ** 0.5, x0_i) + stp(self.cum_betas[n] ** 0.5, eps)
|
87 |
+
n_list.append(torch.tensor(n, device=x0_i.device))
|
88 |
+
eps_list.append(eps)
|
89 |
+
xn_list.append(xn)
|
90 |
+
return n_list, eps_list, xn_list
|
91 |
+
else:
|
92 |
+
n = np.random.choice(list(range(1, self.N + 1)), (len(x0),))
|
93 |
+
eps = torch.randn_like(x0)
|
94 |
+
xn = stp(self.cum_alphas[n] ** 0.5, x0) + stp(self.cum_betas[n] ** 0.5, eps)
|
95 |
+
return torch.tensor(n, device=x0.device), eps, xn
|
96 |
+
|
97 |
+
def __repr__(self):
|
98 |
+
return f'Schedule({self.betas[:10]}..., {self.N})'
|
99 |
+
|
100 |
+
|
101 |
+
def LSimple(x0, nnet, schedule, multi_modal=False, **kwargs):
|
102 |
+
if multi_modal:
|
103 |
+
n_list, eps_list, xn_list = schedule.sample(x0, multi_modal=multi_modal) # n in {1, ..., 1000}
|
104 |
+
eps_pred = nnet(xn_list, n_list, **kwargs)
|
105 |
+
return sum(mos(n - np_) for n, np_ in zip(eps_list, eps_pred))
|
106 |
+
else:
|
107 |
+
n, eps, xn = schedule.sample(x0) # n in {1, ..., 1000}
|
108 |
+
eps_pred = nnet(xn, n, **kwargs)
|
109 |
+
return mos(eps - eps_pred)
|
110 |
+
|
111 |
+
|
112 |
+
def train(config):
|
113 |
+
if config.get('benchmark', False):
|
114 |
+
torch.backends.cudnn.benchmark = True
|
115 |
+
torch.backends.cudnn.deterministic = False
|
116 |
+
|
117 |
+
mp.set_start_method('spawn')
|
118 |
+
accelerator = accelerate.Accelerator()
|
119 |
+
device = accelerator.device
|
120 |
+
accelerate.utils.set_seed(config.seed, device_specific=True)
|
121 |
+
logging.info(f'Process {accelerator.process_index} using device: {device}')
|
122 |
+
|
123 |
+
config.mixed_precision = accelerator.mixed_precision
|
124 |
+
config = ml_collections.FrozenConfigDict(config)
|
125 |
+
|
126 |
+
assert config.train.batch_size % accelerator.num_processes == 0
|
127 |
+
mini_batch_size = config.train.batch_size // accelerator.num_processes
|
128 |
+
|
129 |
+
if accelerator.is_main_process:
|
130 |
+
os.makedirs(config.ckpt_root, exist_ok=True)
|
131 |
+
os.makedirs(config.sample_dir, exist_ok=True)
|
132 |
+
accelerator.wait_for_everyone()
|
133 |
+
if accelerator.is_main_process:
|
134 |
+
wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(),
|
135 |
+
name=config.hparams, job_type='train', mode='offline')
|
136 |
+
utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
|
137 |
+
logging.info(config)
|
138 |
+
else:
|
139 |
+
utils.set_logger(log_level='error')
|
140 |
+
builtins.print = lambda *args: None
|
141 |
+
logging.info(f'Run on {accelerator.num_processes} devices')
|
142 |
+
|
143 |
+
dataset = get_dataset(**config.dataset)
|
144 |
+
assert os.path.exists(dataset.fid_stat)
|
145 |
+
train_dataset = dataset.get_split(split='train', labeled=config.train.mode == 'cond')
|
146 |
+
train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
|
147 |
+
num_workers=8, pin_memory=True, persistent_workers=True)
|
148 |
+
|
149 |
+
train_state = utils.initialize_train_state(config, device)
|
150 |
+
nnet, nnet_ema, optimizer, train_dataset_loader = accelerator.prepare(
|
151 |
+
train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader)
|
152 |
+
lr_scheduler = train_state.lr_scheduler
|
153 |
+
train_state.resume(config.ckpt_root)
|
154 |
+
|
155 |
+
autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
|
156 |
+
autoencoder.to(device)
|
157 |
+
|
158 |
+
@ torch.cuda.amp.autocast()
|
159 |
+
def encode(_batch):
|
160 |
+
return autoencoder.encode(_batch)
|
161 |
+
|
162 |
+
@ torch.cuda.amp.autocast()
|
163 |
+
def decode(_batch):
|
164 |
+
return autoencoder.decode(_batch)
|
165 |
+
|
166 |
+
def get_data_generator():
|
167 |
+
while True:
|
168 |
+
for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
|
169 |
+
yield data
|
170 |
+
|
171 |
+
data_generator = get_data_generator()
|
172 |
+
|
173 |
+
_betas = stable_diffusion_beta_schedule()
|
174 |
+
_schedule = Schedule(_betas)
|
175 |
+
logging.info(f'use {_schedule}')
|
176 |
+
|
177 |
+
|
178 |
+
def train_step(_batch):
|
179 |
+
_metrics = dict()
|
180 |
+
optimizer.zero_grad()
|
181 |
+
if config.train.mode == 'uncond': # Multi-modal data. Sample each modality independently
|
182 |
+
if config.train.multi_modal:
|
183 |
+
_zs = [autoencoder.sample(modality) if 'feature' in config.dataset.name else encode(modality) for modality in _batch]
|
184 |
+
loss = LSimple(_zs, nnet, _schedule, multi_modal=config.train.multi_modal)
|
185 |
+
else:
|
186 |
+
_z = autoencoder.sample(_batch) if 'feature' in config.dataset.name else encode(_batch)
|
187 |
+
loss = LSimple(_z, nnet, _schedule)
|
188 |
+
elif config.train.mode == 'cond':
|
189 |
+
_z = autoencoder.sample(_batch[0]) if 'feature' in config.dataset.name else encode(_batch[0])
|
190 |
+
loss = LSimple(_z, nnet, _schedule, y=_batch[1])
|
191 |
+
else:
|
192 |
+
raise NotImplementedError(config.train.mode)
|
193 |
+
_metrics['loss'] = accelerator.gather(loss.detach()).mean()
|
194 |
+
accelerator.backward(loss.mean())
|
195 |
+
optimizer.step()
|
196 |
+
lr_scheduler.step()
|
197 |
+
train_state.ema_update(config.get('ema_rate', 0.9999))
|
198 |
+
train_state.step += 1
|
199 |
+
return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)
|
200 |
+
|
201 |
+
def dpm_solver_sample(_n_samples, _sample_steps, **kwargs):
|
202 |
+
_z_init = torch.randn(_n_samples, *config.z_shape, device=device)
|
203 |
+
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
|
204 |
+
|
205 |
+
def model_fn(x, t_continuous):
|
206 |
+
t = t_continuous * _schedule.N
|
207 |
+
eps_pre = nnet_ema(x, t, **kwargs)
|
208 |
+
return eps_pre
|
209 |
+
|
210 |
+
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
|
211 |
+
_z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / _schedule.N, T=1.)
|
212 |
+
return decode(_z)
|
213 |
+
|
214 |
+
def combine_joint(z):
|
215 |
+
z = torch.concat([einops.rearrange(z_i, 'B C H W -> B (C H W)') for z_i in z], dim=-1)
|
216 |
+
return z
|
217 |
+
|
218 |
+
def split_joint(x, n_modalities):
|
219 |
+
C, H, W = config.z_shape
|
220 |
+
z_dim = C * H * W
|
221 |
+
z = x.split([z_dim] * n_modalities, dim=1)
|
222 |
+
z = [einops.rearrange(z_i, 'B (C H W) -> B C H W', C=C, H=H, W=W) for z_i in z]
|
223 |
+
return z
|
224 |
+
|
225 |
+
def dpm_solver_sample_multi_modal(_n_modalities, _n_samples, _sample_steps, **kwargs):
|
226 |
+
"""here"""
|
227 |
+
|
228 |
+
_z_init = torch.randn(_n_modalities, _n_samples, *config.z_shape, device=device)
|
229 |
+
_z_init = combine_joint(_z_init)
|
230 |
+
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
|
231 |
+
|
232 |
+
def model_fn(x, t_continuous):
|
233 |
+
t = t_continuous * _schedule.N
|
234 |
+
|
235 |
+
timesteps = [t] * _n_modalities
|
236 |
+
z = split_joint(x, _n_modalities)
|
237 |
+
z_out = nnet_ema(z, t_imgs=timesteps)
|
238 |
+
x_out = combine_joint(z_out)
|
239 |
+
# eps_pre = nnet_ema(x, t, **kwargs)
|
240 |
+
return x_out
|
241 |
+
|
242 |
+
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
|
243 |
+
_zs = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / _schedule.N, T=1.)
|
244 |
+
_zs = split_joint(_zs, _n_modalities)
|
245 |
+
samples_unstacked = [decode(_z) for _z in _zs]
|
246 |
+
return samples_unstacked
|
247 |
+
|
248 |
+
def eval_step(n_samples, sample_steps):
|
249 |
+
logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}'
|
250 |
+
f'mini_batch_size={config.sample.mini_batch_size}')
|
251 |
+
|
252 |
+
def sample_fn(_n_samples):
|
253 |
+
if config.train.mode == 'uncond':
|
254 |
+
kwargs = dict()
|
255 |
+
elif config.train.mode == 'cond':
|
256 |
+
kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
|
257 |
+
else:
|
258 |
+
raise NotImplementedError
|
259 |
+
return dpm_solver_sample(_n_samples, sample_steps, **kwargs)
|
260 |
+
|
261 |
+
|
262 |
+
with tempfile.TemporaryDirectory() as temp_path:
|
263 |
+
path = config.sample.path or temp_path
|
264 |
+
if accelerator.is_main_process:
|
265 |
+
os.makedirs(path, exist_ok=True)
|
266 |
+
utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
|
267 |
+
|
268 |
+
_fid = 0
|
269 |
+
if accelerator.is_main_process:
|
270 |
+
_fid = calculate_fid_given_paths((dataset.fid_stat, path))
|
271 |
+
logging.info(f'step={train_state.step} fid{n_samples}={_fid}')
|
272 |
+
with open(os.path.join(config.workdir, 'eval.log'), 'a') as f:
|
273 |
+
print(f'step={train_state.step} fid{n_samples}={_fid}', file=f)
|
274 |
+
wandb.log({f'fid{n_samples}': _fid}, step=train_state.step)
|
275 |
+
_fid = torch.tensor(_fid, device=device)
|
276 |
+
_fid = accelerator.reduce(_fid, reduction='sum')
|
277 |
+
|
278 |
+
return _fid.item()
|
279 |
+
|
280 |
+
logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')
|
281 |
+
|
282 |
+
step_fid = []
|
283 |
+
while train_state.step < config.train.n_steps:
|
284 |
+
nnet.train()
|
285 |
+
batch = tree_map(lambda x: x.to(device), next(data_generator))
|
286 |
+
metrics = train_step(batch)
|
287 |
+
|
288 |
+
nnet.eval()
|
289 |
+
if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
|
290 |
+
logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
|
291 |
+
logging.info(config.workdir)
|
292 |
+
wandb.log(metrics, step=train_state.step)
|
293 |
+
|
294 |
+
if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0:
|
295 |
+
torch.cuda.empty_cache()
|
296 |
+
logging.info('Save a grid of images...')
|
297 |
+
if config.train.mode == 'uncond':
|
298 |
+
if config.train.multi_modal:
|
299 |
+
samples = dpm_solver_sample_multi_modal(_n_modalities=config.nnet.num_modalities, _n_samples=5 * 10, _sample_steps=50)
|
300 |
+
else:
|
301 |
+
samples = dpm_solver_sample(_n_samples=5 * 10, _sample_steps=50)
|
302 |
+
elif config.train.mode == 'cond':
|
303 |
+
y = einops.repeat(torch.arange(5, device=device) % dataset.K, 'nrow -> (nrow ncol)', ncol=10)
|
304 |
+
samples = dpm_solver_sample(_n_samples=5 * 10, _sample_steps=50, y=y)
|
305 |
+
else:
|
306 |
+
raise NotImplementedError
|
307 |
+
|
308 |
+
if config.train.multi_modal:
|
309 |
+
samples = torch.stack([dataset.unpreprocess(sample) for sample in samples], dim=0) # stack instead of cat
|
310 |
+
b = samples.shape[1] # batch size
|
311 |
+
# Properly interleave samples from all modalities
|
312 |
+
# For each sample index, get all modalities before moving to next sample
|
313 |
+
samples = torch.stack([samples[j, i] for i in range(b) for j in range(config.nnet.num_modalities)]).view(-1, *samples.shape[2:])
|
314 |
+
# If the number of modalities is 3 then we plot in 9 columns
|
315 |
+
n_cols = 9 if config.nnet.num_modalities == 3 else 10
|
316 |
+
samples = make_grid(samples, n_cols)
|
317 |
+
else:
|
318 |
+
samples = make_grid(dataset.unpreprocess(samples), 10)
|
319 |
+
save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png'))
|
320 |
+
wandb.log({'samples': wandb.Image(samples)}, step=train_state.step)
|
321 |
+
torch.cuda.empty_cache()
|
322 |
+
accelerator.wait_for_everyone()
|
323 |
+
|
324 |
+
if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
|
325 |
+
torch.cuda.empty_cache()
|
326 |
+
logging.info(f'Save and eval checkpoint {train_state.step}...')
|
327 |
+
if accelerator.is_main_process:
|
328 |
+
try:
|
329 |
+
train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
|
330 |
+
except Exception as e:
|
331 |
+
logging.error(f" ==> Failed to save checkpoint: {e}!!!")
|
332 |
+
accelerator.wait_for_everyone()
|
333 |
+
# TODO: Skip FID for now
|
334 |
+
# fid = eval_step(n_samples=10000, sample_steps=50) # calculate fid of the saved checkpoint
|
335 |
+
# step_fid.append((train_state.step, fid))
|
336 |
+
torch.cuda.empty_cache()
|
337 |
+
accelerator.wait_for_everyone()
|
338 |
+
|
339 |
+
logging.info(f'Finish fitting, step={train_state.step}')
|
340 |
+
logging.info(f'step_fid: {step_fid}')
|
341 |
+
step_best = sorted(step_fid, key=lambda x: x[1])[0][0]
|
342 |
+
logging.info(f'step_best: {step_best}')
|
343 |
+
train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt'))
|
344 |
+
del metrics
|
345 |
+
accelerator.wait_for_everyone()
|
346 |
+
eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps)
|
347 |
+
|
348 |
+
|
349 |
+
|
350 |
+
from absl import flags
|
351 |
+
from absl import app
|
352 |
+
from ml_collections import config_flags
|
353 |
+
import sys
|
354 |
+
from pathlib import Path
|
355 |
+
|
356 |
+
|
357 |
+
FLAGS = flags.FLAGS
|
358 |
+
config_flags.DEFINE_config_file(
|
359 |
+
"config", None, "Training configuration.", lock_config=False)
|
360 |
+
flags.mark_flags_as_required(["config"])
|
361 |
+
flags.DEFINE_string("workdir", None, "Work unit directory.")
|
362 |
+
|
363 |
+
|
364 |
+
def get_config_name():
|
365 |
+
argv = sys.argv
|
366 |
+
for i in range(1, len(argv)):
|
367 |
+
if argv[i].startswith('--config='):
|
368 |
+
return Path(argv[i].split('=')[-1]).stem
|
369 |
+
|
370 |
+
def get_config_path():
|
371 |
+
argv = sys.argv
|
372 |
+
for i in range(1, len(argv)):
|
373 |
+
if argv[i].startswith('--config='):
|
374 |
+
path = argv[i].split('=')[-1]
|
375 |
+
if path.startswith('configs/'):
|
376 |
+
path = path[len('configs/'):]
|
377 |
+
return path
|
378 |
+
|
379 |
+
def get_hparams():
|
380 |
+
argv = sys.argv
|
381 |
+
lst = []
|
382 |
+
for i in range(1, len(argv)):
|
383 |
+
assert '=' in argv[i]
|
384 |
+
if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'):
|
385 |
+
hparam, val = argv[i].split('=')
|
386 |
+
hparam = hparam.split('.')[-1]
|
387 |
+
if hparam.endswith('path'):
|
388 |
+
val = Path(val).stem
|
389 |
+
lst.append(f'{hparam}={val}')
|
390 |
+
hparams = '-'.join(lst)
|
391 |
+
if hparams == '':
|
392 |
+
hparams = 'default'
|
393 |
+
return hparams
|
394 |
+
|
395 |
+
|
396 |
+
def main(argv):
|
397 |
+
config = FLAGS.config
|
398 |
+
# config.config_name = get_config_name()
|
399 |
+
config.config_name = get_config_path().strip('.py')
|
400 |
+
config.hparams = get_hparams()
|
401 |
+
config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams)
|
402 |
+
config.ckpt_root = os.path.join(config.workdir, 'ckpts')
|
403 |
+
config.sample_dir = os.path.join(config.workdir, 'samples')
|
404 |
+
train(config)
|
405 |
+
|
406 |
+
|
407 |
+
if __name__ == "__main__":
|
408 |
+
app.run(main)
|
src/COP-GEN-Beta/utils.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from tqdm import tqdm
|
6 |
+
from torchvision.utils import save_image
|
7 |
+
from absl import logging
|
8 |
+
from PIL import Image, ImageDraw, ImageFont
|
9 |
+
|
10 |
+
def set_logger(log_level='info', fname=None):
|
11 |
+
import logging as _logging
|
12 |
+
handler = logging.get_absl_handler()
|
13 |
+
formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s')
|
14 |
+
handler.setFormatter(formatter)
|
15 |
+
logging.set_verbosity(log_level)
|
16 |
+
if fname is not None:
|
17 |
+
handler = _logging.FileHandler(fname)
|
18 |
+
handler.setFormatter(formatter)
|
19 |
+
logging.get_absl_logger().addHandler(handler)
|
20 |
+
|
21 |
+
|
22 |
+
def dct2str(dct):
|
23 |
+
return str({k: f'{v:.6g}' for k, v in dct.items()})
|
24 |
+
|
25 |
+
|
26 |
+
def get_nnet(name, **kwargs):
|
27 |
+
if name == 'uvit':
|
28 |
+
from libs.uvit import UViT
|
29 |
+
return UViT(**kwargs)
|
30 |
+
elif name == 'uvit_t2i':
|
31 |
+
from libs.uvit_t2i import UViT
|
32 |
+
return UViT(**kwargs)
|
33 |
+
elif name == 'uvit_multi_post_ln':
|
34 |
+
from libs.uvit_multi_post_ln import UViT
|
35 |
+
return UViT(**kwargs)
|
36 |
+
elif name == 'uvit_multi_post_ln_v1':
|
37 |
+
from libs.uvit_multi_post_ln_v1 import UViT
|
38 |
+
return UViT(**kwargs)
|
39 |
+
elif name == 'triffuser_multi_post_ln':
|
40 |
+
from libs.triffuser_multi_post_ln import Triffuser
|
41 |
+
return Triffuser(**kwargs)
|
42 |
+
else:
|
43 |
+
raise NotImplementedError(name)
|
44 |
+
|
45 |
+
|
46 |
+
def set_seed(seed: int):
|
47 |
+
if seed is not None:
|
48 |
+
torch.manual_seed(seed)
|
49 |
+
np.random.seed(seed)
|
50 |
+
|
51 |
+
|
52 |
+
def get_optimizer(params, name, **kwargs):
|
53 |
+
if name == 'adam':
|
54 |
+
from torch.optim import Adam
|
55 |
+
return Adam(params, **kwargs)
|
56 |
+
elif name == 'adamw':
|
57 |
+
from torch.optim import AdamW
|
58 |
+
return AdamW(params, **kwargs)
|
59 |
+
else:
|
60 |
+
raise NotImplementedError(name)
|
61 |
+
|
62 |
+
|
63 |
+
def customized_lr_scheduler(optimizer, warmup_steps=-1):
|
64 |
+
from torch.optim.lr_scheduler import LambdaLR
|
65 |
+
def fn(step):
|
66 |
+
if warmup_steps > 0:
|
67 |
+
return min(step / warmup_steps, 1)
|
68 |
+
else:
|
69 |
+
return 1
|
70 |
+
return LambdaLR(optimizer, fn)
|
71 |
+
|
72 |
+
|
73 |
+
def get_lr_scheduler(optimizer, name, **kwargs):
|
74 |
+
if name == 'customized':
|
75 |
+
return customized_lr_scheduler(optimizer, **kwargs)
|
76 |
+
elif name == 'cosine':
|
77 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
78 |
+
return CosineAnnealingLR(optimizer, **kwargs)
|
79 |
+
else:
|
80 |
+
raise NotImplementedError(name)
|
81 |
+
|
82 |
+
|
83 |
+
def ema(model_dest: nn.Module, model_src: nn.Module, rate):
|
84 |
+
param_dict_src = dict(model_src.named_parameters())
|
85 |
+
for p_name, p_dest in model_dest.named_parameters():
|
86 |
+
p_src = param_dict_src[p_name]
|
87 |
+
assert p_src is not p_dest
|
88 |
+
p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
|
89 |
+
|
90 |
+
|
91 |
+
class TrainState(object):
|
92 |
+
def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None):
|
93 |
+
self.optimizer = optimizer
|
94 |
+
self.lr_scheduler = lr_scheduler
|
95 |
+
self.step = step
|
96 |
+
self.nnet = nnet
|
97 |
+
self.nnet_ema = nnet_ema
|
98 |
+
|
99 |
+
def ema_update(self, rate=0.9999):
|
100 |
+
if self.nnet_ema is not None:
|
101 |
+
ema(self.nnet_ema, self.nnet, rate)
|
102 |
+
|
103 |
+
def save(self, path):
|
104 |
+
os.makedirs(path, exist_ok=True)
|
105 |
+
torch.save(self.step, os.path.join(path, 'step.pth'))
|
106 |
+
for key, val in self.__dict__.items():
|
107 |
+
if key != 'step' and val is not None:
|
108 |
+
torch.save(val.state_dict(), os.path.join(path, f'{key}.pth'))
|
109 |
+
|
110 |
+
def load(self, path):
|
111 |
+
logging.info(f'load from {path}')
|
112 |
+
self.step = torch.load(os.path.join(path, 'step.pth'))
|
113 |
+
for key, val in self.__dict__.items():
|
114 |
+
if key != 'step' and val is not None:
|
115 |
+
val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu'))
|
116 |
+
|
117 |
+
def resume(self, ckpt_root, step=None):
|
118 |
+
if not os.path.exists(ckpt_root):
|
119 |
+
return
|
120 |
+
if step is None:
|
121 |
+
ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root)))
|
122 |
+
if not ckpts:
|
123 |
+
return
|
124 |
+
steps = map(lambda x: int(x.split(".")[0]), ckpts)
|
125 |
+
step = max(steps)
|
126 |
+
ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt')
|
127 |
+
logging.info(f'resume from {ckpt_path}')
|
128 |
+
self.load(ckpt_path)
|
129 |
+
|
130 |
+
def to(self, device):
|
131 |
+
for key, val in self.__dict__.items():
|
132 |
+
if isinstance(val, nn.Module):
|
133 |
+
val.to(device)
|
134 |
+
|
135 |
+
|
136 |
+
def cnt_params(model):
|
137 |
+
return sum(param.numel() for param in model.parameters())
|
138 |
+
|
139 |
+
|
140 |
+
def initialize_train_state(config, device):
|
141 |
+
params = []
|
142 |
+
|
143 |
+
nnet = get_nnet(**config.nnet)
|
144 |
+
params += nnet.parameters()
|
145 |
+
nnet_ema = get_nnet(**config.nnet)
|
146 |
+
nnet_ema.eval()
|
147 |
+
logging.info(f'nnet has {cnt_params(nnet)} parameters')
|
148 |
+
|
149 |
+
optimizer = get_optimizer(params, **config.optimizer)
|
150 |
+
lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler)
|
151 |
+
|
152 |
+
train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0,
|
153 |
+
nnet=nnet, nnet_ema=nnet_ema)
|
154 |
+
train_state.ema_update(0)
|
155 |
+
train_state.to(device)
|
156 |
+
return train_state
|
157 |
+
|
158 |
+
|
159 |
+
def amortize(n_samples, batch_size):
|
160 |
+
k = n_samples // batch_size
|
161 |
+
r = n_samples % batch_size
|
162 |
+
return k * [batch_size] if r == 0 else k * [batch_size] + [r]
|
163 |
+
|
164 |
+
|
165 |
+
def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None):
|
166 |
+
os.makedirs(path, exist_ok=True)
|
167 |
+
idx = 0
|
168 |
+
batch_size = mini_batch_size * accelerator.num_processes
|
169 |
+
|
170 |
+
for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'):
|
171 |
+
samples = unpreprocess_fn(sample_fn(mini_batch_size))
|
172 |
+
samples = accelerator.gather(samples.contiguous())[:_batch_size]
|
173 |
+
if accelerator.is_main_process:
|
174 |
+
for sample in samples:
|
175 |
+
save_image(sample, os.path.join(path, f"{idx}.png"))
|
176 |
+
idx += 1
|
177 |
+
|
178 |
+
|
179 |
+
def grad_norm(model):
|
180 |
+
total_norm = 0.
|
181 |
+
for p in model.parameters():
|
182 |
+
param_norm = p.grad.data.norm(2)
|
183 |
+
total_norm += param_norm.item() ** 2
|
184 |
+
total_norm = total_norm ** (1. / 2)
|
185 |
+
return total_norm
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
def center_crop(width, height, img):
|
190 |
+
resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos']
|
191 |
+
crop = np.min(img.shape[:2])
|
192 |
+
img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
|
193 |
+
(img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2] # center crop
|
194 |
+
try:
|
195 |
+
img = Image.fromarray(img, 'RGB')
|
196 |
+
except:
|
197 |
+
img = Image.fromarray(img)
|
198 |
+
img = img.resize((width, height), resample) # resize the center crop from [crop, crop] to [width, height]
|
199 |
+
|
200 |
+
return np.array(img).astype(np.uint8)
|
201 |
+
|
202 |
+
|
203 |
+
def drawRoundRec(draw, color, x, y, w, h, r):
|
204 |
+
drawObject = draw
|
205 |
+
|
206 |
+
'''Rounds'''
|
207 |
+
drawObject.ellipse((x, y, x + r, y + r), fill=color)
|
208 |
+
drawObject.ellipse((x + w - r, y, x + w, y + r), fill=color)
|
209 |
+
drawObject.ellipse((x, y + h - r, x + r, y + h), fill=color)
|
210 |
+
drawObject.ellipse((x + w - r, y + h - r, x + w, y + h), fill=color)
|
211 |
+
|
212 |
+
'''rec.s'''
|
213 |
+
drawObject.rectangle((x + r / 2, y, x + w - (r / 2), y + h), fill=color)
|
214 |
+
drawObject.rectangle((x, y + r / 2, x + w, y + h - (r / 2)), fill=color)
|
215 |
+
|
216 |
+
|
217 |
+
def add_water(img, text='UniDiffuser', pos=3):
|
218 |
+
width, height = img.size
|
219 |
+
scale = 4
|
220 |
+
scale_size = 0.5
|
221 |
+
img = img.resize((width * scale, height * scale), Image.LANCZOS)
|
222 |
+
result = Image.new(img.mode, (width * scale, height * scale), color=(255, 255, 255))
|
223 |
+
result.paste(img, box=(0, 0))
|
224 |
+
|
225 |
+
delta_w = int(width * scale * 0.27 * scale_size) # text width
|
226 |
+
delta_h = width * scale * 0.05 * scale_size # text height
|
227 |
+
postions = np.array([[0, 0], [0, height * scale - delta_h], [width * scale - delta_w, 0],
|
228 |
+
[width * scale - delta_w, height * scale - delta_h]])
|
229 |
+
postion = postions[pos]
|
230 |
+
# 文本
|
231 |
+
draw = ImageDraw.Draw(result)
|
232 |
+
fillColor = (107, 92, 231)
|
233 |
+
setFont = ImageFont.truetype("assets/ArialBoldMT.ttf", int(width * scale * 0.05 * scale_size))
|
234 |
+
delta = 20 * scale_size
|
235 |
+
padding = 15 * scale_size
|
236 |
+
drawRoundRec(draw, (223, 230, 233), postion[0] - delta - padding, postion[1] - delta - padding,
|
237 |
+
w=delta_w + 2 * padding, h=delta_h + 2 * padding, r=50 * scale_size)
|
238 |
+
draw.text((postion[0] - delta, postion[1] - delta), text, font=setFont, fill=fillColor)
|
239 |
+
|
240 |
+
return result.resize((width, height), Image.LANCZOS)
|