kylanoconnor commited on
Commit
fac3244
·
0 Parent(s):

Initial PLONK deployment for Hugging Face Spaces

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +58 -0
  2. README.md +78 -0
  3. app.py +132 -0
  4. plonk/__init__.py +1 -0
  5. plonk/callbacks/__init__.py +3 -0
  6. plonk/callbacks/data.py +11 -0
  7. plonk/callbacks/ema.py +102 -0
  8. plonk/callbacks/fix_nans.py +55 -0
  9. plonk/configs/computer/a100.yaml +8 -0
  10. plonk/configs/computer/cluster-node-a100.yaml +8 -0
  11. plonk/configs/computer/cluster-node-v100.yaml +8 -0
  12. plonk/configs/computer/cpu.yaml +8 -0
  13. plonk/configs/computer/h100.yaml +8 -0
  14. plonk/configs/computer/v100.yaml +8 -0
  15. plonk/configs/config.yaml +90 -0
  16. plonk/configs/dataset/combined_emb.yaml +38 -0
  17. plonk/configs/dataset/inaturalist_emb.yaml +38 -0
  18. plonk/configs/dataset/osv5m.yaml +43 -0
  19. plonk/configs/dataset/osv5m_emb.yaml +38 -0
  20. plonk/configs/dataset/test_transform/center_crop.yaml +12 -0
  21. plonk/configs/dataset/test_transform/clip.yaml +2 -0
  22. plonk/configs/dataset/test_transform/empty.yaml +2 -0
  23. plonk/configs/dataset/test_transform/fast_clip.yaml +12 -0
  24. plonk/configs/dataset/test_transform/fast_resnet.yaml +12 -0
  25. plonk/configs/dataset/test_transform/none.yaml +6 -0
  26. plonk/configs/dataset/train_transform/augmentation.yaml +85 -0
  27. plonk/configs/dataset/train_transform/center_crop.yaml +14 -0
  28. plonk/configs/dataset/train_transform/clip.yaml +2 -0
  29. plonk/configs/dataset/train_transform/empty.yaml +2 -0
  30. plonk/configs/dataset/train_transform/fast_clip.yaml +12 -0
  31. plonk/configs/dataset/train_transform/fast_resnet.yaml +12 -0
  32. plonk/configs/dataset/train_transform/none.yaml +7 -0
  33. plonk/configs/dataset/yfcc_emb.yaml +38 -0
  34. plonk/configs/exp/YFCC100M_geoadalnmlp_r2_small_sigmoid_diffusion.yaml +35 -0
  35. plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_linear_flow_rieman.yaml +32 -0
  36. plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_diffusion.yaml +36 -0
  37. plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow.yaml +38 -0
  38. plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml +40 -0
  39. plonk/configs/exp/YFCC100M_geoadalnmlp_von_fisher.yaml +26 -0
  40. plonk/configs/exp/YFCC100M_geoadalnmlp_von_fisher_mixture.yaml +26 -0
  41. plonk/configs/exp/combined_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml +40 -0
  42. plonk/configs/exp/iNaturalist_geoadalnmlp_r2_small_sigmoid_diffusion.yaml +36 -0
  43. plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_diffusion.yaml +37 -0
  44. plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow.yaml +39 -0
  45. plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml +40 -0
  46. plonk/configs/exp/iNaturalist_geoadalnmlp_von_fisher.yaml +26 -0
  47. plonk/configs/exp/iNaturalist_geoadalnmlp_von_fisher_mixture.yaml +26 -0
  48. plonk/configs/exp/osv_5m_geoadalnmlp_r2_small_sigmoid_diffusion.yaml +34 -0
  49. plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_linear_flow_riemann.yaml +30 -0
  50. plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_sigmoid_diffusion.yaml +35 -0
