Spaces:
Runtime error
Runtime error
Commit
·
fac3244
0
Parent(s):
Initial PLONK deployment for Hugging Face Spaces
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +58 -0
- README.md +78 -0
- app.py +132 -0
- plonk/__init__.py +1 -0
- plonk/callbacks/__init__.py +3 -0
- plonk/callbacks/data.py +11 -0
- plonk/callbacks/ema.py +102 -0
- plonk/callbacks/fix_nans.py +55 -0
- plonk/configs/computer/a100.yaml +8 -0
- plonk/configs/computer/cluster-node-a100.yaml +8 -0
- plonk/configs/computer/cluster-node-v100.yaml +8 -0
- plonk/configs/computer/cpu.yaml +8 -0
- plonk/configs/computer/h100.yaml +8 -0
- plonk/configs/computer/v100.yaml +8 -0
- plonk/configs/config.yaml +90 -0
- plonk/configs/dataset/combined_emb.yaml +38 -0
- plonk/configs/dataset/inaturalist_emb.yaml +38 -0
- plonk/configs/dataset/osv5m.yaml +43 -0
- plonk/configs/dataset/osv5m_emb.yaml +38 -0
- plonk/configs/dataset/test_transform/center_crop.yaml +12 -0
- plonk/configs/dataset/test_transform/clip.yaml +2 -0
- plonk/configs/dataset/test_transform/empty.yaml +2 -0
- plonk/configs/dataset/test_transform/fast_clip.yaml +12 -0
- plonk/configs/dataset/test_transform/fast_resnet.yaml +12 -0
- plonk/configs/dataset/test_transform/none.yaml +6 -0
- plonk/configs/dataset/train_transform/augmentation.yaml +85 -0
- plonk/configs/dataset/train_transform/center_crop.yaml +14 -0
- plonk/configs/dataset/train_transform/clip.yaml +2 -0
- plonk/configs/dataset/train_transform/empty.yaml +2 -0
- plonk/configs/dataset/train_transform/fast_clip.yaml +12 -0
- plonk/configs/dataset/train_transform/fast_resnet.yaml +12 -0
- plonk/configs/dataset/train_transform/none.yaml +7 -0
- plonk/configs/dataset/yfcc_emb.yaml +38 -0
- plonk/configs/exp/YFCC100M_geoadalnmlp_r2_small_sigmoid_diffusion.yaml +35 -0
- plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_linear_flow_rieman.yaml +32 -0
- plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_diffusion.yaml +36 -0
- plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow.yaml +38 -0
- plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml +40 -0
- plonk/configs/exp/YFCC100M_geoadalnmlp_von_fisher.yaml +26 -0
- plonk/configs/exp/YFCC100M_geoadalnmlp_von_fisher_mixture.yaml +26 -0
- plonk/configs/exp/combined_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml +40 -0
- plonk/configs/exp/iNaturalist_geoadalnmlp_r2_small_sigmoid_diffusion.yaml +36 -0
- plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_diffusion.yaml +37 -0
- plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow.yaml +39 -0
- plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml +40 -0
- plonk/configs/exp/iNaturalist_geoadalnmlp_von_fisher.yaml +26 -0
- plonk/configs/exp/iNaturalist_geoadalnmlp_von_fisher_mixture.yaml +26 -0
- plonk/configs/exp/osv_5m_geoadalnmlp_r2_small_sigmoid_diffusion.yaml +34 -0
- plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_linear_flow_riemann.yaml +30 -0
- plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_sigmoid_diffusion.yaml +35 -0
.gitignore
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Hugging Face Spaces .gitignore
|
2 |
+
|
3 |
+
# Python cache
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Development files
|
10 |
+
.ipynb_checkpoints/
|
11 |
+
.vscode/
|
12 |
+
.idea/
|
13 |
+
*.swp
|
14 |
+
*.swo
|
15 |
+
|
16 |
+
# OS files
|
17 |
+
.DS_Store
|
18 |
+
Thumbs.db
|
19 |
+
|
20 |
+
# Temporary files
|
21 |
+
*.tmp
|
22 |
+
*.log
|
23 |
+
*.pid
|
24 |
+
|
25 |
+
# Original demo files (using streamlit)
|
26 |
+
demo/demo.py
|
27 |
+
|
28 |
+
# Environment files
|
29 |
+
.env
|
30 |
+
.env.local
|
31 |
+
|
32 |
+
# Model checkpoints (will be downloaded automatically)
|
33 |
+
checkpoints/
|
34 |
+
*.safetensors
|
35 |
+
*.bin
|
36 |
+
|
37 |
+
# Large data files
|
38 |
+
data/
|
39 |
+
datasets/
|
40 |
+
*.csv
|
41 |
+
*.json
|
42 |
+
|
43 |
+
# Training artifacts
|
44 |
+
wandb/
|
45 |
+
logs/
|
46 |
+
outputs/
|
47 |
+
|
48 |
+
# Test files
|
49 |
+
test_*.py
|
50 |
+
*_test.py
|
51 |
+
|
52 |
+
# Documentation that's not needed for the Space
|
53 |
+
*.md
|
54 |
+
!README.md
|
55 |
+
|
56 |
+
# Git files
|
57 |
+
.git/
|
58 |
+
.gitmodules
|
README.md
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: PLONK Geolocation
|
3 |
+
emoji: 🗺️
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.0.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
# 🗺️ PLONK: Around the World in 80 Timesteps
|
14 |
+
|
15 |
+
A generative approach to global visual geolocation. Upload an image and PLONK will predict where it was taken!
|
16 |
+
|
17 |
+
## About
|
18 |
+
|
19 |
+
PLONK is a diffusion-based model that predicts the geographic location where a photo was taken based solely on its visual content. This Space uses the PLONK_YFCC model trained on the YFCC100M dataset.
|
20 |
+
|
21 |
+
## Features
|
22 |
+
|
23 |
+
- **Simple Prediction**: Get a single high-confidence location prediction
|
24 |
+
- **Advanced Analysis**: Explore prediction uncertainty with multiple samples and guidance control
|
25 |
+
- **Fast CPU Inference**: ~300-500ms per image on CPU-Basic tier
|
26 |
+
- **GPU Ready**: Upgrade to T4-small for ~45ms inference time
|
27 |
+
|
28 |
+
## Usage
|
29 |
+
|
30 |
+
1. Upload an image using the interface
|
31 |
+
2. Click "Submit" to get location predictions
|
32 |
+
3. For advanced analysis, try different guidance scales:
|
33 |
+
- CFG = 0.0: More diverse predictions (good for uncertainty estimation)
|
34 |
+
- CFG = 2.0: Single confident prediction (best guess)
|
35 |
+
|
36 |
+
## API Usage
|
37 |
+
|
38 |
+
This Space exposes a REST API compatible with Gradio's prediction format:
|
39 |
+
|
40 |
+
```python
|
41 |
+
import requests
|
42 |
+
|
43 |
+
url = "https://your-space-name.hf.space/api/predict"
|
44 |
+
files = {"data": open("image.jpg", "rb")}
|
45 |
+
response = requests.post(url, files=files)
|
46 |
+
print(response.json())
|
47 |
+
```
|
48 |
+
|
49 |
+
## Model Performance
|
50 |
+
|
51 |
+
- **Latency**: 300-500ms on CPU-Basic, ~45ms on T4 GPU
|
52 |
+
- **Memory**: <1GB RAM usage
|
53 |
+
- **Throughput**: ~10 req/s on T4 before saturation
|
54 |
+
|
55 |
+
## Scaling Options
|
56 |
+
|
57 |
+
- **Free CPU-Basic**: Perfect for testing and low-volume usage
|
58 |
+
- **T4-small ($0.40/hr)**: 10x faster inference for production
|
59 |
+
- **Inference Endpoints**: Auto-scaling with pay-per-use pricing
|
60 |
+
|
61 |
+
## Citation
|
62 |
+
|
63 |
+
If you use PLONK in your research, please cite:
|
64 |
+
|
65 |
+
```bibtex
|
66 |
+
@article{dufour2024plonk,
|
67 |
+
title={Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation},
|
68 |
+
author={Dufour, Nicolas and others},
|
69 |
+
journal={arXiv preprint},
|
70 |
+
year={2024}
|
71 |
+
}
|
72 |
+
```
|
73 |
+
|
74 |
+
## Links
|
75 |
+
|
76 |
+
- 📄 [Project Page](https://nicolas-dufour.github.io/plonk)
|
77 |
+
- 💻 [Code Repository](https://github.com/nicolas-dufour/plonk)
|
78 |
+
- 🤗 [Model on Hugging Face](https://huggingface.co/nicolas-dufour/PLONK_YFCC)
|
app.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from plonk.pipe import PlonkPipeline
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
# Initialize the pipeline
|
8 |
+
print("Loading PLONK_YFCC model...")
|
9 |
+
pipe = PlonkPipeline(model_path="nicolas-dufour/PLONK_YFCC")
|
10 |
+
print("Model loaded successfully!")
|
11 |
+
|
12 |
+
def predict_geolocation(image):
|
13 |
+
"""
|
14 |
+
Predict geolocation from an uploaded image
|
15 |
+
Args:
|
16 |
+
image: PIL Image
|
17 |
+
Returns:
|
18 |
+
str: Formatted latitude and longitude
|
19 |
+
"""
|
20 |
+
if image is None:
|
21 |
+
return "Please upload an image"
|
22 |
+
|
23 |
+
try:
|
24 |
+
# Get prediction using the pipeline
|
25 |
+
# Using single sample with high confidence (cfg=2.0) for best guess
|
26 |
+
predicted_gps = pipe(image, batch_size=1, cfg=2.0, num_steps=32)
|
27 |
+
|
28 |
+
# Extract latitude and longitude
|
29 |
+
lat, lon = float(predicted_gps[0, 0]), float(predicted_gps[0, 1])
|
30 |
+
|
31 |
+
# Format the result
|
32 |
+
result = f"Predicted Location:\nLatitude: {lat:.6f}\nLongitude: {lon:.6f}"
|
33 |
+
|
34 |
+
return result
|
35 |
+
|
36 |
+
except Exception as e:
|
37 |
+
return f"Error during prediction: {str(e)}"
|
38 |
+
|
39 |
+
def predict_geolocation_with_samples(image, num_samples=64, cfg=0.0):
|
40 |
+
"""
|
41 |
+
Predict geolocation with multiple samples for uncertainty visualization
|
42 |
+
Args:
|
43 |
+
image: PIL Image
|
44 |
+
num_samples: Number of samples to generate
|
45 |
+
cfg: Classifier-free guidance scale
|
46 |
+
Returns:
|
47 |
+
str: Formatted results with statistics
|
48 |
+
"""
|
49 |
+
if image is None:
|
50 |
+
return "Please upload an image"
|
51 |
+
|
52 |
+
try:
|
53 |
+
# Get multiple predictions for uncertainty estimation
|
54 |
+
predicted_gps = pipe(image, batch_size=num_samples, cfg=cfg, num_steps=32)
|
55 |
+
|
56 |
+
# Calculate statistics
|
57 |
+
lats = predicted_gps[:, 0].astype(float)
|
58 |
+
lons = predicted_gps[:, 1].astype(float)
|
59 |
+
|
60 |
+
mean_lat, mean_lon = np.mean(lats), np.mean(lons)
|
61 |
+
std_lat, std_lon = np.std(lats), np.std(lons)
|
62 |
+
|
63 |
+
# Get high confidence prediction
|
64 |
+
high_conf_gps = pipe(image, batch_size=1, cfg=2.0, num_steps=32)
|
65 |
+
conf_lat, conf_lon = float(high_conf_gps[0, 0]), float(high_conf_gps[0, 1])
|
66 |
+
|
67 |
+
result = f"""Geolocation Prediction Results:
|
68 |
+
|
69 |
+
High Confidence Prediction (CFG=2.0):
|
70 |
+
Latitude: {conf_lat:.6f}
|
71 |
+
Longitude: {conf_lon:.6f}
|
72 |
+
|
73 |
+
Sample Statistics ({num_samples} samples, CFG={cfg}):
|
74 |
+
Mean Latitude: {mean_lat:.6f} ± {std_lat:.6f}
|
75 |
+
Mean Longitude: {mean_lon:.6f} ± {std_lon:.6f}
|
76 |
+
"""
|
77 |
+
|
78 |
+
return result
|
79 |
+
|
80 |
+
except Exception as e:
|
81 |
+
return f"Error during prediction: {str(e)}"
|
82 |
+
|
83 |
+
# Create the Gradio interface for simple prediction
|
84 |
+
simple_interface = gr.Interface(
|
85 |
+
fn=predict_geolocation,
|
86 |
+
inputs=gr.Image(type="pil", label="Upload an image"),
|
87 |
+
outputs=gr.Textbox(label="Predicted Location", lines=4),
|
88 |
+
title="🗺️ PLONK: Global Visual Geolocation",
|
89 |
+
description="""
|
90 |
+
Upload an image and PLONK will predict where it was taken!
|
91 |
+
|
92 |
+
This uses the PLONK_YFCC model trained on the YFCC100M dataset.
|
93 |
+
The model predicts latitude and longitude coordinates based on visual content.
|
94 |
+
|
95 |
+
**Note**: This is running on CPU, so processing may take 300-500ms per image.
|
96 |
+
""",
|
97 |
+
examples=[
|
98 |
+
["demo/examples/condor.jpg"],
|
99 |
+
["demo/examples/Kilimanjaro.jpg"],
|
100 |
+
["demo/examples/pigeon.png"]
|
101 |
+
] if any(Path("demo/examples").glob("*")) else None
|
102 |
+
)
|
103 |
+
|
104 |
+
# Create advanced interface with sampling options
|
105 |
+
advanced_interface = gr.Interface(
|
106 |
+
fn=predict_geolocation_with_samples,
|
107 |
+
inputs=[
|
108 |
+
gr.Image(type="pil", label="Upload an image"),
|
109 |
+
gr.Slider(1, 256, value=64, step=1, label="Number of samples"),
|
110 |
+
gr.Slider(0.0, 5.0, value=0.0, step=0.1, label="Guidance scale (CFG)")
|
111 |
+
],
|
112 |
+
outputs=gr.Textbox(label="Detailed Results", lines=10),
|
113 |
+
title="🗺️ PLONK: Advanced Geolocation with Uncertainty",
|
114 |
+
description="""
|
115 |
+
Advanced interface showing prediction uncertainty through multiple samples.
|
116 |
+
|
117 |
+
- **Number of samples**: More samples = better uncertainty estimation (but slower)
|
118 |
+
- **Guidance scale**: Higher values = more confident predictions (try 2.0 for best single guess)
|
119 |
+
""",
|
120 |
+
)
|
121 |
+
|
122 |
+
# Create tabbed interface
|
123 |
+
demo = gr.TabbedInterface(
|
124 |
+
[simple_interface, advanced_interface],
|
125 |
+
["Simple Prediction", "Advanced Analysis"],
|
126 |
+
title="PLONK: Around the World in 80 Timesteps"
|
127 |
+
)
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
# Add necessary import for pathlib
|
131 |
+
from pathlib import Path
|
132 |
+
demo.launch()
|
plonk/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .pipe import PlonkPipeline
|
plonk/callbacks/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .ema import EMACallback
|
2 |
+
from .fix_nans import FixNANinGrad
|
3 |
+
from .data import IncreaseDataEpoch
|
plonk/callbacks/data.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytorch_lightning.callbacks import Callback
|
2 |
+
|
3 |
+
|
4 |
+
class IncreaseDataEpoch(Callback):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
|
8 |
+
def on_train_epoch_start(self, trainer, pl_module):
|
9 |
+
epoch = pl_module.current_epoch
|
10 |
+
if hasattr(trainer.datamodule.train_dataset, "shared_epoch"):
|
11 |
+
trainer.datamodule.train_dataset.shared_epoch.set_value(epoch)
|
plonk/callbacks/ema.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytorch_lightning import Callback
|
2 |
+
import copy
|
3 |
+
import itertools
|
4 |
+
import torch
|
5 |
+
import contextlib
|
6 |
+
from torch.distributed.fsdp import FullyShardedDataParallel
|
7 |
+
|
8 |
+
|
9 |
+
class EMACallback(Callback):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
module_attr_name,
|
13 |
+
ema_module_attr_name,
|
14 |
+
decay=0.999,
|
15 |
+
start_ema_step=0,
|
16 |
+
init_ema_random=True,
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.decay = decay
|
20 |
+
self.module_attr_name = module_attr_name
|
21 |
+
self.ema_module_attr_name = ema_module_attr_name
|
22 |
+
self.start_ema_step = start_ema_step
|
23 |
+
self.init_ema_random = init_ema_random
|
24 |
+
|
25 |
+
def on_train_start(self, trainer, pl_module):
|
26 |
+
if pl_module.global_step == 0:
|
27 |
+
if not hasattr(pl_module, self.module_attr_name):
|
28 |
+
raise ValueError(
|
29 |
+
f"Module {pl_module} does not have attribute {self.module_attr_name}"
|
30 |
+
)
|
31 |
+
if not hasattr(pl_module, self.ema_module_attr_name):
|
32 |
+
pl_module.add_module(
|
33 |
+
self.ema_module_attr_name,
|
34 |
+
copy.deepcopy(getattr(pl_module, self.module_attr_name))
|
35 |
+
.eval()
|
36 |
+
.requires_grad_(False),
|
37 |
+
)
|
38 |
+
self.reset_ema(pl_module)
|
39 |
+
|
40 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
41 |
+
if pl_module.global_step == self.start_ema_step:
|
42 |
+
self.reset_ema(pl_module)
|
43 |
+
elif (
|
44 |
+
pl_module.global_step < self.start_ema_step
|
45 |
+
and pl_module.global_step % 100 == 0
|
46 |
+
):
|
47 |
+
## slow ema updates for visualisation
|
48 |
+
self.update_ema(pl_module, decay=0.9)
|
49 |
+
elif pl_module.global_step > self.start_ema_step:
|
50 |
+
self.update_ema(pl_module, decay=self.decay)
|
51 |
+
|
52 |
+
def update_ema(self, pl_module, decay=0.999):
|
53 |
+
ema_module = getattr(pl_module, self.ema_module_attr_name)
|
54 |
+
module = getattr(pl_module, self.module_attr_name)
|
55 |
+
context_manager = self.get_model_context_manager(module)
|
56 |
+
with context_manager:
|
57 |
+
with torch.no_grad():
|
58 |
+
ema_params = ema_module.state_dict()
|
59 |
+
for name, param in itertools.chain(
|
60 |
+
module.named_parameters(), module.named_buffers()
|
61 |
+
):
|
62 |
+
if name in ema_params:
|
63 |
+
if param.requires_grad:
|
64 |
+
ema_params[name].copy_(
|
65 |
+
ema_params[name].detach().lerp(param.detach(), decay)
|
66 |
+
)
|
67 |
+
|
68 |
+
def get_model_context_manager(self, module):
|
69 |
+
fsdp_enabled = is_model_fsdp(module)
|
70 |
+
model_context_manager = contextlib.nullcontext()
|
71 |
+
if fsdp_enabled:
|
72 |
+
model_context_manager = module.summon_full_params(module)
|
73 |
+
return model_context_manager
|
74 |
+
|
75 |
+
def reset_ema(self, pl_module):
|
76 |
+
ema_module = getattr(pl_module, self.ema_module_attr_name)
|
77 |
+
if self.init_ema_random:
|
78 |
+
ema_module.init_weights()
|
79 |
+
else:
|
80 |
+
module = getattr(pl_module, self.module_attr_name)
|
81 |
+
context_manager = self.get_model_context_manager(module)
|
82 |
+
with context_manager:
|
83 |
+
ema_params = ema_module.state_dict()
|
84 |
+
for name, param in itertools.chain(
|
85 |
+
module.named_parameters(), module.named_buffers()
|
86 |
+
):
|
87 |
+
if name in ema_params:
|
88 |
+
ema_params[name].copy_(param.detach())
|
89 |
+
|
90 |
+
|
91 |
+
def is_model_fsdp(model: torch.nn.Module) -> bool:
|
92 |
+
try:
|
93 |
+
if isinstance(model, FullyShardedDataParallel):
|
94 |
+
return True
|
95 |
+
|
96 |
+
# Check if model is wrapped with FSDP
|
97 |
+
for _, obj in model.named_children():
|
98 |
+
if isinstance(obj, FullyShardedDataParallel):
|
99 |
+
return True
|
100 |
+
return False
|
101 |
+
except ImportError:
|
102 |
+
return False
|
plonk/callbacks/fix_nans.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pytorch_lightning.callbacks import Callback
|
3 |
+
import torch
|
4 |
+
|
5 |
+
log = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
|
8 |
+
class FixNANinGrad(Callback):
|
9 |
+
def __init__(self, monitor):
|
10 |
+
super().__init__()
|
11 |
+
self.monitor = monitor
|
12 |
+
self.continuous_nan_batchs = 0
|
13 |
+
|
14 |
+
def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None:
|
15 |
+
has_nan = []
|
16 |
+
is_inf = []
|
17 |
+
for name, param in pl_module.named_parameters():
|
18 |
+
if param.grad is not None:
|
19 |
+
if torch.isnan(param.grad).any():
|
20 |
+
has_nan.append(name)
|
21 |
+
if torch.isinf(param.grad).any():
|
22 |
+
is_inf.append(name)
|
23 |
+
torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad)
|
24 |
+
if len(has_nan) > 0:
|
25 |
+
print(f"Found NaN in {has_nan}")
|
26 |
+
if len(is_inf) > 0:
|
27 |
+
print(f"Found Inf in {is_inf}")
|
28 |
+
|
29 |
+
def on_train_batch_end(
|
30 |
+
self,
|
31 |
+
trainer,
|
32 |
+
pl_module,
|
33 |
+
outputs,
|
34 |
+
batch,
|
35 |
+
batch_idx,
|
36 |
+
) -> None:
|
37 |
+
logs = trainer.callback_metrics
|
38 |
+
i = 0
|
39 |
+
found_metric = False
|
40 |
+
while i < len(self.monitor) and not found_metric:
|
41 |
+
if self.monitor[i] in logs.keys():
|
42 |
+
current = logs[self.monitor[i]].squeeze()
|
43 |
+
found_metric = True
|
44 |
+
else:
|
45 |
+
i += 1
|
46 |
+
if not found_metric:
|
47 |
+
raise ValueError("Asked metric not in logs")
|
48 |
+
|
49 |
+
if not torch.isfinite(current):
|
50 |
+
self.continuous_nan_batchs += 1
|
51 |
+
if self.continuous_nan_batchs >= 5:
|
52 |
+
trainer.should_stop = True
|
53 |
+
log.info("Training interrupted because of NaN in {self.monitor}")
|
54 |
+
else:
|
55 |
+
self.continuous_nan_batchs = 0
|
plonk/configs/computer/a100.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
devices: 1
|
2 |
+
progress_bar_refresh_rate: 2
|
3 |
+
num_workers: 8
|
4 |
+
sync_batchnorm: False
|
5 |
+
accelerator: gpu
|
6 |
+
precision: 32
|
7 |
+
strategy: auto
|
8 |
+
num_nodes: 1
|
plonk/configs/computer/cluster-node-a100.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
devices: 8
|
2 |
+
num_workers: 8
|
3 |
+
progress_bar_refresh_rate: 2
|
4 |
+
sync_batchnorm: True
|
5 |
+
accelerator: gpu
|
6 |
+
precision: 32
|
7 |
+
strategy: ddp
|
8 |
+
num_nodes: 1
|
plonk/configs/computer/cluster-node-v100.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
devices: 4
|
2 |
+
num_workers: 10
|
3 |
+
progress_bar_refresh_rate: 2
|
4 |
+
sync_batchnorm: True
|
5 |
+
accelerator: gpu
|
6 |
+
precision: 32
|
7 |
+
strategy: ddp
|
8 |
+
num_nodes: 1
|
plonk/configs/computer/cpu.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
devices: null
|
2 |
+
num_workers: 0
|
3 |
+
progress_bar_refresh_rate: 2
|
4 |
+
sync_batchnorm: False
|
5 |
+
accelerator: cpu
|
6 |
+
precision: 32
|
7 |
+
strategy: auto
|
8 |
+
num_nodes: null
|
plonk/configs/computer/h100.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
devices: 1
|
2 |
+
progress_bar_refresh_rate: 2
|
3 |
+
num_workers: 24
|
4 |
+
sync_batchnorm: False
|
5 |
+
accelerator: gpu
|
6 |
+
precision: 32
|
7 |
+
strategy: auto
|
8 |
+
num_nodes: 1
|
plonk/configs/computer/v100.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
devices: 1
|
2 |
+
num_workers: 10
|
3 |
+
progress_bar_refresh_rate: 2
|
4 |
+
sync_batchnorm: False
|
5 |
+
accelerator: gpu
|
6 |
+
precision: 32
|
7 |
+
strategy: auto
|
8 |
+
num_nodes: 1
|
plonk/configs/config.yaml
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- model: default
|
3 |
+
- computer: v100
|
4 |
+
- dataset: osv5m_emb
|
5 |
+
- stage: null
|
6 |
+
- _self_
|
7 |
+
- exp: ???
|
8 |
+
|
9 |
+
model:
|
10 |
+
val_metrics:
|
11 |
+
_target_: metrics.distance_based.HaversineMetrics
|
12 |
+
acc_radiuses:
|
13 |
+
- 1
|
14 |
+
- 25
|
15 |
+
- 200
|
16 |
+
- 750
|
17 |
+
- 2500
|
18 |
+
acc_area: []
|
19 |
+
test_metrics:
|
20 |
+
_target_: metrics.distance_based.HaversineMetrics
|
21 |
+
acc_radiuses:
|
22 |
+
- 1
|
23 |
+
- 25
|
24 |
+
- 200
|
25 |
+
- 750
|
26 |
+
- 2500
|
27 |
+
acc_area: ${areas}
|
28 |
+
|
29 |
+
datamodule:
|
30 |
+
_target_: plonk.data.datamodule.ImageDataModule
|
31 |
+
train_dataset: ${dataset.train_dataset}
|
32 |
+
val_dataset: ${dataset.val_dataset}
|
33 |
+
test_dataset: ${dataset.test_dataset}
|
34 |
+
full_batch_size: ${dataset.full_batch_size}
|
35 |
+
eval_batch_size: ${dataset.eval_batch_size}
|
36 |
+
num_workers: ${computer.num_workers}
|
37 |
+
num_nodes: ${computer.num_nodes}
|
38 |
+
num_devices: ${computer.devices}
|
39 |
+
val_proportion: 0.02
|
40 |
+
|
41 |
+
trainer:
|
42 |
+
_target_: pytorch_lightning.Trainer
|
43 |
+
devices: ${computer.devices}
|
44 |
+
accelerator: ${computer.accelerator}
|
45 |
+
strategy: ${computer.strategy}
|
46 |
+
num_nodes: ${computer.num_nodes}
|
47 |
+
precision: ${computer.precision}
|
48 |
+
max_steps: 1000000
|
49 |
+
val_check_interval: 25000
|
50 |
+
check_val_every_n_epoch: null
|
51 |
+
|
52 |
+
logger:
|
53 |
+
_target_: pytorch_lightning.loggers.WandbLogger
|
54 |
+
save_dir: ${root_dir}/plonk
|
55 |
+
name: ${experiment_name}${logger_suffix}
|
56 |
+
project: diff_plonk
|
57 |
+
log_model: False
|
58 |
+
offline: False
|
59 |
+
|
60 |
+
checkpoints:
|
61 |
+
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
62 |
+
dirpath: ${root_dir}/plonk/checkpoints/${experiment_name}
|
63 |
+
filename: 'epoch_{epoch}'
|
64 |
+
monitor: val/loss
|
65 |
+
save_last: True
|
66 |
+
save_top_k: 0
|
67 |
+
every_n_epochs: 1
|
68 |
+
enable_version_counter: False
|
69 |
+
|
70 |
+
progress_bar:
|
71 |
+
_target_: pytorch_lightning.callbacks.TQDMProgressBar
|
72 |
+
refresh_rate: ${computer.progress_bar_refresh_rate}
|
73 |
+
|
74 |
+
data_dir: ${root_dir}/plonk/datasets
|
75 |
+
root_dir: ${hydra:runtime.cwd}
|
76 |
+
experiment_name: ${dataset.name}_${model.name}_${experiment_name_suffix}
|
77 |
+
experiment_name_suffix: base
|
78 |
+
logger_suffix: ""
|
79 |
+
mode: train # change that to eval to do the testing
|
80 |
+
areas: ['country', 'region', 'sub-region', 'city']
|
81 |
+
class_name: null
|
82 |
+
streetclip: False
|
83 |
+
blur: False
|
84 |
+
text_tuning: False
|
85 |
+
|
86 |
+
hydra:
|
87 |
+
run:
|
88 |
+
dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name}
|
89 |
+
job:
|
90 |
+
chdir: true
|
plonk/configs/dataset/combined_emb.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- train_transform: empty
|
3 |
+
- test_transform: empty
|
4 |
+
- _self_
|
5 |
+
|
6 |
+
name: iNaturalist_OSV5M_YFCC100M_${dataset.embedding_name}
|
7 |
+
full_batch_size: 2048
|
8 |
+
cond_dim: 1024
|
9 |
+
eval_batch_size: 4096
|
10 |
+
output_type: emb
|
11 |
+
embedding_name: dinov2_vitl14_registers
|
12 |
+
|
13 |
+
train_dataset:
|
14 |
+
_partial_: true
|
15 |
+
_target_: plonk.data.webdataset.GPSWebdataset
|
16 |
+
root: ${data_dir}/YFCC100M/train/ ${data_dir}/osv5m/train/ ${data_dir}/inaturalist/train/ ${data_dir}/osv5m/train/ ${data_dir}/inaturalist/train/
|
17 |
+
train: true
|
18 |
+
embedding_name: ${dataset.embedding_name}
|
19 |
+
return_image: false
|
20 |
+
metadata_attributes: []
|
21 |
+
|
22 |
+
val_dataset:
|
23 |
+
_partial_: true
|
24 |
+
_target_: plonk.data.webdataset.GPSWebdataset
|
25 |
+
root: ${data_dir}/YFCC100M/yfcc4k/
|
26 |
+
train: false
|
27 |
+
embedding_name: ${dataset.embedding_name}
|
28 |
+
return_image: false
|
29 |
+
metadata_attributes: []
|
30 |
+
|
31 |
+
test_dataset:
|
32 |
+
_partial_: true
|
33 |
+
_target_: plonk.data.webdataset.GPSWebdataset
|
34 |
+
root: ${data_dir}/YFCC100M/yfcc4k/
|
35 |
+
train: false
|
36 |
+
embedding_name: ${dataset.embedding_name}
|
37 |
+
return_image: false
|
38 |
+
metadata_attributes: []
|
plonk/configs/dataset/inaturalist_emb.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- train_transform: empty
|
3 |
+
- test_transform: empty
|
4 |
+
- _self_
|
5 |
+
|
6 |
+
name: iNaturalist_${dataset.embedding_name}
|
7 |
+
full_batch_size: 512
|
8 |
+
cond_dim: 1024
|
9 |
+
eval_batch_size: 4096
|
10 |
+
output_type: emb
|
11 |
+
embedding_name: dinov2_vitl14_registers
|
12 |
+
|
13 |
+
train_dataset:
|
14 |
+
_partial_: true
|
15 |
+
_target_: plonk.data.webdataset.GPSWebdataset
|
16 |
+
root: ${data_dir}/inaturalist/train/
|
17 |
+
train: true
|
18 |
+
embedding_name: ${dataset.embedding_name}
|
19 |
+
return_image: false
|
20 |
+
metadata_attributes: []
|
21 |
+
|
22 |
+
val_dataset:
|
23 |
+
_partial_: true
|
24 |
+
_target_: plonk.data.webdataset.GPSWebdataset
|
25 |
+
root: ${data_dir}/inaturalist/val/
|
26 |
+
train: false
|
27 |
+
embedding_name: ${dataset.embedding_name}
|
28 |
+
return_image: false
|
29 |
+
metadata_attributes: []
|
30 |
+
|
31 |
+
test_dataset:
|
32 |
+
_partial_: true
|
33 |
+
_target_: plonk.data.webdataset.GPSWebdataset
|
34 |
+
root: ${data_dir}/inaturalist/test/
|
35 |
+
train: false
|
36 |
+
embedding_name: ${dataset.embedding_name}
|
37 |
+
return_image: false
|
38 |
+
metadata_attributes: []
|
plonk/configs/dataset/osv5m.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- train_transform: fast_clip
|
3 |
+
- test_transform: fast_clip
|
4 |
+
- _self_
|
5 |
+
|
6 |
+
name: osv5m
|
7 |
+
full_batch_size: 2048
|
8 |
+
eval_batch_size: 4096
|
9 |
+
train_dataset:
|
10 |
+
_partial_: true
|
11 |
+
_target_: plonk.data.data.OSV5M
|
12 |
+
path: ${data_dir}/osv5m/
|
13 |
+
split: train
|
14 |
+
class_name: ${class_name}
|
15 |
+
transforms: ${dataset.train_transform}
|
16 |
+
is_baseline: ${is_baseline}
|
17 |
+
areas: ${areas}
|
18 |
+
streetclip: ${streetclip}
|
19 |
+
blur: ${blur}
|
20 |
+
|
21 |
+
val_dataset:
|
22 |
+
_partial_: true
|
23 |
+
_target_: plonk.data.data.OSV5M
|
24 |
+
path: ${data_dir}/osv5m/
|
25 |
+
split: val
|
26 |
+
class_name: ${class_name}
|
27 |
+
transforms: ${dataset.test_transform}
|
28 |
+
is_baseline: ${is_baseline}
|
29 |
+
areas: ${areas}
|
30 |
+
streetclip: ${streetclip}
|
31 |
+
blur: ${blur}
|
32 |
+
|
33 |
+
test_dataset:
|
34 |
+
_partial_: true
|
35 |
+
_target_: plonk.data.data.OSV5M
|
36 |
+
path: ${data_dir}/osv5m/
|
37 |
+
split: test
|
38 |
+
class_name: ${class_name}
|
39 |
+
transforms: ${dataset.test_transform}
|
40 |
+
is_baseline: ${is_baseline}
|
41 |
+
areas: ${areas}
|
42 |
+
streetclip: ${streetclip}
|
43 |
+
blur: ${blur}
|
plonk/configs/dataset/osv5m_emb.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- train_transform: empty
|
3 |
+
- test_transform: empty
|
4 |
+
- _self_
|
5 |
+
|
6 |
+
name: osv5m_${dataset.embedding_name}
|
7 |
+
full_batch_size: 1024
|
8 |
+
eval_batch_size: 4096
|
9 |
+
cond_dim: 1024
|
10 |
+
output_type: emb
|
11 |
+
embedding_name: street_clip
|
12 |
+
|
13 |
+
train_dataset:
|
14 |
+
_partial_: true
|
15 |
+
_target_: plonk.data.webdataset.GPSWebdataset
|
16 |
+
root: ${data_dir}/osv5m/train/
|
17 |
+
train: true
|
18 |
+
embedding_name: ${dataset.embedding_name}
|
19 |
+
return_image: false
|
20 |
+
metadata_attributes: []
|
21 |
+
|
22 |
+
val_dataset:
|
23 |
+
_partial_: true
|
24 |
+
_target_: plonk.data.webdataset.GPSWebdataset
|
25 |
+
root: ${data_dir}/osv5m/val/
|
26 |
+
train: false
|
27 |
+
embedding_name: ${dataset.embedding_name}
|
28 |
+
return_image: false
|
29 |
+
metadata_attributes: ["unique_country", "unique_region", "unique_sub-region", "unique_city"]
|
30 |
+
|
31 |
+
test_dataset:
|
32 |
+
_partial_: true
|
33 |
+
_target_: plonk.data.webdataset.GPSWebdataset
|
34 |
+
root: ${data_dir}/osv5m/test/
|
35 |
+
train: false
|
36 |
+
embedding_name: ${dataset.embedding_name}
|
37 |
+
return_image: false
|
38 |
+
metadata_attributes: ["unique_country", "unique_region", "unique_sub-region", "unique_city"]
|
plonk/configs/dataset/test_transform/center_crop.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: torchvision.transforms.Compose
|
2 |
+
transforms:
|
3 |
+
- _target_: torchvision.transforms.ToTensor
|
4 |
+
- _target_: plonk.utils.image_processing.CenterCrop
|
5 |
+
ratio: "1:1"
|
6 |
+
- _target_: torchvision.transforms.Resize
|
7 |
+
size: ${dataset.img_resolution}
|
8 |
+
interpolation: 3
|
9 |
+
antialias: true
|
10 |
+
- _target_: torchvision.transforms.Normalize
|
11 |
+
mean: 0.5
|
12 |
+
std: 0.5
|
plonk/configs/dataset/test_transform/clip.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
_target_: plonk.data.transforms.ClipTransform
|
2 |
+
split: val
|
plonk/configs/dataset/test_transform/empty.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
_target_: plonk.data.data.null_transform
|
2 |
+
_partial_: true
|
plonk/configs/dataset/test_transform/fast_clip.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: torchvision.transforms.Compose
|
2 |
+
transforms:
|
3 |
+
- _target_: torchvision.transforms.Resize
|
4 |
+
size: 224
|
5 |
+
interpolation: 3
|
6 |
+
antialias: true
|
7 |
+
- _target_: torchvision.transforms.CenterCrop
|
8 |
+
size: 224
|
9 |
+
- _target_: torchvision.transforms.ToTensor
|
10 |
+
- _target_: torchvision.transforms.Normalize
|
11 |
+
mean: [0.48145466, 0.4578275, 0.40821073]
|
12 |
+
std: [0.26862954, 0.26130258, 0.27577711]
|
plonk/configs/dataset/test_transform/fast_resnet.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: torchvision.transforms.Compose
|
2 |
+
transforms:
|
3 |
+
- _target_: torchvision.transforms.Resize
|
4 |
+
size: 224
|
5 |
+
interpolation: 3
|
6 |
+
antialias: true
|
7 |
+
- _target_: torchvision.transforms.CenterCrop
|
8 |
+
size: 224
|
9 |
+
- _target_: torchvision.transforms.ToTensor
|
10 |
+
- _target_: torchvision.transforms.Normalize
|
11 |
+
mean: [0.485 ,0.456 ,0.406]
|
12 |
+
std: [0.229, 0.224, 0.225]
|
plonk/configs/dataset/test_transform/none.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: torchvision.transforms.Compose
|
2 |
+
transforms:
|
3 |
+
- _target_: torchvision.transforms.ToTensor
|
4 |
+
- _target_: torchvision.transforms.Normalize
|
5 |
+
mean: 0.5
|
6 |
+
std: 0.5
|
plonk/configs/dataset/train_transform/augmentation.yaml
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: plonk.data.augmentation.ImageAugmentation
|
2 |
+
names: "standard_augmentation,geometric_augmentation,clip_transform"
|
3 |
+
|
4 |
+
# always apply clip_transform at the end
|
5 |
+
clip_transform:
|
6 |
+
_target_: torchvision.transforms.Compose
|
7 |
+
transforms:
|
8 |
+
- _target_: torchvision.transforms.Resize
|
9 |
+
size: 224
|
10 |
+
interpolation: 3
|
11 |
+
antialias: true
|
12 |
+
- _target_: torchvision.transforms.CenterCrop
|
13 |
+
size: 224
|
14 |
+
- _target_: torchvision.transforms.ToTensor
|
15 |
+
- _target_: torchvision.transforms.Normalize
|
16 |
+
mean: [0.48145466, 0.4578275, 0.40821073]
|
17 |
+
std: [0.26862954, 0.26130258, 0.27577711]
|
18 |
+
|
19 |
+
standard_augmentation:
|
20 |
+
_target_: plonk.data.augmentation.StandardAugmentation
|
21 |
+
# by default, we all augmentation methods
|
22 |
+
names: "brightness,contrast,sharpness,color,blur,gaussian_noise"
|
23 |
+
|
24 |
+
# random PIL brigtness
|
25 |
+
brightness:
|
26 |
+
_target_: plonk.data.augmentation.PillowBrightness
|
27 |
+
p: 0.2
|
28 |
+
factor_interval: [0.5, 1.5]
|
29 |
+
|
30 |
+
# random PIL contrast
|
31 |
+
contrast:
|
32 |
+
_target_: plonk.data.augmentation.PillowContrast
|
33 |
+
p: 0.2
|
34 |
+
factor_interval: [0.3, 3]
|
35 |
+
|
36 |
+
# random PIL sharpness
|
37 |
+
sharpness:
|
38 |
+
_target_: plonk.data.augmentation.PillowSharpness
|
39 |
+
p: 0.2
|
40 |
+
factor_interval: [0.5, 30.0]
|
41 |
+
|
42 |
+
# random PIL color
|
43 |
+
color:
|
44 |
+
_target_: plonk.data.augmentation.PillowColor
|
45 |
+
p: 0.2
|
46 |
+
factor_interval: [0.0, 2.0]
|
47 |
+
|
48 |
+
# random PIL blur
|
49 |
+
blur:
|
50 |
+
_target_: plonk.data.augmentation.PillowBlur
|
51 |
+
p: 0.2
|
52 |
+
factor_interval: [1, 2]
|
53 |
+
|
54 |
+
# random numpy gaussian noise
|
55 |
+
gaussian_noise:
|
56 |
+
_target_: plonk.data.augmentation.NumpyGaussianNoise
|
57 |
+
p: 0.2
|
58 |
+
factor_interval: [0.1, 0.04]
|
59 |
+
|
60 |
+
geometric_augmentation:
|
61 |
+
_target_: plonk.data.augmentation.GeometricAugmentation
|
62 |
+
# by default, we all augmentation methods
|
63 |
+
names: "random_rotation,random_resized_crop,random_horizontal_flip"
|
64 |
+
|
65 |
+
# random rotation
|
66 |
+
random_rotation:
|
67 |
+
_target_: torchvision.transforms.RandomRotation
|
68 |
+
degrees: [-15, 15]
|
69 |
+
|
70 |
+
# random crop
|
71 |
+
random_resized_crop:
|
72 |
+
_target_: torchvision.transforms.RandomResizedCrop
|
73 |
+
scale: [0.5, 1.0]
|
74 |
+
ratio: [0.9, 1.1]
|
75 |
+
size: 224
|
76 |
+
|
77 |
+
# random horizontal flip
|
78 |
+
random_horizontal_flip:
|
79 |
+
_target_: torchvision.transforms.RandomHorizontalFlip
|
80 |
+
p: 0.5
|
81 |
+
|
82 |
+
# random vertical flip
|
83 |
+
random_vertical_flip:
|
84 |
+
_target_: torchvision.transforms.RandomVerticalFlip
|
85 |
+
p: 0.5
|
plonk/configs/dataset/train_transform/center_crop.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: torchvision.transforms.Compose
|
2 |
+
transforms:
|
3 |
+
- _target_: torchvision.transforms.ToTensor
|
4 |
+
- _target_: plonk.utils.image_processing.CenterCrop
|
5 |
+
ratio: "1:1"
|
6 |
+
- _target_: torchvision.transforms.Resize
|
7 |
+
size: ${dataset.img_resolution}
|
8 |
+
interpolation: 3
|
9 |
+
antialias: true
|
10 |
+
- _target_: torchvision.transforms.RandomHorizontalFlip
|
11 |
+
p: 0.5
|
12 |
+
- _target_: torchvision.transforms.Normalize
|
13 |
+
mean: 0.5
|
14 |
+
std: 0.5
|
plonk/configs/dataset/train_transform/clip.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
_target_: plonk.data.transforms.ClipTransform
|
2 |
+
split: val
|
plonk/configs/dataset/train_transform/empty.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
_target_: plonk.data.data.null_transform
|
2 |
+
_partial_: true
|
plonk/configs/dataset/train_transform/fast_clip.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: torchvision.transforms.Compose
|
2 |
+
transforms:
|
3 |
+
- _target_: torchvision.transforms.Resize
|
4 |
+
size: 224
|
5 |
+
interpolation: 3
|
6 |
+
antialias: true
|
7 |
+
- _target_: torchvision.transforms.CenterCrop
|
8 |
+
size: 224
|
9 |
+
- _target_: torchvision.transforms.ToTensor
|
10 |
+
- _target_: torchvision.transforms.Normalize
|
11 |
+
mean: [0.48145466, 0.4578275, 0.40821073]
|
12 |
+
std: [0.26862954, 0.26130258, 0.27577711]
|
plonk/configs/dataset/train_transform/fast_resnet.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: torchvision.transforms.Compose
|
2 |
+
transforms:
|
3 |
+
- _target_: torchvision.transforms.Resize
|
4 |
+
size: 224
|
5 |
+
interpolation: 3
|
6 |
+
antialias: true
|
7 |
+
- _target_: torchvision.transforms.CenterCrop
|
8 |
+
size: 224
|
9 |
+
- _target_: torchvision.transforms.ToTensor
|
10 |
+
- _target_: torchvision.transforms.Normalize
|
11 |
+
mean: [0.485 ,0.456 ,0.406]
|
12 |
+
std: [0.229, 0.224, 0.225]
|
plonk/configs/dataset/train_transform/none.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: torchvision.transforms.Compose
|
2 |
+
transforms:
|
3 |
+
- _target_: torchvision.transforms.Resize
|
4 |
+
size: 224
|
5 |
+
interpolation: 3
|
6 |
+
antialias: true
|
7 |
+
- _target_: torchvision.transforms.ToTensor
|
plonk/configs/dataset/yfcc_emb.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- train_transform: empty
|
3 |
+
- test_transform: empty
|
4 |
+
- _self_
|
5 |
+
|
6 |
+
name: iNaturalist_${dataset.embedding_name}
|
7 |
+
full_batch_size: 2048
|
8 |
+
cond_dim: 1024
|
9 |
+
eval_batch_size: 4096
|
10 |
+
output_type: emb
|
11 |
+
embedding_name: dinov2_vitl14_registers
|
12 |
+
|
13 |
+
train_dataset:
|
14 |
+
_partial_: true
|
15 |
+
_target_: plonk.data.webdataset.GPSWebdataset
|
16 |
+
root: ${data_dir}/YFCC100M/train/
|
17 |
+
train: true
|
18 |
+
embedding_name: ${dataset.embedding_name}
|
19 |
+
return_image: false
|
20 |
+
metadata_attributes: []
|
21 |
+
|
22 |
+
val_dataset:
|
23 |
+
_partial_: true
|
24 |
+
_target_: plonk.data.webdataset.GPSWebdataset
|
25 |
+
root: ${data_dir}/YFCC100M/yfcc4k/
|
26 |
+
train: false
|
27 |
+
embedding_name: ${dataset.embedding_name}
|
28 |
+
return_image: false
|
29 |
+
metadata_attributes: []
|
30 |
+
|
31 |
+
test_dataset:
|
32 |
+
_partial_: true
|
33 |
+
_target_: plonk.data.webdataset.GPSWebdataset
|
34 |
+
root: ${data_dir}/YFCC100M/yfcc4k/
|
35 |
+
train: false
|
36 |
+
embedding_name: ${dataset.embedding_name}
|
37 |
+
return_image: false
|
38 |
+
metadata_attributes: []
|
plonk/configs/exp/YFCC100M_geoadalnmlp_r2_small_sigmoid_diffusion.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: yfcc_emb
|
5 |
+
- override /model: emb_cond
|
6 |
+
- override /model/network: geo_adaln_mlp
|
7 |
+
- override /model/train_noise_scheduler: sigmoid
|
8 |
+
- override /model/inference_noise_scheduler: sigmoid
|
9 |
+
- override /model/loss: ddpm
|
10 |
+
- _self_
|
11 |
+
|
12 |
+
model:
|
13 |
+
network:
|
14 |
+
depth: 12
|
15 |
+
dim: 512
|
16 |
+
optimizer:
|
17 |
+
optim:
|
18 |
+
lr: 8e-4
|
19 |
+
weight_decay: 0.05
|
20 |
+
loss:
|
21 |
+
cond_drop_rate: 0.1
|
22 |
+
train_noise_scheduler:
|
23 |
+
start: -7
|
24 |
+
end: 3
|
25 |
+
tau: 1.0
|
26 |
+
inference_noise_scheduler:
|
27 |
+
start: -7
|
28 |
+
end: 3
|
29 |
+
tau: 1.0
|
30 |
+
interpolant: diffusion
|
31 |
+
dataset:
|
32 |
+
full_batch_size: 1024
|
33 |
+
|
34 |
+
experiment_name_suffix: small_sigmoid
|
35 |
+
areas: []
|
plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_linear_flow_rieman.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: yfcc_emb
|
5 |
+
- override /model: emb_cond_cartesian
|
6 |
+
- override /model/network: geo_adaln_mlp
|
7 |
+
- override /model/train_noise_scheduler: linear
|
8 |
+
- override /model/inference_noise_scheduler: linear
|
9 |
+
- override /model/loss: riemannian_flow_matching
|
10 |
+
- override /model/manifold: sphere
|
11 |
+
- override /model/val_sampler: riemannian_flow_matching
|
12 |
+
- override /model/test_sampler: riemannian_flow_matching
|
13 |
+
- _self_
|
14 |
+
|
15 |
+
model:
|
16 |
+
network:
|
17 |
+
depth: 12
|
18 |
+
dim: 512
|
19 |
+
optimizer:
|
20 |
+
optim:
|
21 |
+
lr: 8e-4
|
22 |
+
weight_decay: 0.05
|
23 |
+
loss:
|
24 |
+
cond_drop_rate: 0.1
|
25 |
+
interpolant: flow_matching
|
26 |
+
|
27 |
+
dataset:
|
28 |
+
full_batch_size: 1024
|
29 |
+
|
30 |
+
areas: []
|
31 |
+
|
32 |
+
experiment_name_suffix: small_sigmoid
|
plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_diffusion.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: yfcc_emb
|
5 |
+
- override /model: emb_cond_cartesian
|
6 |
+
- override /model/network: geo_adaln_mlp
|
7 |
+
- override /model/train_noise_scheduler: sigmoid
|
8 |
+
- override /model/inference_noise_scheduler: sigmoid
|
9 |
+
- override /model/loss: ddpm
|
10 |
+
- _self_
|
11 |
+
|
12 |
+
model:
|
13 |
+
network:
|
14 |
+
depth: 12
|
15 |
+
dim: 512
|
16 |
+
optimizer:
|
17 |
+
optim:
|
18 |
+
lr: 8e-4
|
19 |
+
weight_decay: 0.05
|
20 |
+
loss:
|
21 |
+
cond_drop_rate: 0.1
|
22 |
+
train_noise_scheduler:
|
23 |
+
start: -7
|
24 |
+
end: 3
|
25 |
+
tau: 1.0
|
26 |
+
inference_noise_scheduler:
|
27 |
+
start: -7
|
28 |
+
end: 3
|
29 |
+
tau: 1.0
|
30 |
+
interpolant: diffusion
|
31 |
+
|
32 |
+
dataset:
|
33 |
+
full_batch_size: 1024
|
34 |
+
|
35 |
+
experiment_name_suffix: small_sigmoid
|
36 |
+
areas: []
|
plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: yfcc_emb
|
5 |
+
- override /model: emb_cond_cartesian
|
6 |
+
- override /model/network: geo_adaln_mlp
|
7 |
+
- override /model/train_noise_scheduler: sigmoid
|
8 |
+
- override /model/inference_noise_scheduler: sigmoid
|
9 |
+
- override /model/loss: flow_matching
|
10 |
+
- override /model/val_sampler: flow_matching
|
11 |
+
- override /model/test_sampler: flow_matching
|
12 |
+
- _self_
|
13 |
+
|
14 |
+
model:
|
15 |
+
network:
|
16 |
+
depth: 12
|
17 |
+
dim: 512
|
18 |
+
optimizer:
|
19 |
+
optim:
|
20 |
+
lr: 8e-4
|
21 |
+
weight_decay: 0.05
|
22 |
+
loss:
|
23 |
+
cond_drop_rate: 0.1
|
24 |
+
train_noise_scheduler:
|
25 |
+
start: -7
|
26 |
+
end: 3
|
27 |
+
tau: 1.0
|
28 |
+
inference_noise_scheduler:
|
29 |
+
start: -7
|
30 |
+
end: 3
|
31 |
+
tau: 1.0
|
32 |
+
interpolant: flow_matching
|
33 |
+
|
34 |
+
dataset:
|
35 |
+
full_batch_size: 1024
|
36 |
+
|
37 |
+
experiment_name_suffix: small_sigmoid
|
38 |
+
areas: []
|
plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: yfcc_emb
|
5 |
+
- override /model: emb_cond_cartesian
|
6 |
+
- override /model/network: geo_adaln_mlp
|
7 |
+
- override /model/train_noise_scheduler: sigmoid
|
8 |
+
- override /model/inference_noise_scheduler: sigmoid
|
9 |
+
- override /model/loss: riemannian_flow_matching
|
10 |
+
- override /model/manifold: sphere
|
11 |
+
- override /model/val_sampler: riemannian_flow_matching
|
12 |
+
- override /model/test_sampler: riemannian_flow_matching
|
13 |
+
- _self_
|
14 |
+
|
15 |
+
model:
|
16 |
+
network:
|
17 |
+
depth: 12
|
18 |
+
dim: 512
|
19 |
+
optimizer:
|
20 |
+
optim:
|
21 |
+
lr: 8e-4
|
22 |
+
weight_decay: 0.05
|
23 |
+
loss:
|
24 |
+
cond_drop_rate: 0.1
|
25 |
+
train_noise_scheduler:
|
26 |
+
start: -7
|
27 |
+
end: 3
|
28 |
+
tau: 1.0
|
29 |
+
inference_noise_scheduler:
|
30 |
+
start: -7
|
31 |
+
end: 3
|
32 |
+
tau: 1.0
|
33 |
+
interpolant: flow_matching
|
34 |
+
|
35 |
+
dataset:
|
36 |
+
full_batch_size: 1024
|
37 |
+
|
38 |
+
areas: []
|
39 |
+
|
40 |
+
experiment_name_suffix: small_sigmoid
|
plonk/configs/exp/YFCC100M_geoadalnmlp_von_fisher.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: yfcc_emb
|
5 |
+
- override /model: von_fisher
|
6 |
+
- override /model/network: geo_adaln_mlp_von_fisher
|
7 |
+
- override /model/loss: von_fisher
|
8 |
+
- override /model/val_sampler: von_fisher
|
9 |
+
- override /model/test_sampler: von_fisher
|
10 |
+
- _self_
|
11 |
+
|
12 |
+
model:
|
13 |
+
network:
|
14 |
+
depth: 11 # To compensate the increase in params
|
15 |
+
dim: 512
|
16 |
+
optimizer:
|
17 |
+
optim:
|
18 |
+
lr: 1e-4
|
19 |
+
weight_decay: 0.05
|
20 |
+
dataset:
|
21 |
+
full_batch_size: 1024
|
22 |
+
trainer:
|
23 |
+
gradient_clip_val: 0.05
|
24 |
+
gradient_clip_algorithm: norm
|
25 |
+
areas: []
|
26 |
+
experiment_name_suffix: von_fisher
|
plonk/configs/exp/YFCC100M_geoadalnmlp_von_fisher_mixture.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: yfcc_emb
|
5 |
+
- override /model: von_fisher_mixture
|
6 |
+
- override /model/network: geo_adaln_mlp_von_fisher_mixture
|
7 |
+
- override /model/loss: von_fisher_mixture
|
8 |
+
- override /model/val_sampler: von_fisher_mixture
|
9 |
+
- override /model/test_sampler: von_fisher_mixture
|
10 |
+
- _self_
|
11 |
+
|
12 |
+
model:
|
13 |
+
network:
|
14 |
+
depth: 11 # To compensate the increase in params
|
15 |
+
dim: 512
|
16 |
+
optimizer:
|
17 |
+
optim:
|
18 |
+
lr: 1e-5
|
19 |
+
weight_decay: 0.05
|
20 |
+
dataset:
|
21 |
+
full_batch_size: 1024
|
22 |
+
trainer:
|
23 |
+
gradient_clip_val: 0.01
|
24 |
+
gradient_clip_algorithm: norm
|
25 |
+
experiment_name_suffix: von_fisher_mixture
|
26 |
+
areas: []
|
plonk/configs/exp/combined_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: combined_emb
|
5 |
+
- override /model: emb_cond_cartesian
|
6 |
+
- override /model/network: geo_adaln_mlp
|
7 |
+
- override /model/train_noise_scheduler: sigmoid
|
8 |
+
- override /model/inference_noise_scheduler: sigmoid
|
9 |
+
- override /model/loss: riemannian_flow_matching
|
10 |
+
- override /model/manifold: sphere
|
11 |
+
- override /model/val_sampler: riemannian_flow_matching
|
12 |
+
- override /model/test_sampler: riemannian_flow_matching
|
13 |
+
- _self_
|
14 |
+
|
15 |
+
model:
|
16 |
+
network:
|
17 |
+
depth: 12
|
18 |
+
dim: 512
|
19 |
+
optimizer:
|
20 |
+
optim:
|
21 |
+
lr: 8e-4
|
22 |
+
weight_decay: 0.05
|
23 |
+
loss:
|
24 |
+
cond_drop_rate: 0.1
|
25 |
+
train_noise_scheduler:
|
26 |
+
start: -7
|
27 |
+
end: 3
|
28 |
+
tau: 1.0
|
29 |
+
inference_noise_scheduler:
|
30 |
+
start: -7
|
31 |
+
end: 3
|
32 |
+
tau: 1.0
|
33 |
+
interpolant: flow_matching
|
34 |
+
|
35 |
+
dataset:
|
36 |
+
full_batch_size: 1024
|
37 |
+
|
38 |
+
areas: []
|
39 |
+
|
40 |
+
experiment_name_suffix: small_sigmoid
|
plonk/configs/exp/iNaturalist_geoadalnmlp_r2_small_sigmoid_diffusion.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: inaturalist_emb
|
5 |
+
- override /model: emb_cond
|
6 |
+
- override /model/network: geo_adaln_mlp
|
7 |
+
- override /model/train_noise_scheduler: sigmoid
|
8 |
+
- override /model/inference_noise_scheduler: sigmoid
|
9 |
+
- override /model/loss: ddpm
|
10 |
+
- _self_
|
11 |
+
|
12 |
+
model:
|
13 |
+
network:
|
14 |
+
depth: 12
|
15 |
+
dim: 256
|
16 |
+
optimizer:
|
17 |
+
optim:
|
18 |
+
lr: 8e-4
|
19 |
+
weight_decay: 0.1
|
20 |
+
loss:
|
21 |
+
cond_drop_rate: 0.1
|
22 |
+
train_noise_scheduler:
|
23 |
+
start: -7
|
24 |
+
end: 3
|
25 |
+
tau: 1.0
|
26 |
+
inference_noise_scheduler:
|
27 |
+
start: -7
|
28 |
+
end: 3
|
29 |
+
tau: 1.0
|
30 |
+
interpolant: diffusion
|
31 |
+
dataset:
|
32 |
+
full_batch_size: 512
|
33 |
+
|
34 |
+
areas: []
|
35 |
+
|
36 |
+
experiment_name_suffix: small_sigmoid
|
plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_diffusion.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: inaturalist_emb
|
5 |
+
- override /model: emb_cond_cartesian
|
6 |
+
- override /model/network: geo_adaln_mlp
|
7 |
+
- override /model/train_noise_scheduler: sigmoid
|
8 |
+
- override /model/inference_noise_scheduler: sigmoid
|
9 |
+
- override /model/loss: ddpm
|
10 |
+
- _self_
|
11 |
+
|
12 |
+
model:
|
13 |
+
network:
|
14 |
+
depth: 12
|
15 |
+
dim: 256
|
16 |
+
optimizer:
|
17 |
+
optim:
|
18 |
+
lr: 8e-4
|
19 |
+
weight_decay: 0.1
|
20 |
+
loss:
|
21 |
+
cond_drop_rate: 0.1
|
22 |
+
train_noise_scheduler:
|
23 |
+
start: -7
|
24 |
+
end: 3
|
25 |
+
tau: 1.0
|
26 |
+
inference_noise_scheduler:
|
27 |
+
start: -7
|
28 |
+
end: 3
|
29 |
+
tau: 1.0
|
30 |
+
interpolant: diffusion
|
31 |
+
|
32 |
+
dataset:
|
33 |
+
full_batch_size: 512
|
34 |
+
|
35 |
+
areas: []
|
36 |
+
|
37 |
+
experiment_name_suffix: small_sigmoid
|
plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: inaturalist_emb
|
5 |
+
- override /model: emb_cond_cartesian
|
6 |
+
- override /model/network: geo_adaln_mlp
|
7 |
+
- override /model/train_noise_scheduler: sigmoid
|
8 |
+
- override /model/inference_noise_scheduler: sigmoid
|
9 |
+
- override /model/loss: flow_matching
|
10 |
+
- override /model/val_sampler: flow_matching
|
11 |
+
- override /model/test_sampler: flow_matching
|
12 |
+
- _self_
|
13 |
+
|
14 |
+
model:
|
15 |
+
network:
|
16 |
+
depth: 12
|
17 |
+
dim: 256
|
18 |
+
optimizer:
|
19 |
+
optim:
|
20 |
+
lr: 8e-4
|
21 |
+
weight_decay: 0.1
|
22 |
+
loss:
|
23 |
+
cond_drop_rate: 0.1
|
24 |
+
train_noise_scheduler:
|
25 |
+
start: -7
|
26 |
+
end: 3
|
27 |
+
tau: 1.0
|
28 |
+
inference_noise_scheduler:
|
29 |
+
start: -7
|
30 |
+
end: 3
|
31 |
+
tau: 1.0
|
32 |
+
interpolant: flow_matching
|
33 |
+
|
34 |
+
dataset:
|
35 |
+
full_batch_size: 512
|
36 |
+
|
37 |
+
areas: []
|
38 |
+
|
39 |
+
experiment_name_suffix: small_sigmoid
|
plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: inaturalist_emb
|
5 |
+
- override /model: emb_cond_cartesian
|
6 |
+
- override /model/network: geo_adaln_mlp
|
7 |
+
- override /model/train_noise_scheduler: sigmoid
|
8 |
+
- override /model/inference_noise_scheduler: sigmoid
|
9 |
+
- override /model/loss: riemannian_flow_matching
|
10 |
+
- override /model/manifold: sphere
|
11 |
+
- override /model/val_sampler: riemannian_flow_matching
|
12 |
+
- override /model/test_sampler: riemannian_flow_matching
|
13 |
+
- _self_
|
14 |
+
|
15 |
+
model:
|
16 |
+
network:
|
17 |
+
depth: 12
|
18 |
+
dim: 256
|
19 |
+
optimizer:
|
20 |
+
optim:
|
21 |
+
lr: 8e-4
|
22 |
+
weight_decay: 0.1
|
23 |
+
loss:
|
24 |
+
cond_drop_rate: 0.1
|
25 |
+
train_noise_scheduler:
|
26 |
+
start: -7
|
27 |
+
end: 3
|
28 |
+
tau: 1.0
|
29 |
+
inference_noise_scheduler:
|
30 |
+
start: -7
|
31 |
+
end: 3
|
32 |
+
tau: 1.0
|
33 |
+
interpolant: flow_matching
|
34 |
+
|
35 |
+
dataset:
|
36 |
+
full_batch_size: 512
|
37 |
+
|
38 |
+
areas: []
|
39 |
+
|
40 |
+
experiment_name_suffix: small_sigmoid
|
plonk/configs/exp/iNaturalist_geoadalnmlp_von_fisher.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: inaturalist_emb
|
5 |
+
- override /model: von_fisher
|
6 |
+
- override /model/network: geo_adaln_mlp_von_fisher
|
7 |
+
- override /model/loss: von_fisher
|
8 |
+
- override /model/val_sampler: von_fisher
|
9 |
+
- override /model/test_sampler: von_fisher
|
10 |
+
- _self_
|
11 |
+
|
12 |
+
model:
|
13 |
+
network:
|
14 |
+
depth: 11 # To compensate the increase in params
|
15 |
+
dim: 256
|
16 |
+
optimizer:
|
17 |
+
optim:
|
18 |
+
lr: 1e-4
|
19 |
+
weight_decay: 0.1
|
20 |
+
dataset:
|
21 |
+
full_batch_size: 512
|
22 |
+
trainer:
|
23 |
+
gradient_clip_val: 0.01
|
24 |
+
gradient_clip_algorithm: norm
|
25 |
+
areas: []
|
26 |
+
experiment_name_suffix: von_fisher
|
plonk/configs/exp/iNaturalist_geoadalnmlp_von_fisher_mixture.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: inaturalist_emb
|
5 |
+
- override /model: von_fisher_mixture
|
6 |
+
- override /model/network: geo_adaln_mlp_von_fisher_mixture
|
7 |
+
- override /model/loss: von_fisher_mixture
|
8 |
+
- override /model/val_sampler: von_fisher_mixture
|
9 |
+
- override /model/test_sampler: von_fisher_mixture
|
10 |
+
- _self_
|
11 |
+
|
12 |
+
model:
|
13 |
+
network:
|
14 |
+
depth: 11 # To compensate the increase in params
|
15 |
+
dim: 256
|
16 |
+
optimizer:
|
17 |
+
optim:
|
18 |
+
lr: 1e-5
|
19 |
+
weight_decay: 0.1
|
20 |
+
dataset:
|
21 |
+
full_batch_size: 512
|
22 |
+
trainer:
|
23 |
+
gradient_clip_val: 0.01
|
24 |
+
gradient_clip_algorithm: norm
|
25 |
+
areas: []
|
26 |
+
experiment_name_suffix: von_fisher_mixture
|
plonk/configs/exp/osv_5m_geoadalnmlp_r2_small_sigmoid_diffusion.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: osv5m_emb
|
5 |
+
- override /model: emb_cond
|
6 |
+
- override /model/network: geo_adaln_mlp
|
7 |
+
- override /model/train_noise_scheduler: sigmoid
|
8 |
+
- override /model/inference_noise_scheduler: sigmoid
|
9 |
+
- override /model/loss: ddpm
|
10 |
+
- _self_
|
11 |
+
|
12 |
+
model:
|
13 |
+
network:
|
14 |
+
depth: 12
|
15 |
+
dim: 512
|
16 |
+
optimizer:
|
17 |
+
optim:
|
18 |
+
lr: 8e-4
|
19 |
+
weight_decay: 0.05
|
20 |
+
loss:
|
21 |
+
cond_drop_rate: 0.1
|
22 |
+
train_noise_scheduler:
|
23 |
+
start: -7
|
24 |
+
end: 3
|
25 |
+
tau: 1.0
|
26 |
+
inference_noise_scheduler:
|
27 |
+
start: -7
|
28 |
+
end: 3
|
29 |
+
tau: 1.0
|
30 |
+
interpolant: diffusion
|
31 |
+
dataset:
|
32 |
+
full_batch_size: 1024
|
33 |
+
|
34 |
+
experiment_name_suffix: small_sigmoid
|
plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_linear_flow_riemann.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: osv5m_emb
|
5 |
+
- override /model: emb_cond_cartesian
|
6 |
+
- override /model/network: geo_adaln_mlp
|
7 |
+
- override /model/train_noise_scheduler: linear
|
8 |
+
- override /model/inference_noise_scheduler: linear
|
9 |
+
- override /model/loss: riemannian_flow_matching
|
10 |
+
- override /model/manifold: sphere
|
11 |
+
- override /model/val_sampler: riemannian_flow_matching
|
12 |
+
- override /model/test_sampler: riemannian_flow_matching
|
13 |
+
- _self_
|
14 |
+
|
15 |
+
model:
|
16 |
+
network:
|
17 |
+
depth: 12
|
18 |
+
dim: 512
|
19 |
+
optimizer:
|
20 |
+
optim:
|
21 |
+
lr: 8e-4
|
22 |
+
weight_decay: 0.05
|
23 |
+
loss:
|
24 |
+
cond_drop_rate: 0.1
|
25 |
+
interpolant: flow_matching
|
26 |
+
|
27 |
+
dataset:
|
28 |
+
full_batch_size: 1024
|
29 |
+
|
30 |
+
experiment_name_suffix: small_sigmoid
|
plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_sigmoid_diffusion.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /dataset: osv5m_emb
|
5 |
+
- override /model: emb_cond_cartesian
|
6 |
+
- override /model/network: geo_adaln_mlp
|
7 |
+
- override /model/train_noise_scheduler: sigmoid
|
8 |
+
- override /model/inference_noise_scheduler: sigmoid
|
9 |
+
- override /model/loss: ddpm
|
10 |
+
- _self_
|
11 |
+
|
12 |
+
model:
|
13 |
+
network:
|
14 |
+
depth: 12
|
15 |
+
dim: 512
|
16 |
+
optimizer:
|
17 |
+
optim:
|
18 |
+
lr: 8e-4
|
19 |
+
weight_decay: 0.05
|
20 |
+
loss:
|
21 |
+
cond_drop_rate: 0.1
|
22 |
+
train_noise_scheduler:
|
23 |
+
start: -7
|
24 |
+
end: 3
|
25 |
+
tau: 1.0
|
26 |
+
inference_noise_scheduler:
|
27 |
+
start: -7
|
28 |
+
end: 3
|
29 |
+
tau: 1.0
|
30 |
+
interpolant: diffusion
|
31 |
+
|
32 |
+
dataset:
|
33 |
+
full_batch_size: 1024
|
34 |
+
|
35 |
+
experiment_name_suffix: small_sigmoid
|