.gitignore ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces .gitignore
2
+
3
+ # Python cache
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+ *.so
8
+
9
+ # Development files
10
+ .ipynb_checkpoints/
11
+ .vscode/
12
+ .idea/
13
+ *.swp
14
+ *.swo
15
+
16
+ # OS files
17
+ .DS_Store
18
+ Thumbs.db
19
+
20
+ # Temporary files
21
+ *.tmp
22
+ *.log
23
+ *.pid
24
+
25
+ # Original demo files (using streamlit)
26
+ demo/demo.py
27
+
28
+ # Environment files
29
+ .env
30
+ .env.local
31
+
32
+ # Model checkpoints (will be downloaded automatically)
33
+ checkpoints/
34
+ *.safetensors
35
+ *.bin
36
+
37
+ # Large data files
38
+ data/
39
+ datasets/
40
+ *.csv
41
+ *.json
42
+
43
+ # Training artifacts
44
+ wandb/
45
+ logs/
46
+ outputs/
47
+
48
+ # Test files
49
+ test_*.py
50
+ *_test.py
51
+
52
+ # Documentation that's not needed for the Space
53
+ *.md
54
+ !README.md
55
+
56
+ # Git files
57
+ .git/
58
+ .gitmodules
README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PLONK Geolocation
3
+ emoji: 🗺️
4
+ colorFrom: red
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # 🗺️ PLONK: Around the World in 80 Timesteps
14
+
15
+ A generative approach to global visual geolocation. Upload an image and PLONK will predict where it was taken!
16
+
17
+ ## About
18
+
19
+ PLONK is a diffusion-based model that predicts the geographic location where a photo was taken based solely on its visual content. This Space uses the PLONK_YFCC model trained on the YFCC100M dataset.
20
+
21
+ ## Features
22
+
23
+ - **Simple Prediction**: Get a single high-confidence location prediction
24
+ - **Advanced Analysis**: Explore prediction uncertainty with multiple samples and guidance control
25
+ - **Fast CPU Inference**: ~300-500ms per image on CPU-Basic tier
26
+ - **GPU Ready**: Upgrade to T4-small for ~45ms inference time
27
+
28
+ ## Usage
29
+
30
+ 1. Upload an image using the interface
31
+ 2. Click "Submit" to get location predictions
32
+ 3. For advanced analysis, try different guidance scales:
33
+ - CFG = 0.0: More diverse predictions (good for uncertainty estimation)
34
+ - CFG = 2.0: Single confident prediction (best guess)
35
+
36
+ ## API Usage
37
+
38
+ This Space exposes a REST API compatible with Gradio's prediction format:
39
+
40
+ ```python
41
+ import requests
42
+
43
+ url = "https://your-space-name.hf.space/api/predict"
44
+ files = {"data": open("image.jpg", "rb")}
45
+ response = requests.post(url, files=files)
46
+ print(response.json())
47
+ ```
48
+
49
+ ## Model Performance
50
+
51
+ - **Latency**: 300-500ms on CPU-Basic, ~45ms on T4 GPU
52
+ - **Memory**: <1GB RAM usage
53
+ - **Throughput**: ~10 req/s on T4 before saturation
54
+
55
+ ## Scaling Options
56
+
57
+ - **Free CPU-Basic**: Perfect for testing and low-volume usage
58
+ - **T4-small ($0.40/hr)**: 10x faster inference for production
59
+ - **Inference Endpoints**: Auto-scaling with pay-per-use pricing
60
+
61
+ ## Citation
62
+
63
+ If you use PLONK in your research, please cite:
64
+
65
+ ```bibtex
66
+ @article{dufour2024plonk,
67
+ title={Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation},
68
+ author={Dufour, Nicolas and others},
69
+ journal={arXiv preprint},
70
+ year={2024}
71
+ }
72
+ ```
73
+
74
+ ## Links
75
+
76
+ - 📄 [Project Page](https://nicolas-dufour.github.io/plonk)
77
+ - 💻 [Code Repository](https://github.com/nicolas-dufour/plonk)
78
+ - 🤗 [Model on Hugging Face](https://huggingface.co/nicolas-dufour/PLONK_YFCC)
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from plonk.pipe import PlonkPipeline
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ # Initialize the pipeline
8
+ print("Loading PLONK_YFCC model...")
9
+ pipe = PlonkPipeline(model_path="nicolas-dufour/PLONK_YFCC")
10
+ print("Model loaded successfully!")
11
+
12
+ def predict_geolocation(image):
13
+ """
14
+ Predict geolocation from an uploaded image
15
+ Args:
16
+ image: PIL Image
17
+ Returns:
18
+ str: Formatted latitude and longitude
19
+ """
20
+ if image is None:
21
+ return "Please upload an image"
22
+
23
+ try:
24
+ # Get prediction using the pipeline
25
+ # Using single sample with high confidence (cfg=2.0) for best guess
26
+ predicted_gps = pipe(image, batch_size=1, cfg=2.0, num_steps=32)
27
+
28
+ # Extract latitude and longitude
29
+ lat, lon = float(predicted_gps[0, 0]), float(predicted_gps[0, 1])
30
+
31
+ # Format the result
32
+ result = f"Predicted Location:\nLatitude: {lat:.6f}\nLongitude: {lon:.6f}"
33
+
34
+ return result
35
+
36
+ except Exception as e:
37
+ return f"Error during prediction: {str(e)}"
38
+
39
+ def predict_geolocation_with_samples(image, num_samples=64, cfg=0.0):
40
+ """
41
+ Predict geolocation with multiple samples for uncertainty visualization
42
+ Args:
43
+ image: PIL Image
44
+ num_samples: Number of samples to generate
45
+ cfg: Classifier-free guidance scale
46
+ Returns:
47
+ str: Formatted results with statistics
48
+ """
49
+ if image is None:
50
+ return "Please upload an image"
51
+
52
+ try:
53
+ # Get multiple predictions for uncertainty estimation
54
+ predicted_gps = pipe(image, batch_size=num_samples, cfg=cfg, num_steps=32)
55
+
56
+ # Calculate statistics
57
+ lats = predicted_gps[:, 0].astype(float)
58
+ lons = predicted_gps[:, 1].astype(float)
59
+
60
+ mean_lat, mean_lon = np.mean(lats), np.mean(lons)
61
+ std_lat, std_lon = np.std(lats), np.std(lons)
62
+
63
+ # Get high confidence prediction
64
+ high_conf_gps = pipe(image, batch_size=1, cfg=2.0, num_steps=32)
65
+ conf_lat, conf_lon = float(high_conf_gps[0, 0]), float(high_conf_gps[0, 1])
66
+
67
+ result = f"""Geolocation Prediction Results:
68
+
69
+ High Confidence Prediction (CFG=2.0):
70
+ Latitude: {conf_lat:.6f}
71
+ Longitude: {conf_lon:.6f}
72
+
73
+ Sample Statistics ({num_samples} samples, CFG={cfg}):
74
+ Mean Latitude: {mean_lat:.6f} ± {std_lat:.6f}
75
+ Mean Longitude: {mean_lon:.6f} ± {std_lon:.6f}
76
+ """
77
+
78
+ return result
79
+
80
+ except Exception as e:
81
+ return f"Error during prediction: {str(e)}"
82
+
83
+ # Create the Gradio interface for simple prediction
84
+ simple_interface = gr.Interface(
85
+ fn=predict_geolocation,
86
+ inputs=gr.Image(type="pil", label="Upload an image"),
87
+ outputs=gr.Textbox(label="Predicted Location", lines=4),
88
+ title="🗺️ PLONK: Global Visual Geolocation",
89
+ description="""
90
+ Upload an image and PLONK will predict where it was taken!
91
+
92
+ This uses the PLONK_YFCC model trained on the YFCC100M dataset.
93
+ The model predicts latitude and longitude coordinates based on visual content.
94
+
95
+ **Note**: This is running on CPU, so processing may take 300-500ms per image.
96
+ """,
97
+ examples=[
98
+ ["demo/examples/condor.jpg"],
99
+ ["demo/examples/Kilimanjaro.jpg"],
100
+ ["demo/examples/pigeon.png"]
101
+ ] if any(Path("demo/examples").glob("*")) else None
102
+ )
103
+
104
+ # Create advanced interface with sampling options
105
+ advanced_interface = gr.Interface(
106
+ fn=predict_geolocation_with_samples,
107
+ inputs=[
108
+ gr.Image(type="pil", label="Upload an image"),
109
+ gr.Slider(1, 256, value=64, step=1, label="Number of samples"),
110
+ gr.Slider(0.0, 5.0, value=0.0, step=0.1, label="Guidance scale (CFG)")
111
+ ],
112
+ outputs=gr.Textbox(label="Detailed Results", lines=10),
113
+ title="🗺️ PLONK: Advanced Geolocation with Uncertainty",
114
+ description="""
115
+ Advanced interface showing prediction uncertainty through multiple samples.
116
+
117
+ - **Number of samples**: More samples = better uncertainty estimation (but slower)
118
+ - **Guidance scale**: Higher values = more confident predictions (try 2.0 for best single guess)
119
+ """,
120
+ )
121
+
122
+ # Create tabbed interface
123
+ demo = gr.TabbedInterface(
124
+ [simple_interface, advanced_interface],
125
+ ["Simple Prediction", "Advanced Analysis"],
126
+ title="PLONK: Around the World in 80 Timesteps"
127
+ )
128
+
129
+ if __name__ == "__main__":
130
+ # Add necessary import for pathlib
131
+ from pathlib import Path
132
+ demo.launch()
plonk/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .pipe import PlonkPipeline
plonk/callbacks/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .ema import EMACallback
2
+ from .fix_nans import FixNANinGrad
3
+ from .data import IncreaseDataEpoch
plonk/callbacks/data.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning.callbacks import Callback
2
+
3
+
4
+ class IncreaseDataEpoch(Callback):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ def on_train_epoch_start(self, trainer, pl_module):
9
+ epoch = pl_module.current_epoch
10
+ if hasattr(trainer.datamodule.train_dataset, "shared_epoch"):
11
+ trainer.datamodule.train_dataset.shared_epoch.set_value(epoch)
plonk/callbacks/ema.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning import Callback
2
+ import copy
3
+ import itertools
4
+ import torch
5
+ import contextlib
6
+ from torch.distributed.fsdp import FullyShardedDataParallel
7
+
8
+
9
+ class EMACallback(Callback):
10
+ def __init__(
11
+ self,
12
+ module_attr_name,
13
+ ema_module_attr_name,
14
+ decay=0.999,
15
+ start_ema_step=0,
16
+ init_ema_random=True,
17
+ ):
18
+ super().__init__()
19
+ self.decay = decay
20
+ self.module_attr_name = module_attr_name
21
+ self.ema_module_attr_name = ema_module_attr_name
22
+ self.start_ema_step = start_ema_step
23
+ self.init_ema_random = init_ema_random
24
+
25
+ def on_train_start(self, trainer, pl_module):
26
+ if pl_module.global_step == 0:
27
+ if not hasattr(pl_module, self.module_attr_name):
28
+ raise ValueError(
29
+ f"Module {pl_module} does not have attribute {self.module_attr_name}"
30
+ )
31
+ if not hasattr(pl_module, self.ema_module_attr_name):
32
+ pl_module.add_module(
33
+ self.ema_module_attr_name,
34
+ copy.deepcopy(getattr(pl_module, self.module_attr_name))
35
+ .eval()
36
+ .requires_grad_(False),
37
+ )
38
+ self.reset_ema(pl_module)
39
+
40
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
41
+ if pl_module.global_step == self.start_ema_step:
42
+ self.reset_ema(pl_module)
43
+ elif (
44
+ pl_module.global_step < self.start_ema_step
45
+ and pl_module.global_step % 100 == 0
46
+ ):
47
+ ## slow ema updates for visualisation
48
+ self.update_ema(pl_module, decay=0.9)
49
+ elif pl_module.global_step > self.start_ema_step:
50
+ self.update_ema(pl_module, decay=self.decay)
51
+
52
+ def update_ema(self, pl_module, decay=0.999):
53
+ ema_module = getattr(pl_module, self.ema_module_attr_name)
54
+ module = getattr(pl_module, self.module_attr_name)
55
+ context_manager = self.get_model_context_manager(module)
56
+ with context_manager:
57
+ with torch.no_grad():
58
+ ema_params = ema_module.state_dict()
59
+ for name, param in itertools.chain(
60
+ module.named_parameters(), module.named_buffers()
61
+ ):
62
+ if name in ema_params:
63
+ if param.requires_grad:
64
+ ema_params[name].copy_(
65
+ ema_params[name].detach().lerp(param.detach(), decay)
66
+ )
67
+
68
+ def get_model_context_manager(self, module):
69
+ fsdp_enabled = is_model_fsdp(module)
70
+ model_context_manager = contextlib.nullcontext()
71
+ if fsdp_enabled:
72
+ model_context_manager = module.summon_full_params(module)
73
+ return model_context_manager
74
+
75
+ def reset_ema(self, pl_module):
76
+ ema_module = getattr(pl_module, self.ema_module_attr_name)
77
+ if self.init_ema_random:
78
+ ema_module.init_weights()
79
+ else:
80
+ module = getattr(pl_module, self.module_attr_name)
81
+ context_manager = self.get_model_context_manager(module)
82
+ with context_manager:
83
+ ema_params = ema_module.state_dict()
84
+ for name, param in itertools.chain(
85
+ module.named_parameters(), module.named_buffers()
86
+ ):
87
+ if name in ema_params:
88
+ ema_params[name].copy_(param.detach())
89
+
90
+
91
+ def is_model_fsdp(model: torch.nn.Module) -> bool:
92
+ try:
93
+ if isinstance(model, FullyShardedDataParallel):
94
+ return True
95
+
96
+ # Check if model is wrapped with FSDP
97
+ for _, obj in model.named_children():
98
+ if isinstance(obj, FullyShardedDataParallel):
99
+ return True
100
+ return False
101
+ except ImportError:
102
+ return False
plonk/callbacks/fix_nans.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pytorch_lightning.callbacks import Callback
3
+ import torch
4
+
5
+ log = logging.getLogger(__name__)
6
+
7
+
8
+ class FixNANinGrad(Callback):
9
+ def __init__(self, monitor):
10
+ super().__init__()
11
+ self.monitor = monitor
12
+ self.continuous_nan_batchs = 0
13
+
14
+ def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None:
15
+ has_nan = []
16
+ is_inf = []
17
+ for name, param in pl_module.named_parameters():
18
+ if param.grad is not None:
19
+ if torch.isnan(param.grad).any():
20
+ has_nan.append(name)
21
+ if torch.isinf(param.grad).any():
22
+ is_inf.append(name)
23
+ torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad)
24
+ if len(has_nan) > 0:
25
+ print(f"Found NaN in {has_nan}")
26
+ if len(is_inf) > 0:
27
+ print(f"Found Inf in {is_inf}")
28
+
29
+ def on_train_batch_end(
30
+ self,
31
+ trainer,
32
+ pl_module,
33
+ outputs,
34
+ batch,
35
+ batch_idx,
36
+ ) -> None:
37
+ logs = trainer.callback_metrics
38
+ i = 0
39
+ found_metric = False
40
+ while i < len(self.monitor) and not found_metric:
41
+ if self.monitor[i] in logs.keys():
42
+ current = logs[self.monitor[i]].squeeze()
43
+ found_metric = True
44
+ else:
45
+ i += 1
46
+ if not found_metric:
47
+ raise ValueError("Asked metric not in logs")
48
+
49
+ if not torch.isfinite(current):
50
+ self.continuous_nan_batchs += 1
51
+ if self.continuous_nan_batchs >= 5:
52
+ trainer.should_stop = True
53
+ log.info("Training interrupted because of NaN in {self.monitor}")
54
+ else:
55
+ self.continuous_nan_batchs = 0
plonk/configs/computer/a100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 1
2
+ progress_bar_refresh_rate: 2
3
+ num_workers: 8
4
+ sync_batchnorm: False
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: auto
8
+ num_nodes: 1
plonk/configs/computer/cluster-node-a100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 8
2
+ num_workers: 8
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: True
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: ddp
8
+ num_nodes: 1
plonk/configs/computer/cluster-node-v100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 4
2
+ num_workers: 10
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: True
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: ddp
8
+ num_nodes: 1
plonk/configs/computer/cpu.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: null
2
+ num_workers: 0
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: False
5
+ accelerator: cpu
6
+ precision: 32
7
+ strategy: auto
8
+ num_nodes: null
plonk/configs/computer/h100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 1
2
+ progress_bar_refresh_rate: 2
3
+ num_workers: 24
4
+ sync_batchnorm: False
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: auto
8
+ num_nodes: 1
plonk/configs/computer/v100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 1
2
+ num_workers: 10
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: False
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: auto
8
+ num_nodes: 1
plonk/configs/config.yaml ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model: default
3
+ - computer: v100
4
+ - dataset: osv5m_emb
5
+ - stage: null
6
+ - _self_
7
+ - exp: ???
8
+
9
+ model:
10
+ val_metrics:
11
+ _target_: metrics.distance_based.HaversineMetrics
12
+ acc_radiuses:
13
+ - 1
14
+ - 25
15
+ - 200
16
+ - 750
17
+ - 2500
18
+ acc_area: []
19
+ test_metrics:
20
+ _target_: metrics.distance_based.HaversineMetrics
21
+ acc_radiuses:
22
+ - 1
23
+ - 25
24
+ - 200
25
+ - 750
26
+ - 2500
27
+ acc_area: ${areas}
28
+
29
+ datamodule:
30
+ _target_: plonk.data.datamodule.ImageDataModule
31
+ train_dataset: ${dataset.train_dataset}
32
+ val_dataset: ${dataset.val_dataset}
33
+ test_dataset: ${dataset.test_dataset}
34
+ full_batch_size: ${dataset.full_batch_size}
35
+ eval_batch_size: ${dataset.eval_batch_size}
36
+ num_workers: ${computer.num_workers}
37
+ num_nodes: ${computer.num_nodes}
38
+ num_devices: ${computer.devices}
39
+ val_proportion: 0.02
40
+
41
+ trainer:
42
+ _target_: pytorch_lightning.Trainer
43
+ devices: ${computer.devices}
44
+ accelerator: ${computer.accelerator}
45
+ strategy: ${computer.strategy}
46
+ num_nodes: ${computer.num_nodes}
47
+ precision: ${computer.precision}
48
+ max_steps: 1000000
49
+ val_check_interval: 25000
50
+ check_val_every_n_epoch: null
51
+
52
+ logger:
53
+ _target_: pytorch_lightning.loggers.WandbLogger
54
+ save_dir: ${root_dir}/plonk
55
+ name: ${experiment_name}${logger_suffix}
56
+ project: diff_plonk
57
+ log_model: False
58
+ offline: False
59
+
60
+ checkpoints:
61
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
62
+ dirpath: ${root_dir}/plonk/checkpoints/${experiment_name}
63
+ filename: 'epoch_{epoch}'
64
+ monitor: val/loss
65
+ save_last: True
66
+ save_top_k: 0
67
+ every_n_epochs: 1
68
+ enable_version_counter: False
69
+
70
+ progress_bar:
71
+ _target_: pytorch_lightning.callbacks.TQDMProgressBar
72
+ refresh_rate: ${computer.progress_bar_refresh_rate}
73
+
74
+ data_dir: ${root_dir}/plonk/datasets
75
+ root_dir: ${hydra:runtime.cwd}
76
+ experiment_name: ${dataset.name}_${model.name}_${experiment_name_suffix}
77
+ experiment_name_suffix: base
78
+ logger_suffix: ""
79
+ mode: train # change that to eval to do the testing
80
+ areas: ['country', 'region', 'sub-region', 'city']
81
+ class_name: null
82
+ streetclip: False
83
+ blur: False
84
+ text_tuning: False
85
+
86
+ hydra:
87
+ run:
88
+ dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name}
89
+ job:
90
+ chdir: true
plonk/configs/dataset/combined_emb.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: empty
3
+ - test_transform: empty
4
+ - _self_
5
+
6
+ name: iNaturalist_OSV5M_YFCC100M_${dataset.embedding_name}
7
+ full_batch_size: 2048
8
+ cond_dim: 1024
9
+ eval_batch_size: 4096
10
+ output_type: emb
11
+ embedding_name: dinov2_vitl14_registers
12
+
13
+ train_dataset:
14
+ _partial_: true
15
+ _target_: plonk.data.webdataset.GPSWebdataset
16
+ root: ${data_dir}/YFCC100M/train/ ${data_dir}/osv5m/train/ ${data_dir}/inaturalist/train/ ${data_dir}/osv5m/train/ ${data_dir}/inaturalist/train/
17
+ train: true
18
+ embedding_name: ${dataset.embedding_name}
19
+ return_image: false
20
+ metadata_attributes: []
21
+
22
+ val_dataset:
23
+ _partial_: true
24
+ _target_: plonk.data.webdataset.GPSWebdataset
25
+ root: ${data_dir}/YFCC100M/yfcc4k/
26
+ train: false
27
+ embedding_name: ${dataset.embedding_name}
28
+ return_image: false
29
+ metadata_attributes: []
30
+
31
+ test_dataset:
32
+ _partial_: true
33
+ _target_: plonk.data.webdataset.GPSWebdataset
34
+ root: ${data_dir}/YFCC100M/yfcc4k/
35
+ train: false
36
+ embedding_name: ${dataset.embedding_name}
37
+ return_image: false
38
+ metadata_attributes: []
plonk/configs/dataset/inaturalist_emb.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: empty
3
+ - test_transform: empty
4
+ - _self_
5
+
6
+ name: iNaturalist_${dataset.embedding_name}
7
+ full_batch_size: 512
8
+ cond_dim: 1024
9
+ eval_batch_size: 4096
10
+ output_type: emb
11
+ embedding_name: dinov2_vitl14_registers
12
+
13
+ train_dataset:
14
+ _partial_: true
15
+ _target_: plonk.data.webdataset.GPSWebdataset
16
+ root: ${data_dir}/inaturalist/train/
17
+ train: true
18
+ embedding_name: ${dataset.embedding_name}
19
+ return_image: false
20
+ metadata_attributes: []
21
+
22
+ val_dataset:
23
+ _partial_: true
24
+ _target_: plonk.data.webdataset.GPSWebdataset
25
+ root: ${data_dir}/inaturalist/val/
26
+ train: false
27
+ embedding_name: ${dataset.embedding_name}
28
+ return_image: false
29
+ metadata_attributes: []
30
+
31
+ test_dataset:
32
+ _partial_: true
33
+ _target_: plonk.data.webdataset.GPSWebdataset
34
+ root: ${data_dir}/inaturalist/test/
35
+ train: false
36
+ embedding_name: ${dataset.embedding_name}
37
+ return_image: false
38
+ metadata_attributes: []
plonk/configs/dataset/osv5m.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: fast_clip
3
+ - test_transform: fast_clip
4
+ - _self_
5
+
6
+ name: osv5m
7
+ full_batch_size: 2048
8
+ eval_batch_size: 4096
9
+ train_dataset:
10
+ _partial_: true
11
+ _target_: plonk.data.data.OSV5M
12
+ path: ${data_dir}/osv5m/
13
+ split: train
14
+ class_name: ${class_name}
15
+ transforms: ${dataset.train_transform}
16
+ is_baseline: ${is_baseline}
17
+ areas: ${areas}
18
+ streetclip: ${streetclip}
19
+ blur: ${blur}
20
+
21
+ val_dataset:
22
+ _partial_: true
23
+ _target_: plonk.data.data.OSV5M
24
+ path: ${data_dir}/osv5m/
25
+ split: val
26
+ class_name: ${class_name}
27
+ transforms: ${dataset.test_transform}
28
+ is_baseline: ${is_baseline}
29
+ areas: ${areas}
30
+ streetclip: ${streetclip}
31
+ blur: ${blur}
32
+
33
+ test_dataset:
34
+ _partial_: true
35
+ _target_: plonk.data.data.OSV5M
36
+ path: ${data_dir}/osv5m/
37
+ split: test
38
+ class_name: ${class_name}
39
+ transforms: ${dataset.test_transform}
40
+ is_baseline: ${is_baseline}
41
+ areas: ${areas}
42
+ streetclip: ${streetclip}
43
+ blur: ${blur}
plonk/configs/dataset/osv5m_emb.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: empty
3
+ - test_transform: empty
4
+ - _self_
5
+
6
+ name: osv5m_${dataset.embedding_name}
7
+ full_batch_size: 1024
8
+ eval_batch_size: 4096
9
+ cond_dim: 1024
10
+ output_type: emb
11
+ embedding_name: street_clip
12
+
13
+ train_dataset:
14
+ _partial_: true
15
+ _target_: plonk.data.webdataset.GPSWebdataset
16
+ root: ${data_dir}/osv5m/train/
17
+ train: true
18
+ embedding_name: ${dataset.embedding_name}
19
+ return_image: false
20
+ metadata_attributes: []
21
+
22
+ val_dataset:
23
+ _partial_: true
24
+ _target_: plonk.data.webdataset.GPSWebdataset
25
+ root: ${data_dir}/osv5m/val/
26
+ train: false
27
+ embedding_name: ${dataset.embedding_name}
28
+ return_image: false
29
+ metadata_attributes: ["unique_country", "unique_region", "unique_sub-region", "unique_city"]
30
+
31
+ test_dataset:
32
+ _partial_: true
33
+ _target_: plonk.data.webdataset.GPSWebdataset
34
+ root: ${data_dir}/osv5m/test/
35
+ train: false
36
+ embedding_name: ${dataset.embedding_name}
37
+ return_image: false
38
+ metadata_attributes: ["unique_country", "unique_region", "unique_sub-region", "unique_city"]
plonk/configs/dataset/test_transform/center_crop.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.ToTensor
4
+ - _target_: plonk.utils.image_processing.CenterCrop
5
+ ratio: "1:1"
6
+ - _target_: torchvision.transforms.Resize
7
+ size: ${dataset.img_resolution}
8
+ interpolation: 3
9
+ antialias: true
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: 0.5
12
+ std: 0.5
plonk/configs/dataset/test_transform/clip.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: plonk.data.transforms.ClipTransform
2
+ split: val
plonk/configs/dataset/test_transform/empty.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: plonk.data.data.null_transform
2
+ _partial_: true
plonk/configs/dataset/test_transform/fast_clip.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.48145466, 0.4578275, 0.40821073]
12
+ std: [0.26862954, 0.26130258, 0.27577711]
plonk/configs/dataset/test_transform/fast_resnet.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.485 ,0.456 ,0.406]
12
+ std: [0.229, 0.224, 0.225]
plonk/configs/dataset/test_transform/none.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.ToTensor
4
+ - _target_: torchvision.transforms.Normalize
5
+ mean: 0.5
6
+ std: 0.5
plonk/configs/dataset/train_transform/augmentation.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: plonk.data.augmentation.ImageAugmentation
2
+ names: "standard_augmentation,geometric_augmentation,clip_transform"
3
+
4
+ # always apply clip_transform at the end
5
+ clip_transform:
6
+ _target_: torchvision.transforms.Compose
7
+ transforms:
8
+ - _target_: torchvision.transforms.Resize
9
+ size: 224
10
+ interpolation: 3
11
+ antialias: true
12
+ - _target_: torchvision.transforms.CenterCrop
13
+ size: 224
14
+ - _target_: torchvision.transforms.ToTensor
15
+ - _target_: torchvision.transforms.Normalize
16
+ mean: [0.48145466, 0.4578275, 0.40821073]
17
+ std: [0.26862954, 0.26130258, 0.27577711]
18
+
19
+ standard_augmentation:
20
+ _target_: plonk.data.augmentation.StandardAugmentation
21
+ # by default, we all augmentation methods
22
+ names: "brightness,contrast,sharpness,color,blur,gaussian_noise"
23
+
24
+ # random PIL brigtness
25
+ brightness:
26
+ _target_: plonk.data.augmentation.PillowBrightness
27
+ p: 0.2
28
+ factor_interval: [0.5, 1.5]
29
+
30
+ # random PIL contrast
31
+ contrast:
32
+ _target_: plonk.data.augmentation.PillowContrast
33
+ p: 0.2
34
+ factor_interval: [0.3, 3]
35
+
36
+ # random PIL sharpness
37
+ sharpness:
38
+ _target_: plonk.data.augmentation.PillowSharpness
39
+ p: 0.2
40
+ factor_interval: [0.5, 30.0]
41
+
42
+ # random PIL color
43
+ color:
44
+ _target_: plonk.data.augmentation.PillowColor
45
+ p: 0.2
46
+ factor_interval: [0.0, 2.0]
47
+
48
+ # random PIL blur
49
+ blur:
50
+ _target_: plonk.data.augmentation.PillowBlur
51
+ p: 0.2
52
+ factor_interval: [1, 2]
53
+
54
+ # random numpy gaussian noise
55
+ gaussian_noise:
56
+ _target_: plonk.data.augmentation.NumpyGaussianNoise
57
+ p: 0.2
58
+ factor_interval: [0.1, 0.04]
59
+
60
+ geometric_augmentation:
61
+ _target_: plonk.data.augmentation.GeometricAugmentation
62
+ # by default, we all augmentation methods
63
+ names: "random_rotation,random_resized_crop,random_horizontal_flip"
64
+
65
+ # random rotation
66
+ random_rotation:
67
+ _target_: torchvision.transforms.RandomRotation
68
+ degrees: [-15, 15]
69
+
70
+ # random crop
71
+ random_resized_crop:
72
+ _target_: torchvision.transforms.RandomResizedCrop
73
+ scale: [0.5, 1.0]
74
+ ratio: [0.9, 1.1]
75
+ size: 224
76
+
77
+ # random horizontal flip
78
+ random_horizontal_flip:
79
+ _target_: torchvision.transforms.RandomHorizontalFlip
80
+ p: 0.5
81
+
82
+ # random vertical flip
83
+ random_vertical_flip:
84
+ _target_: torchvision.transforms.RandomVerticalFlip
85
+ p: 0.5
plonk/configs/dataset/train_transform/center_crop.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.ToTensor
4
+ - _target_: plonk.utils.image_processing.CenterCrop
5
+ ratio: "1:1"
6
+ - _target_: torchvision.transforms.Resize
7
+ size: ${dataset.img_resolution}
8
+ interpolation: 3
9
+ antialias: true
10
+ - _target_: torchvision.transforms.RandomHorizontalFlip
11
+ p: 0.5
12
+ - _target_: torchvision.transforms.Normalize
13
+ mean: 0.5
14
+ std: 0.5
plonk/configs/dataset/train_transform/clip.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: plonk.data.transforms.ClipTransform
2
+ split: val
plonk/configs/dataset/train_transform/empty.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: plonk.data.data.null_transform
2
+ _partial_: true
plonk/configs/dataset/train_transform/fast_clip.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.48145466, 0.4578275, 0.40821073]
12
+ std: [0.26862954, 0.26130258, 0.27577711]
plonk/configs/dataset/train_transform/fast_resnet.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.485 ,0.456 ,0.406]
12
+ std: [0.229, 0.224, 0.225]
plonk/configs/dataset/train_transform/none.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.ToTensor
plonk/configs/dataset/yfcc_emb.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: empty
3
+ - test_transform: empty
4
+ - _self_
5
+
6
+ name: iNaturalist_${dataset.embedding_name}
7
+ full_batch_size: 2048
8
+ cond_dim: 1024
9
+ eval_batch_size: 4096
10
+ output_type: emb
11
+ embedding_name: dinov2_vitl14_registers
12
+
13
+ train_dataset:
14
+ _partial_: true
15
+ _target_: plonk.data.webdataset.GPSWebdataset
16
+ root: ${data_dir}/YFCC100M/train/
17
+ train: true
18
+ embedding_name: ${dataset.embedding_name}
19
+ return_image: false
20
+ metadata_attributes: []
21
+
22
+ val_dataset:
23
+ _partial_: true
24
+ _target_: plonk.data.webdataset.GPSWebdataset
25
+ root: ${data_dir}/YFCC100M/yfcc4k/
26
+ train: false
27
+ embedding_name: ${dataset.embedding_name}
28
+ return_image: false
29
+ metadata_attributes: []
30
+
31
+ test_dataset:
32
+ _partial_: true
33
+ _target_: plonk.data.webdataset.GPSWebdataset
34
+ root: ${data_dir}/YFCC100M/yfcc4k/
35
+ train: false
36
+ embedding_name: ${dataset.embedding_name}
37
+ return_image: false
38
+ metadata_attributes: []
plonk/configs/exp/YFCC100M_geoadalnmlp_r2_small_sigmoid_diffusion.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: yfcc_emb
5
+ - override /model: emb_cond
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: ddpm
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 12
15
+ dim: 512
16
+ optimizer:
17
+ optim:
18
+ lr: 8e-4
19
+ weight_decay: 0.05
20
+ loss:
21
+ cond_drop_rate: 0.1
22
+ train_noise_scheduler:
23
+ start: -7
24
+ end: 3
25
+ tau: 1.0
26
+ inference_noise_scheduler:
27
+ start: -7
28
+ end: 3
29
+ tau: 1.0
30
+ interpolant: diffusion
31
+ dataset:
32
+ full_batch_size: 1024
33
+
34
+ experiment_name_suffix: small_sigmoid
35
+ areas: []
plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_linear_flow_rieman.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: yfcc_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: linear
8
+ - override /model/inference_noise_scheduler: linear
9
+ - override /model/loss: riemannian_flow_matching
10
+ - override /model/manifold: sphere
11
+ - override /model/val_sampler: riemannian_flow_matching
12
+ - override /model/test_sampler: riemannian_flow_matching
13
+ - _self_
14
+
15
+ model:
16
+ network:
17
+ depth: 12
18
+ dim: 512
19
+ optimizer:
20
+ optim:
21
+ lr: 8e-4
22
+ weight_decay: 0.05
23
+ loss:
24
+ cond_drop_rate: 0.1
25
+ interpolant: flow_matching
26
+
27
+ dataset:
28
+ full_batch_size: 1024
29
+
30
+ areas: []
31
+
32
+ experiment_name_suffix: small_sigmoid
plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_diffusion.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: yfcc_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: ddpm
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 12
15
+ dim: 512
16
+ optimizer:
17
+ optim:
18
+ lr: 8e-4
19
+ weight_decay: 0.05
20
+ loss:
21
+ cond_drop_rate: 0.1
22
+ train_noise_scheduler:
23
+ start: -7
24
+ end: 3
25
+ tau: 1.0
26
+ inference_noise_scheduler:
27
+ start: -7
28
+ end: 3
29
+ tau: 1.0
30
+ interpolant: diffusion
31
+
32
+ dataset:
33
+ full_batch_size: 1024
34
+
35
+ experiment_name_suffix: small_sigmoid
36
+ areas: []
plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: yfcc_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: flow_matching
10
+ - override /model/val_sampler: flow_matching
11
+ - override /model/test_sampler: flow_matching
12
+ - _self_
13
+
14
+ model:
15
+ network:
16
+ depth: 12
17
+ dim: 512
18
+ optimizer:
19
+ optim:
20
+ lr: 8e-4
21
+ weight_decay: 0.05
22
+ loss:
23
+ cond_drop_rate: 0.1
24
+ train_noise_scheduler:
25
+ start: -7
26
+ end: 3
27
+ tau: 1.0
28
+ inference_noise_scheduler:
29
+ start: -7
30
+ end: 3
31
+ tau: 1.0
32
+ interpolant: flow_matching
33
+
34
+ dataset:
35
+ full_batch_size: 1024
36
+
37
+ experiment_name_suffix: small_sigmoid
38
+ areas: []
plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: yfcc_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: riemannian_flow_matching
10
+ - override /model/manifold: sphere
11
+ - override /model/val_sampler: riemannian_flow_matching
12
+ - override /model/test_sampler: riemannian_flow_matching
13
+ - _self_
14
+
15
+ model:
16
+ network:
17
+ depth: 12
18
+ dim: 512
19
+ optimizer:
20
+ optim:
21
+ lr: 8e-4
22
+ weight_decay: 0.05
23
+ loss:
24
+ cond_drop_rate: 0.1
25
+ train_noise_scheduler:
26
+ start: -7
27
+ end: 3
28
+ tau: 1.0
29
+ inference_noise_scheduler:
30
+ start: -7
31
+ end: 3
32
+ tau: 1.0
33
+ interpolant: flow_matching
34
+
35
+ dataset:
36
+ full_batch_size: 1024
37
+
38
+ areas: []
39
+
40
+ experiment_name_suffix: small_sigmoid
plonk/configs/exp/YFCC100M_geoadalnmlp_von_fisher.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: yfcc_emb
5
+ - override /model: von_fisher
6
+ - override /model/network: geo_adaln_mlp_von_fisher
7
+ - override /model/loss: von_fisher
8
+ - override /model/val_sampler: von_fisher
9
+ - override /model/test_sampler: von_fisher
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 11 # To compensate the increase in params
15
+ dim: 512
16
+ optimizer:
17
+ optim:
18
+ lr: 1e-4
19
+ weight_decay: 0.05
20
+ dataset:
21
+ full_batch_size: 1024
22
+ trainer:
23
+ gradient_clip_val: 0.05
24
+ gradient_clip_algorithm: norm
25
+ areas: []
26
+ experiment_name_suffix: von_fisher
plonk/configs/exp/YFCC100M_geoadalnmlp_von_fisher_mixture.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: yfcc_emb
5
+ - override /model: von_fisher_mixture
6
+ - override /model/network: geo_adaln_mlp_von_fisher_mixture
7
+ - override /model/loss: von_fisher_mixture
8
+ - override /model/val_sampler: von_fisher_mixture
9
+ - override /model/test_sampler: von_fisher_mixture
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 11 # To compensate the increase in params
15
+ dim: 512
16
+ optimizer:
17
+ optim:
18
+ lr: 1e-5
19
+ weight_decay: 0.05
20
+ dataset:
21
+ full_batch_size: 1024
22
+ trainer:
23
+ gradient_clip_val: 0.01
24
+ gradient_clip_algorithm: norm
25
+ experiment_name_suffix: von_fisher_mixture
26
+ areas: []
plonk/configs/exp/combined_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: combined_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: riemannian_flow_matching
10
+ - override /model/manifold: sphere
11
+ - override /model/val_sampler: riemannian_flow_matching
12
+ - override /model/test_sampler: riemannian_flow_matching
13
+ - _self_
14
+
15
+ model:
16
+ network:
17
+ depth: 12
18
+ dim: 512
19
+ optimizer:
20
+ optim:
21
+ lr: 8e-4
22
+ weight_decay: 0.05
23
+ loss:
24
+ cond_drop_rate: 0.1
25
+ train_noise_scheduler:
26
+ start: -7
27
+ end: 3
28
+ tau: 1.0
29
+ inference_noise_scheduler:
30
+ start: -7
31
+ end: 3
32
+ tau: 1.0
33
+ interpolant: flow_matching
34
+
35
+ dataset:
36
+ full_batch_size: 1024
37
+
38
+ areas: []
39
+
40
+ experiment_name_suffix: small_sigmoid
plonk/configs/exp/iNaturalist_geoadalnmlp_r2_small_sigmoid_diffusion.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: inaturalist_emb
5
+ - override /model: emb_cond
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: ddpm
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 12
15
+ dim: 256
16
+ optimizer:
17
+ optim:
18
+ lr: 8e-4
19
+ weight_decay: 0.1
20
+ loss:
21
+ cond_drop_rate: 0.1
22
+ train_noise_scheduler:
23
+ start: -7
24
+ end: 3
25
+ tau: 1.0
26
+ inference_noise_scheduler:
27
+ start: -7
28
+ end: 3
29
+ tau: 1.0
30
+ interpolant: diffusion
31
+ dataset:
32
+ full_batch_size: 512
33
+
34
+ areas: []
35
+
36
+ experiment_name_suffix: small_sigmoid
plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_diffusion.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: inaturalist_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: ddpm
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 12
15
+ dim: 256
16
+ optimizer:
17
+ optim:
18
+ lr: 8e-4
19
+ weight_decay: 0.1
20
+ loss:
21
+ cond_drop_rate: 0.1
22
+ train_noise_scheduler:
23
+ start: -7
24
+ end: 3
25
+ tau: 1.0
26
+ inference_noise_scheduler:
27
+ start: -7
28
+ end: 3
29
+ tau: 1.0
30
+ interpolant: diffusion
31
+
32
+ dataset:
33
+ full_batch_size: 512
34
+
35
+ areas: []
36
+
37
+ experiment_name_suffix: small_sigmoid
plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: inaturalist_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: flow_matching
10
+ - override /model/val_sampler: flow_matching
11
+ - override /model/test_sampler: flow_matching
12
+ - _self_
13
+
14
+ model:
15
+ network:
16
+ depth: 12
17
+ dim: 256
18
+ optimizer:
19
+ optim:
20
+ lr: 8e-4
21
+ weight_decay: 0.1
22
+ loss:
23
+ cond_drop_rate: 0.1
24
+ train_noise_scheduler:
25
+ start: -7
26
+ end: 3
27
+ tau: 1.0
28
+ inference_noise_scheduler:
29
+ start: -7
30
+ end: 3
31
+ tau: 1.0
32
+ interpolant: flow_matching
33
+
34
+ dataset:
35
+ full_batch_size: 512
36
+
37
+ areas: []
38
+
39
+ experiment_name_suffix: small_sigmoid
plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: inaturalist_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: riemannian_flow_matching
10
+ - override /model/manifold: sphere
11
+ - override /model/val_sampler: riemannian_flow_matching
12
+ - override /model/test_sampler: riemannian_flow_matching
13
+ - _self_
14
+
15
+ model:
16
+ network:
17
+ depth: 12
18
+ dim: 256
19
+ optimizer:
20
+ optim:
21
+ lr: 8e-4
22
+ weight_decay: 0.1
23
+ loss:
24
+ cond_drop_rate: 0.1
25
+ train_noise_scheduler:
26
+ start: -7
27
+ end: 3
28
+ tau: 1.0
29
+ inference_noise_scheduler:
30
+ start: -7
31
+ end: 3
32
+ tau: 1.0
33
+ interpolant: flow_matching
34
+
35
+ dataset:
36
+ full_batch_size: 512
37
+
38
+ areas: []
39
+
40
+ experiment_name_suffix: small_sigmoid
plonk/configs/exp/iNaturalist_geoadalnmlp_von_fisher.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: inaturalist_emb
5
+ - override /model: von_fisher
6
+ - override /model/network: geo_adaln_mlp_von_fisher
7
+ - override /model/loss: von_fisher
8
+ - override /model/val_sampler: von_fisher
9
+ - override /model/test_sampler: von_fisher
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 11 # To compensate the increase in params
15
+ dim: 256
16
+ optimizer:
17
+ optim:
18
+ lr: 1e-4
19
+ weight_decay: 0.1
20
+ dataset:
21
+ full_batch_size: 512
22
+ trainer:
23
+ gradient_clip_val: 0.01
24
+ gradient_clip_algorithm: norm
25
+ areas: []
26
+ experiment_name_suffix: von_fisher
plonk/configs/exp/iNaturalist_geoadalnmlp_von_fisher_mixture.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: inaturalist_emb
5
+ - override /model: von_fisher_mixture
6
+ - override /model/network: geo_adaln_mlp_von_fisher_mixture
7
+ - override /model/loss: von_fisher_mixture
8
+ - override /model/val_sampler: von_fisher_mixture
9
+ - override /model/test_sampler: von_fisher_mixture
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 11 # To compensate the increase in params
15
+ dim: 256
16
+ optimizer:
17
+ optim:
18
+ lr: 1e-5
19
+ weight_decay: 0.1
20
+ dataset:
21
+ full_batch_size: 512
22
+ trainer:
23
+ gradient_clip_val: 0.01
24
+ gradient_clip_algorithm: norm
25
+ areas: []
26
+ experiment_name_suffix: von_fisher_mixture
plonk/configs/exp/osv_5m_geoadalnmlp_r2_small_sigmoid_diffusion.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_emb
5
+ - override /model: emb_cond
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: ddpm
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 12
15
+ dim: 512
16
+ optimizer:
17
+ optim:
18
+ lr: 8e-4
19
+ weight_decay: 0.05
20
+ loss:
21
+ cond_drop_rate: 0.1
22
+ train_noise_scheduler:
23
+ start: -7
24
+ end: 3
25
+ tau: 1.0
26
+ inference_noise_scheduler:
27
+ start: -7
28
+ end: 3
29
+ tau: 1.0
30
+ interpolant: diffusion
31
+ dataset:
32
+ full_batch_size: 1024
33
+
34
+ experiment_name_suffix: small_sigmoid
plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_linear_flow_riemann.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: linear
8
+ - override /model/inference_noise_scheduler: linear
9
+ - override /model/loss: riemannian_flow_matching
10
+ - override /model/manifold: sphere
11
+ - override /model/val_sampler: riemannian_flow_matching
12
+ - override /model/test_sampler: riemannian_flow_matching
13
+ - _self_
14
+
15
+ model:
16
+ network:
17
+ depth: 12
18
+ dim: 512
19
+ optimizer:
20
+ optim:
21
+ lr: 8e-4
22
+ weight_decay: 0.05
23
+ loss:
24
+ cond_drop_rate: 0.1
25
+ interpolant: flow_matching
26
+
27
+ dataset:
28
+ full_batch_size: 1024
29
+
30
+ experiment_name_suffix: small_sigmoid
plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_sigmoid_diffusion.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: ddpm
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 12
15
+ dim: 512
16
+ optimizer:
17
+ optim:
18
+ lr: 8e-4
19
+ weight_decay: 0.05
20
+ loss:
21
+ cond_drop_rate: 0.1
22
+ train_noise_scheduler:
23
+ start: -7
24
+ end: 3
25
+ tau: 1.0
26
+ inference_noise_scheduler:
27
+ start: -7
28
+ end: 3
29
+ tau: 1.0
30
+ interpolant: diffusion
31
+
32
+ dataset:
33
+ full_batch_size: 1024
34
+
35
+ experiment_name_suffix: small_sigmoid