diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..75faa9005bb3f63f3b5cf1f14a0ca9fd1c1afe41 --- /dev/null +++ b/.gitignore @@ -0,0 +1,58 @@ +# Hugging Face Spaces .gitignore + +# Python cache +__pycache__/ +*.py[cod] +*$py.class +*.so + +# Development files +.ipynb_checkpoints/ +.vscode/ +.idea/ +*.swp +*.swo + +# OS files +.DS_Store +Thumbs.db + +# Temporary files +*.tmp +*.log +*.pid + +# Original demo files (using streamlit) +demo/demo.py + +# Environment files +.env +.env.local + +# Model checkpoints (will be downloaded automatically) +checkpoints/ +*.safetensors +*.bin + +# Large data files +data/ +datasets/ +*.csv +*.json + +# Training artifacts +wandb/ +logs/ +outputs/ + +# Test files +test_*.py +*_test.py + +# Documentation that's not needed for the Space +*.md +!README.md + +# Git files +.git/ +.gitmodules \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..56d71f14df234c3515535b51bf3219f3c53ebe84 --- /dev/null +++ b/README.md @@ -0,0 +1,78 @@ +--- +title: PLONK Geolocation +emoji: 🗺️ +colorFrom: red +colorTo: blue +sdk: gradio +sdk_version: 4.0.0 +app_file: app.py +pinned: false +license: mit +--- + +# 🗺️ PLONK: Around the World in 80 Timesteps + +A generative approach to global visual geolocation. Upload an image and PLONK will predict where it was taken! + +## About + +PLONK is a diffusion-based model that predicts the geographic location where a photo was taken based solely on its visual content. This Space uses the PLONK_YFCC model trained on the YFCC100M dataset. + +## Features + +- **Simple Prediction**: Get a single high-confidence location prediction +- **Advanced Analysis**: Explore prediction uncertainty with multiple samples and guidance control +- **Fast CPU Inference**: ~300-500ms per image on CPU-Basic tier +- **GPU Ready**: Upgrade to T4-small for ~45ms inference time + +## Usage + +1. Upload an image using the interface +2. Click "Submit" to get location predictions +3. For advanced analysis, try different guidance scales: + - CFG = 0.0: More diverse predictions (good for uncertainty estimation) + - CFG = 2.0: Single confident prediction (best guess) + +## API Usage + +This Space exposes a REST API compatible with Gradio's prediction format: + +```python +import requests + +url = "https://your-space-name.hf.space/api/predict" +files = {"data": open("image.jpg", "rb")} +response = requests.post(url, files=files) +print(response.json()) +``` + +## Model Performance + +- **Latency**: 300-500ms on CPU-Basic, ~45ms on T4 GPU +- **Memory**: <1GB RAM usage +- **Throughput**: ~10 req/s on T4 before saturation + +## Scaling Options + +- **Free CPU-Basic**: Perfect for testing and low-volume usage +- **T4-small ($0.40/hr)**: 10x faster inference for production +- **Inference Endpoints**: Auto-scaling with pay-per-use pricing + +## Citation + +If you use PLONK in your research, please cite: + +```bibtex +@article{dufour2024plonk, + title={Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation}, + author={Dufour, Nicolas and others}, + journal={arXiv preprint}, + year={2024} +} +``` + +## Links + +- 📄 [Project Page](https://nicolas-dufour.github.io/plonk) +- 💻 [Code Repository](https://github.com/nicolas-dufour/plonk) +- 🤗 [Model on Hugging Face](https://huggingface.co/nicolas-dufour/PLONK_YFCC) \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..b37de84cc1c79b5308df2a2c5ee048c3add08fca --- /dev/null +++ b/app.py @@ -0,0 +1,132 @@ +import gradio as gr +import torch +from plonk.pipe import PlonkPipeline +import numpy as np +from PIL import Image + +# Initialize the pipeline +print("Loading PLONK_YFCC model...") +pipe = PlonkPipeline(model_path="nicolas-dufour/PLONK_YFCC") +print("Model loaded successfully!") + +def predict_geolocation(image): + """ + Predict geolocation from an uploaded image + Args: + image: PIL Image + Returns: + str: Formatted latitude and longitude + """ + if image is None: + return "Please upload an image" + + try: + # Get prediction using the pipeline + # Using single sample with high confidence (cfg=2.0) for best guess + predicted_gps = pipe(image, batch_size=1, cfg=2.0, num_steps=32) + + # Extract latitude and longitude + lat, lon = float(predicted_gps[0, 0]), float(predicted_gps[0, 1]) + + # Format the result + result = f"Predicted Location:\nLatitude: {lat:.6f}\nLongitude: {lon:.6f}" + + return result + + except Exception as e: + return f"Error during prediction: {str(e)}" + +def predict_geolocation_with_samples(image, num_samples=64, cfg=0.0): + """ + Predict geolocation with multiple samples for uncertainty visualization + Args: + image: PIL Image + num_samples: Number of samples to generate + cfg: Classifier-free guidance scale + Returns: + str: Formatted results with statistics + """ + if image is None: + return "Please upload an image" + + try: + # Get multiple predictions for uncertainty estimation + predicted_gps = pipe(image, batch_size=num_samples, cfg=cfg, num_steps=32) + + # Calculate statistics + lats = predicted_gps[:, 0].astype(float) + lons = predicted_gps[:, 1].astype(float) + + mean_lat, mean_lon = np.mean(lats), np.mean(lons) + std_lat, std_lon = np.std(lats), np.std(lons) + + # Get high confidence prediction + high_conf_gps = pipe(image, batch_size=1, cfg=2.0, num_steps=32) + conf_lat, conf_lon = float(high_conf_gps[0, 0]), float(high_conf_gps[0, 1]) + + result = f"""Geolocation Prediction Results: + +High Confidence Prediction (CFG=2.0): +Latitude: {conf_lat:.6f} +Longitude: {conf_lon:.6f} + +Sample Statistics ({num_samples} samples, CFG={cfg}): +Mean Latitude: {mean_lat:.6f} ± {std_lat:.6f} +Mean Longitude: {mean_lon:.6f} ± {std_lon:.6f} + """ + + return result + + except Exception as e: + return f"Error during prediction: {str(e)}" + +# Create the Gradio interface for simple prediction +simple_interface = gr.Interface( + fn=predict_geolocation, + inputs=gr.Image(type="pil", label="Upload an image"), + outputs=gr.Textbox(label="Predicted Location", lines=4), + title="🗺️ PLONK: Global Visual Geolocation", + description=""" + Upload an image and PLONK will predict where it was taken! + + This uses the PLONK_YFCC model trained on the YFCC100M dataset. + The model predicts latitude and longitude coordinates based on visual content. + + **Note**: This is running on CPU, so processing may take 300-500ms per image. + """, + examples=[ + ["demo/examples/condor.jpg"], + ["demo/examples/Kilimanjaro.jpg"], + ["demo/examples/pigeon.png"] + ] if any(Path("demo/examples").glob("*")) else None +) + +# Create advanced interface with sampling options +advanced_interface = gr.Interface( + fn=predict_geolocation_with_samples, + inputs=[ + gr.Image(type="pil", label="Upload an image"), + gr.Slider(1, 256, value=64, step=1, label="Number of samples"), + gr.Slider(0.0, 5.0, value=0.0, step=0.1, label="Guidance scale (CFG)") + ], + outputs=gr.Textbox(label="Detailed Results", lines=10), + title="🗺️ PLONK: Advanced Geolocation with Uncertainty", + description=""" + Advanced interface showing prediction uncertainty through multiple samples. + + - **Number of samples**: More samples = better uncertainty estimation (but slower) + - **Guidance scale**: Higher values = more confident predictions (try 2.0 for best single guess) + """, +) + +# Create tabbed interface +demo = gr.TabbedInterface( + [simple_interface, advanced_interface], + ["Simple Prediction", "Advanced Analysis"], + title="PLONK: Around the World in 80 Timesteps" +) + +if __name__ == "__main__": + # Add necessary import for pathlib + from pathlib import Path + demo.launch() \ No newline at end of file diff --git a/plonk/__init__.py b/plonk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e20d0d3e4d7bef4f8abcf69d0355936e8770955 --- /dev/null +++ b/plonk/__init__.py @@ -0,0 +1 @@ +from .pipe import PlonkPipeline \ No newline at end of file diff --git a/plonk/callbacks/__init__.py b/plonk/callbacks/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..7e2064a43f692ee9010e8f92f9b647bdb61488b9 --- /dev/null +++ b/plonk/callbacks/__init__.py @@ -0,0 +1,3 @@ +from .ema import EMACallback +from .fix_nans import FixNANinGrad +from .data import IncreaseDataEpoch diff --git a/plonk/callbacks/data.py b/plonk/callbacks/data.py new file mode 100644 index 0000000000000000000000000000000000000000..4706e5f21fcd415f69407e401326ba472291e167 --- /dev/null +++ b/plonk/callbacks/data.py @@ -0,0 +1,11 @@ +from pytorch_lightning.callbacks import Callback + + +class IncreaseDataEpoch(Callback): + def __init__(self): + super().__init__() + + def on_train_epoch_start(self, trainer, pl_module): + epoch = pl_module.current_epoch + if hasattr(trainer.datamodule.train_dataset, "shared_epoch"): + trainer.datamodule.train_dataset.shared_epoch.set_value(epoch) diff --git a/plonk/callbacks/ema.py b/plonk/callbacks/ema.py new file mode 100755 index 0000000000000000000000000000000000000000..bf65a7bfc358234712206de408761e2b2880d102 --- /dev/null +++ b/plonk/callbacks/ema.py @@ -0,0 +1,102 @@ +from pytorch_lightning import Callback +import copy +import itertools +import torch +import contextlib +from torch.distributed.fsdp import FullyShardedDataParallel + + +class EMACallback(Callback): + def __init__( + self, + module_attr_name, + ema_module_attr_name, + decay=0.999, + start_ema_step=0, + init_ema_random=True, + ): + super().__init__() + self.decay = decay + self.module_attr_name = module_attr_name + self.ema_module_attr_name = ema_module_attr_name + self.start_ema_step = start_ema_step + self.init_ema_random = init_ema_random + + def on_train_start(self, trainer, pl_module): + if pl_module.global_step == 0: + if not hasattr(pl_module, self.module_attr_name): + raise ValueError( + f"Module {pl_module} does not have attribute {self.module_attr_name}" + ) + if not hasattr(pl_module, self.ema_module_attr_name): + pl_module.add_module( + self.ema_module_attr_name, + copy.deepcopy(getattr(pl_module, self.module_attr_name)) + .eval() + .requires_grad_(False), + ) + self.reset_ema(pl_module) + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if pl_module.global_step == self.start_ema_step: + self.reset_ema(pl_module) + elif ( + pl_module.global_step < self.start_ema_step + and pl_module.global_step % 100 == 0 + ): + ## slow ema updates for visualisation + self.update_ema(pl_module, decay=0.9) + elif pl_module.global_step > self.start_ema_step: + self.update_ema(pl_module, decay=self.decay) + + def update_ema(self, pl_module, decay=0.999): + ema_module = getattr(pl_module, self.ema_module_attr_name) + module = getattr(pl_module, self.module_attr_name) + context_manager = self.get_model_context_manager(module) + with context_manager: + with torch.no_grad(): + ema_params = ema_module.state_dict() + for name, param in itertools.chain( + module.named_parameters(), module.named_buffers() + ): + if name in ema_params: + if param.requires_grad: + ema_params[name].copy_( + ema_params[name].detach().lerp(param.detach(), decay) + ) + + def get_model_context_manager(self, module): + fsdp_enabled = is_model_fsdp(module) + model_context_manager = contextlib.nullcontext() + if fsdp_enabled: + model_context_manager = module.summon_full_params(module) + return model_context_manager + + def reset_ema(self, pl_module): + ema_module = getattr(pl_module, self.ema_module_attr_name) + if self.init_ema_random: + ema_module.init_weights() + else: + module = getattr(pl_module, self.module_attr_name) + context_manager = self.get_model_context_manager(module) + with context_manager: + ema_params = ema_module.state_dict() + for name, param in itertools.chain( + module.named_parameters(), module.named_buffers() + ): + if name in ema_params: + ema_params[name].copy_(param.detach()) + + +def is_model_fsdp(model: torch.nn.Module) -> bool: + try: + if isinstance(model, FullyShardedDataParallel): + return True + + # Check if model is wrapped with FSDP + for _, obj in model.named_children(): + if isinstance(obj, FullyShardedDataParallel): + return True + return False + except ImportError: + return False diff --git a/plonk/callbacks/fix_nans.py b/plonk/callbacks/fix_nans.py new file mode 100755 index 0000000000000000000000000000000000000000..51c1d829a4eaa2b14b2c30e54ead3d153d77ac1a --- /dev/null +++ b/plonk/callbacks/fix_nans.py @@ -0,0 +1,55 @@ +import logging +from pytorch_lightning.callbacks import Callback +import torch + +log = logging.getLogger(__name__) + + +class FixNANinGrad(Callback): + def __init__(self, monitor): + super().__init__() + self.monitor = monitor + self.continuous_nan_batchs = 0 + + def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None: + has_nan = [] + is_inf = [] + for name, param in pl_module.named_parameters(): + if param.grad is not None: + if torch.isnan(param.grad).any(): + has_nan.append(name) + if torch.isinf(param.grad).any(): + is_inf.append(name) + torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) + if len(has_nan) > 0: + print(f"Found NaN in {has_nan}") + if len(is_inf) > 0: + print(f"Found Inf in {is_inf}") + + def on_train_batch_end( + self, + trainer, + pl_module, + outputs, + batch, + batch_idx, + ) -> None: + logs = trainer.callback_metrics + i = 0 + found_metric = False + while i < len(self.monitor) and not found_metric: + if self.monitor[i] in logs.keys(): + current = logs[self.monitor[i]].squeeze() + found_metric = True + else: + i += 1 + if not found_metric: + raise ValueError("Asked metric not in logs") + + if not torch.isfinite(current): + self.continuous_nan_batchs += 1 + if self.continuous_nan_batchs >= 5: + trainer.should_stop = True + log.info("Training interrupted because of NaN in {self.monitor}") + else: + self.continuous_nan_batchs = 0 diff --git a/plonk/configs/computer/a100.yaml b/plonk/configs/computer/a100.yaml new file mode 100644 index 0000000000000000000000000000000000000000..60ac8bd5263b64cad5b659b1f71a0752f6edfe96 --- /dev/null +++ b/plonk/configs/computer/a100.yaml @@ -0,0 +1,8 @@ +devices: 1 +progress_bar_refresh_rate: 2 +num_workers: 8 +sync_batchnorm: False +accelerator: gpu +precision: 32 +strategy: auto +num_nodes: 1 diff --git a/plonk/configs/computer/cluster-node-a100.yaml b/plonk/configs/computer/cluster-node-a100.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d60903dca91d09422eefb572a41060bde0aac7b1 --- /dev/null +++ b/plonk/configs/computer/cluster-node-a100.yaml @@ -0,0 +1,8 @@ +devices: 8 +num_workers: 8 +progress_bar_refresh_rate: 2 +sync_batchnorm: True +accelerator: gpu +precision: 32 +strategy: ddp +num_nodes: 1 diff --git a/plonk/configs/computer/cluster-node-v100.yaml b/plonk/configs/computer/cluster-node-v100.yaml new file mode 100644 index 0000000000000000000000000000000000000000..48da9ac269cedd97f8619e92e54986a8124f6bd7 --- /dev/null +++ b/plonk/configs/computer/cluster-node-v100.yaml @@ -0,0 +1,8 @@ +devices: 4 +num_workers: 10 +progress_bar_refresh_rate: 2 +sync_batchnorm: True +accelerator: gpu +precision: 32 +strategy: ddp +num_nodes: 1 diff --git a/plonk/configs/computer/cpu.yaml b/plonk/configs/computer/cpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6e4e49bbe84d4bfbf0ed4849db41a20aa27d9dc2 --- /dev/null +++ b/plonk/configs/computer/cpu.yaml @@ -0,0 +1,8 @@ +devices: null +num_workers: 0 +progress_bar_refresh_rate: 2 +sync_batchnorm: False +accelerator: cpu +precision: 32 +strategy: auto +num_nodes: null diff --git a/plonk/configs/computer/h100.yaml b/plonk/configs/computer/h100.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8509aa21fc99c38e44b05d250658b45d5300cfb7 --- /dev/null +++ b/plonk/configs/computer/h100.yaml @@ -0,0 +1,8 @@ +devices: 1 +progress_bar_refresh_rate: 2 +num_workers: 24 +sync_batchnorm: False +accelerator: gpu +precision: 32 +strategy: auto +num_nodes: 1 diff --git a/plonk/configs/computer/v100.yaml b/plonk/configs/computer/v100.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d0ac2cc4c2aef6ee3a941f8508e20f5585487f8b --- /dev/null +++ b/plonk/configs/computer/v100.yaml @@ -0,0 +1,8 @@ +devices: 1 +num_workers: 10 +progress_bar_refresh_rate: 2 +sync_batchnorm: False +accelerator: gpu +precision: 32 +strategy: auto +num_nodes: 1 diff --git a/plonk/configs/config.yaml b/plonk/configs/config.yaml new file mode 100755 index 0000000000000000000000000000000000000000..5ba550e2e7213b601d67806da5310ea907daa267 --- /dev/null +++ b/plonk/configs/config.yaml @@ -0,0 +1,90 @@ +defaults: + - model: default + - computer: v100 + - dataset: osv5m_emb + - stage: null + - _self_ + - exp: ??? + +model: + val_metrics: + _target_: metrics.distance_based.HaversineMetrics + acc_radiuses: + - 1 + - 25 + - 200 + - 750 + - 2500 + acc_area: [] + test_metrics: + _target_: metrics.distance_based.HaversineMetrics + acc_radiuses: + - 1 + - 25 + - 200 + - 750 + - 2500 + acc_area: ${areas} + +datamodule: + _target_: plonk.data.datamodule.ImageDataModule + train_dataset: ${dataset.train_dataset} + val_dataset: ${dataset.val_dataset} + test_dataset: ${dataset.test_dataset} + full_batch_size: ${dataset.full_batch_size} + eval_batch_size: ${dataset.eval_batch_size} + num_workers: ${computer.num_workers} + num_nodes: ${computer.num_nodes} + num_devices: ${computer.devices} + val_proportion: 0.02 + +trainer: + _target_: pytorch_lightning.Trainer + devices: ${computer.devices} + accelerator: ${computer.accelerator} + strategy: ${computer.strategy} + num_nodes: ${computer.num_nodes} + precision: ${computer.precision} + max_steps: 1000000 + val_check_interval: 25000 + check_val_every_n_epoch: null + +logger: + _target_: pytorch_lightning.loggers.WandbLogger + save_dir: ${root_dir}/plonk + name: ${experiment_name}${logger_suffix} + project: diff_plonk + log_model: False + offline: False + +checkpoints: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + dirpath: ${root_dir}/plonk/checkpoints/${experiment_name} + filename: 'epoch_{epoch}' + monitor: val/loss + save_last: True + save_top_k: 0 + every_n_epochs: 1 + enable_version_counter: False + +progress_bar: + _target_: pytorch_lightning.callbacks.TQDMProgressBar + refresh_rate: ${computer.progress_bar_refresh_rate} + +data_dir: ${root_dir}/plonk/datasets +root_dir: ${hydra:runtime.cwd} +experiment_name: ${dataset.name}_${model.name}_${experiment_name_suffix} +experiment_name_suffix: base +logger_suffix: "" +mode: train # change that to eval to do the testing +areas: ['country', 'region', 'sub-region', 'city'] +class_name: null +streetclip: False +blur: False +text_tuning: False + +hydra: + run: + dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name} + job: + chdir: true diff --git a/plonk/configs/dataset/combined_emb.yaml b/plonk/configs/dataset/combined_emb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a3bb7ddcf1e4aad935fd5740cdb38488f76ff8b --- /dev/null +++ b/plonk/configs/dataset/combined_emb.yaml @@ -0,0 +1,38 @@ +defaults: + - train_transform: empty + - test_transform: empty + - _self_ + +name: iNaturalist_OSV5M_YFCC100M_${dataset.embedding_name} +full_batch_size: 2048 +cond_dim: 1024 +eval_batch_size: 4096 +output_type: emb +embedding_name: dinov2_vitl14_registers + +train_dataset: + _partial_: true + _target_: plonk.data.webdataset.GPSWebdataset + root: ${data_dir}/YFCC100M/train/ ${data_dir}/osv5m/train/ ${data_dir}/inaturalist/train/ ${data_dir}/osv5m/train/ ${data_dir}/inaturalist/train/ + train: true + embedding_name: ${dataset.embedding_name} + return_image: false + metadata_attributes: [] + +val_dataset: + _partial_: true + _target_: plonk.data.webdataset.GPSWebdataset + root: ${data_dir}/YFCC100M/yfcc4k/ + train: false + embedding_name: ${dataset.embedding_name} + return_image: false + metadata_attributes: [] + +test_dataset: + _partial_: true + _target_: plonk.data.webdataset.GPSWebdataset + root: ${data_dir}/YFCC100M/yfcc4k/ + train: false + embedding_name: ${dataset.embedding_name} + return_image: false + metadata_attributes: [] diff --git a/plonk/configs/dataset/inaturalist_emb.yaml b/plonk/configs/dataset/inaturalist_emb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a8353eaed5fe7b4fe80f112b6204859ea6a250a5 --- /dev/null +++ b/plonk/configs/dataset/inaturalist_emb.yaml @@ -0,0 +1,38 @@ +defaults: + - train_transform: empty + - test_transform: empty + - _self_ + +name: iNaturalist_${dataset.embedding_name} +full_batch_size: 512 +cond_dim: 1024 +eval_batch_size: 4096 +output_type: emb +embedding_name: dinov2_vitl14_registers + +train_dataset: + _partial_: true + _target_: plonk.data.webdataset.GPSWebdataset + root: ${data_dir}/inaturalist/train/ + train: true + embedding_name: ${dataset.embedding_name} + return_image: false + metadata_attributes: [] + +val_dataset: + _partial_: true + _target_: plonk.data.webdataset.GPSWebdataset + root: ${data_dir}/inaturalist/val/ + train: false + embedding_name: ${dataset.embedding_name} + return_image: false + metadata_attributes: [] + +test_dataset: + _partial_: true + _target_: plonk.data.webdataset.GPSWebdataset + root: ${data_dir}/inaturalist/test/ + train: false + embedding_name: ${dataset.embedding_name} + return_image: false + metadata_attributes: [] diff --git a/plonk/configs/dataset/osv5m.yaml b/plonk/configs/dataset/osv5m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..070a5e1aeb2d32281c3071634e2df205b64bbe27 --- /dev/null +++ b/plonk/configs/dataset/osv5m.yaml @@ -0,0 +1,43 @@ +defaults: + - train_transform: fast_clip + - test_transform: fast_clip + - _self_ + +name: osv5m +full_batch_size: 2048 +eval_batch_size: 4096 +train_dataset: + _partial_: true + _target_: plonk.data.data.OSV5M + path: ${data_dir}/osv5m/ + split: train + class_name: ${class_name} + transforms: ${dataset.train_transform} + is_baseline: ${is_baseline} + areas: ${areas} + streetclip: ${streetclip} + blur: ${blur} + +val_dataset: + _partial_: true + _target_: plonk.data.data.OSV5M + path: ${data_dir}/osv5m/ + split: val + class_name: ${class_name} + transforms: ${dataset.test_transform} + is_baseline: ${is_baseline} + areas: ${areas} + streetclip: ${streetclip} + blur: ${blur} + +test_dataset: + _partial_: true + _target_: plonk.data.data.OSV5M + path: ${data_dir}/osv5m/ + split: test + class_name: ${class_name} + transforms: ${dataset.test_transform} + is_baseline: ${is_baseline} + areas: ${areas} + streetclip: ${streetclip} + blur: ${blur} diff --git a/plonk/configs/dataset/osv5m_emb.yaml b/plonk/configs/dataset/osv5m_emb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4cb277dc40051a40b2cd9f45bdf12458074e1169 --- /dev/null +++ b/plonk/configs/dataset/osv5m_emb.yaml @@ -0,0 +1,38 @@ +defaults: + - train_transform: empty + - test_transform: empty + - _self_ + +name: osv5m_${dataset.embedding_name} +full_batch_size: 1024 +eval_batch_size: 4096 +cond_dim: 1024 +output_type: emb +embedding_name: street_clip + +train_dataset: + _partial_: true + _target_: plonk.data.webdataset.GPSWebdataset + root: ${data_dir}/osv5m/train/ + train: true + embedding_name: ${dataset.embedding_name} + return_image: false + metadata_attributes: [] + +val_dataset: + _partial_: true + _target_: plonk.data.webdataset.GPSWebdataset + root: ${data_dir}/osv5m/val/ + train: false + embedding_name: ${dataset.embedding_name} + return_image: false + metadata_attributes: ["unique_country", "unique_region", "unique_sub-region", "unique_city"] + +test_dataset: + _partial_: true + _target_: plonk.data.webdataset.GPSWebdataset + root: ${data_dir}/osv5m/test/ + train: false + embedding_name: ${dataset.embedding_name} + return_image: false + metadata_attributes: ["unique_country", "unique_region", "unique_sub-region", "unique_city"] diff --git a/plonk/configs/dataset/test_transform/center_crop.yaml b/plonk/configs/dataset/test_transform/center_crop.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17a226a752cc43e10577bc8905add7f89f930be7 --- /dev/null +++ b/plonk/configs/dataset/test_transform/center_crop.yaml @@ -0,0 +1,12 @@ +_target_: torchvision.transforms.Compose +transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: plonk.utils.image_processing.CenterCrop + ratio: "1:1" + - _target_: torchvision.transforms.Resize + size: ${dataset.img_resolution} + interpolation: 3 + antialias: true + - _target_: torchvision.transforms.Normalize + mean: 0.5 + std: 0.5 diff --git a/plonk/configs/dataset/test_transform/clip.yaml b/plonk/configs/dataset/test_transform/clip.yaml new file mode 100755 index 0000000000000000000000000000000000000000..fdeeb4b48d72f93a02ffb3e3ff2b44c5beee5ee5 --- /dev/null +++ b/plonk/configs/dataset/test_transform/clip.yaml @@ -0,0 +1,2 @@ +_target_: plonk.data.transforms.ClipTransform +split: val diff --git a/plonk/configs/dataset/test_transform/empty.yaml b/plonk/configs/dataset/test_transform/empty.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e19b331275841d011daab56a27bd697d236e6643 --- /dev/null +++ b/plonk/configs/dataset/test_transform/empty.yaml @@ -0,0 +1,2 @@ +_target_: plonk.data.data.null_transform +_partial_: true \ No newline at end of file diff --git a/plonk/configs/dataset/test_transform/fast_clip.yaml b/plonk/configs/dataset/test_transform/fast_clip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..45b6a08732e0466ba225038b8e1a27fffb3f66c7 --- /dev/null +++ b/plonk/configs/dataset/test_transform/fast_clip.yaml @@ -0,0 +1,12 @@ +_target_: torchvision.transforms.Compose +transforms: + - _target_: torchvision.transforms.Resize + size: 224 + interpolation: 3 + antialias: true + - _target_: torchvision.transforms.CenterCrop + size: 224 + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] diff --git a/plonk/configs/dataset/test_transform/fast_resnet.yaml b/plonk/configs/dataset/test_transform/fast_resnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fdbabe78156489a27370fa60e69e539170fbe150 --- /dev/null +++ b/plonk/configs/dataset/test_transform/fast_resnet.yaml @@ -0,0 +1,12 @@ +_target_: torchvision.transforms.Compose +transforms: + - _target_: torchvision.transforms.Resize + size: 224 + interpolation: 3 + antialias: true + - _target_: torchvision.transforms.CenterCrop + size: 224 + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.485 ,0.456 ,0.406] + std: [0.229, 0.224, 0.225] \ No newline at end of file diff --git a/plonk/configs/dataset/test_transform/none.yaml b/plonk/configs/dataset/test_transform/none.yaml new file mode 100755 index 0000000000000000000000000000000000000000..711c1f0b1d1101281d28c9a95c19d7c0da2ae838 --- /dev/null +++ b/plonk/configs/dataset/test_transform/none.yaml @@ -0,0 +1,6 @@ +_target_: torchvision.transforms.Compose +transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: 0.5 + std: 0.5 diff --git a/plonk/configs/dataset/train_transform/augmentation.yaml b/plonk/configs/dataset/train_transform/augmentation.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4d47baa311bfb1702f867db5d6d7a724fef79a4 --- /dev/null +++ b/plonk/configs/dataset/train_transform/augmentation.yaml @@ -0,0 +1,85 @@ +_target_: plonk.data.augmentation.ImageAugmentation +names: "standard_augmentation,geometric_augmentation,clip_transform" + +# always apply clip_transform at the end +clip_transform: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.Resize + size: 224 + interpolation: 3 + antialias: true + - _target_: torchvision.transforms.CenterCrop + size: 224 + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +standard_augmentation: + _target_: plonk.data.augmentation.StandardAugmentation + # by default, we all augmentation methods + names: "brightness,contrast,sharpness,color,blur,gaussian_noise" + + # random PIL brigtness + brightness: + _target_: plonk.data.augmentation.PillowBrightness + p: 0.2 + factor_interval: [0.5, 1.5] + + # random PIL contrast + contrast: + _target_: plonk.data.augmentation.PillowContrast + p: 0.2 + factor_interval: [0.3, 3] + + # random PIL sharpness + sharpness: + _target_: plonk.data.augmentation.PillowSharpness + p: 0.2 + factor_interval: [0.5, 30.0] + + # random PIL color + color: + _target_: plonk.data.augmentation.PillowColor + p: 0.2 + factor_interval: [0.0, 2.0] + + # random PIL blur + blur: + _target_: plonk.data.augmentation.PillowBlur + p: 0.2 + factor_interval: [1, 2] + + # random numpy gaussian noise + gaussian_noise: + _target_: plonk.data.augmentation.NumpyGaussianNoise + p: 0.2 + factor_interval: [0.1, 0.04] + +geometric_augmentation: + _target_: plonk.data.augmentation.GeometricAugmentation + # by default, we all augmentation methods + names: "random_rotation,random_resized_crop,random_horizontal_flip" + + # random rotation + random_rotation: + _target_: torchvision.transforms.RandomRotation + degrees: [-15, 15] + + # random crop + random_resized_crop: + _target_: torchvision.transforms.RandomResizedCrop + scale: [0.5, 1.0] + ratio: [0.9, 1.1] + size: 224 + + # random horizontal flip + random_horizontal_flip: + _target_: torchvision.transforms.RandomHorizontalFlip + p: 0.5 + + # random vertical flip + random_vertical_flip: + _target_: torchvision.transforms.RandomVerticalFlip + p: 0.5 diff --git a/plonk/configs/dataset/train_transform/center_crop.yaml b/plonk/configs/dataset/train_transform/center_crop.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cda160fe4e4315cfe2758f8d97a335b94d61da40 --- /dev/null +++ b/plonk/configs/dataset/train_transform/center_crop.yaml @@ -0,0 +1,14 @@ +_target_: torchvision.transforms.Compose +transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: plonk.utils.image_processing.CenterCrop + ratio: "1:1" + - _target_: torchvision.transforms.Resize + size: ${dataset.img_resolution} + interpolation: 3 + antialias: true + - _target_: torchvision.transforms.RandomHorizontalFlip + p: 0.5 + - _target_: torchvision.transforms.Normalize + mean: 0.5 + std: 0.5 diff --git a/plonk/configs/dataset/train_transform/clip.yaml b/plonk/configs/dataset/train_transform/clip.yaml new file mode 100755 index 0000000000000000000000000000000000000000..fdeeb4b48d72f93a02ffb3e3ff2b44c5beee5ee5 --- /dev/null +++ b/plonk/configs/dataset/train_transform/clip.yaml @@ -0,0 +1,2 @@ +_target_: plonk.data.transforms.ClipTransform +split: val diff --git a/plonk/configs/dataset/train_transform/empty.yaml b/plonk/configs/dataset/train_transform/empty.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e19b331275841d011daab56a27bd697d236e6643 --- /dev/null +++ b/plonk/configs/dataset/train_transform/empty.yaml @@ -0,0 +1,2 @@ +_target_: plonk.data.data.null_transform +_partial_: true \ No newline at end of file diff --git a/plonk/configs/dataset/train_transform/fast_clip.yaml b/plonk/configs/dataset/train_transform/fast_clip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..45b6a08732e0466ba225038b8e1a27fffb3f66c7 --- /dev/null +++ b/plonk/configs/dataset/train_transform/fast_clip.yaml @@ -0,0 +1,12 @@ +_target_: torchvision.transforms.Compose +transforms: + - _target_: torchvision.transforms.Resize + size: 224 + interpolation: 3 + antialias: true + - _target_: torchvision.transforms.CenterCrop + size: 224 + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] diff --git a/plonk/configs/dataset/train_transform/fast_resnet.yaml b/plonk/configs/dataset/train_transform/fast_resnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fdbabe78156489a27370fa60e69e539170fbe150 --- /dev/null +++ b/plonk/configs/dataset/train_transform/fast_resnet.yaml @@ -0,0 +1,12 @@ +_target_: torchvision.transforms.Compose +transforms: + - _target_: torchvision.transforms.Resize + size: 224 + interpolation: 3 + antialias: true + - _target_: torchvision.transforms.CenterCrop + size: 224 + - _target_: torchvision.transforms.ToTensor + - _target_: torchvision.transforms.Normalize + mean: [0.485 ,0.456 ,0.406] + std: [0.229, 0.224, 0.225] \ No newline at end of file diff --git a/plonk/configs/dataset/train_transform/none.yaml b/plonk/configs/dataset/train_transform/none.yaml new file mode 100755 index 0000000000000000000000000000000000000000..0d54fe0045915b325145491307e283face27b3c2 --- /dev/null +++ b/plonk/configs/dataset/train_transform/none.yaml @@ -0,0 +1,7 @@ +_target_: torchvision.transforms.Compose +transforms: + - _target_: torchvision.transforms.Resize + size: 224 + interpolation: 3 + antialias: true + - _target_: torchvision.transforms.ToTensor diff --git a/plonk/configs/dataset/yfcc_emb.yaml b/plonk/configs/dataset/yfcc_emb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8aa34f8d73c14433969e146ff3d9aa2e9a0f69bb --- /dev/null +++ b/plonk/configs/dataset/yfcc_emb.yaml @@ -0,0 +1,38 @@ +defaults: + - train_transform: empty + - test_transform: empty + - _self_ + +name: iNaturalist_${dataset.embedding_name} +full_batch_size: 2048 +cond_dim: 1024 +eval_batch_size: 4096 +output_type: emb +embedding_name: dinov2_vitl14_registers + +train_dataset: + _partial_: true + _target_: plonk.data.webdataset.GPSWebdataset + root: ${data_dir}/YFCC100M/train/ + train: true + embedding_name: ${dataset.embedding_name} + return_image: false + metadata_attributes: [] + +val_dataset: + _partial_: true + _target_: plonk.data.webdataset.GPSWebdataset + root: ${data_dir}/YFCC100M/yfcc4k/ + train: false + embedding_name: ${dataset.embedding_name} + return_image: false + metadata_attributes: [] + +test_dataset: + _partial_: true + _target_: plonk.data.webdataset.GPSWebdataset + root: ${data_dir}/YFCC100M/yfcc4k/ + train: false + embedding_name: ${dataset.embedding_name} + return_image: false + metadata_attributes: [] diff --git a/plonk/configs/exp/YFCC100M_geoadalnmlp_r2_small_sigmoid_diffusion.yaml b/plonk/configs/exp/YFCC100M_geoadalnmlp_r2_small_sigmoid_diffusion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4b3410bff2182d3f5d1b044850974900a6326ab8 --- /dev/null +++ b/plonk/configs/exp/YFCC100M_geoadalnmlp_r2_small_sigmoid_diffusion.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +defaults: + - override /dataset: yfcc_emb + - override /model: emb_cond + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: sigmoid + - override /model/inference_noise_scheduler: sigmoid + - override /model/loss: ddpm + - _self_ + +model: + network: + depth: 12 + dim: 512 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.05 + loss: + cond_drop_rate: 0.1 + train_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + inference_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + interpolant: diffusion +dataset: + full_batch_size: 1024 + +experiment_name_suffix: small_sigmoid +areas: [] \ No newline at end of file diff --git a/plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_linear_flow_rieman.yaml b/plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_linear_flow_rieman.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0fee68fbf405a91b29ec434bb78476b848c30f3d --- /dev/null +++ b/plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_linear_flow_rieman.yaml @@ -0,0 +1,32 @@ +# @package _global_ + +defaults: + - override /dataset: yfcc_emb + - override /model: emb_cond_cartesian + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: linear + - override /model/inference_noise_scheduler: linear + - override /model/loss: riemannian_flow_matching + - override /model/manifold: sphere + - override /model/val_sampler: riemannian_flow_matching + - override /model/test_sampler: riemannian_flow_matching + - _self_ + +model: + network: + depth: 12 + dim: 512 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.05 + loss: + cond_drop_rate: 0.1 + interpolant: flow_matching + +dataset: + full_batch_size: 1024 + +areas: [] + +experiment_name_suffix: small_sigmoid diff --git a/plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_diffusion.yaml b/plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_diffusion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1672bd4cde5c447efae5f390785244c090f77b18 --- /dev/null +++ b/plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_diffusion.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +defaults: + - override /dataset: yfcc_emb + - override /model: emb_cond_cartesian + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: sigmoid + - override /model/inference_noise_scheduler: sigmoid + - override /model/loss: ddpm + - _self_ + +model: + network: + depth: 12 + dim: 512 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.05 + loss: + cond_drop_rate: 0.1 + train_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + inference_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + interpolant: diffusion + +dataset: + full_batch_size: 1024 + +experiment_name_suffix: small_sigmoid +areas: [] diff --git a/plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow.yaml b/plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fb204d93b7f0d4ce333fbdd61e1dff12ce4ba87e --- /dev/null +++ b/plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow.yaml @@ -0,0 +1,38 @@ +# @package _global_ + +defaults: + - override /dataset: yfcc_emb + - override /model: emb_cond_cartesian + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: sigmoid + - override /model/inference_noise_scheduler: sigmoid + - override /model/loss: flow_matching + - override /model/val_sampler: flow_matching + - override /model/test_sampler: flow_matching + - _self_ + +model: + network: + depth: 12 + dim: 512 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.05 + loss: + cond_drop_rate: 0.1 + train_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + inference_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + interpolant: flow_matching + +dataset: + full_batch_size: 1024 + +experiment_name_suffix: small_sigmoid +areas: [] \ No newline at end of file diff --git a/plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml b/plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d62acd07ffa09c8c618fda364da2910da20202dc --- /dev/null +++ b/plonk/configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml @@ -0,0 +1,40 @@ +# @package _global_ + +defaults: + - override /dataset: yfcc_emb + - override /model: emb_cond_cartesian + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: sigmoid + - override /model/inference_noise_scheduler: sigmoid + - override /model/loss: riemannian_flow_matching + - override /model/manifold: sphere + - override /model/val_sampler: riemannian_flow_matching + - override /model/test_sampler: riemannian_flow_matching + - _self_ + +model: + network: + depth: 12 + dim: 512 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.05 + loss: + cond_drop_rate: 0.1 + train_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + inference_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + interpolant: flow_matching + +dataset: + full_batch_size: 1024 + +areas: [] + +experiment_name_suffix: small_sigmoid diff --git a/plonk/configs/exp/YFCC100M_geoadalnmlp_von_fisher.yaml b/plonk/configs/exp/YFCC100M_geoadalnmlp_von_fisher.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aba9726efc25aac006d3c6c50c273ef0b2b9d4bb --- /dev/null +++ b/plonk/configs/exp/YFCC100M_geoadalnmlp_von_fisher.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +defaults: + - override /dataset: yfcc_emb + - override /model: von_fisher + - override /model/network: geo_adaln_mlp_von_fisher + - override /model/loss: von_fisher + - override /model/val_sampler: von_fisher + - override /model/test_sampler: von_fisher + - _self_ + +model: + network: + depth: 11 # To compensate the increase in params + dim: 512 + optimizer: + optim: + lr: 1e-4 + weight_decay: 0.05 +dataset: + full_batch_size: 1024 +trainer: + gradient_clip_val: 0.05 + gradient_clip_algorithm: norm +areas: [] +experiment_name_suffix: von_fisher \ No newline at end of file diff --git a/plonk/configs/exp/YFCC100M_geoadalnmlp_von_fisher_mixture.yaml b/plonk/configs/exp/YFCC100M_geoadalnmlp_von_fisher_mixture.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3ec04a70472c2417e47750f078e9ccea2b5d12d8 --- /dev/null +++ b/plonk/configs/exp/YFCC100M_geoadalnmlp_von_fisher_mixture.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +defaults: + - override /dataset: yfcc_emb + - override /model: von_fisher_mixture + - override /model/network: geo_adaln_mlp_von_fisher_mixture + - override /model/loss: von_fisher_mixture + - override /model/val_sampler: von_fisher_mixture + - override /model/test_sampler: von_fisher_mixture + - _self_ + +model: + network: + depth: 11 # To compensate the increase in params + dim: 512 + optimizer: + optim: + lr: 1e-5 + weight_decay: 0.05 +dataset: + full_batch_size: 1024 +trainer: + gradient_clip_val: 0.01 + gradient_clip_algorithm: norm +experiment_name_suffix: von_fisher_mixture +areas: [] \ No newline at end of file diff --git a/plonk/configs/exp/combined_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml b/plonk/configs/exp/combined_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b047cd07a5e3cb138be093a2a30729296b067bdf --- /dev/null +++ b/plonk/configs/exp/combined_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml @@ -0,0 +1,40 @@ +# @package _global_ + +defaults: + - override /dataset: combined_emb + - override /model: emb_cond_cartesian + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: sigmoid + - override /model/inference_noise_scheduler: sigmoid + - override /model/loss: riemannian_flow_matching + - override /model/manifold: sphere + - override /model/val_sampler: riemannian_flow_matching + - override /model/test_sampler: riemannian_flow_matching + - _self_ + +model: + network: + depth: 12 + dim: 512 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.05 + loss: + cond_drop_rate: 0.1 + train_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + inference_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + interpolant: flow_matching + +dataset: + full_batch_size: 1024 + +areas: [] + +experiment_name_suffix: small_sigmoid diff --git a/plonk/configs/exp/iNaturalist_geoadalnmlp_r2_small_sigmoid_diffusion.yaml b/plonk/configs/exp/iNaturalist_geoadalnmlp_r2_small_sigmoid_diffusion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9e44b2af3045a6f59891cd205606bbf0e8a2e10 --- /dev/null +++ b/plonk/configs/exp/iNaturalist_geoadalnmlp_r2_small_sigmoid_diffusion.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +defaults: + - override /dataset: inaturalist_emb + - override /model: emb_cond + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: sigmoid + - override /model/inference_noise_scheduler: sigmoid + - override /model/loss: ddpm + - _self_ + +model: + network: + depth: 12 + dim: 256 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.1 + loss: + cond_drop_rate: 0.1 + train_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + inference_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + interpolant: diffusion +dataset: + full_batch_size: 512 + +areas: [] + +experiment_name_suffix: small_sigmoid \ No newline at end of file diff --git a/plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_diffusion.yaml b/plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_diffusion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e87f9bbacf609fc627e85bd183d4adae9def3a10 --- /dev/null +++ b/plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_diffusion.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +defaults: + - override /dataset: inaturalist_emb + - override /model: emb_cond_cartesian + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: sigmoid + - override /model/inference_noise_scheduler: sigmoid + - override /model/loss: ddpm + - _self_ + +model: + network: + depth: 12 + dim: 256 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.1 + loss: + cond_drop_rate: 0.1 + train_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + inference_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + interpolant: diffusion + +dataset: + full_batch_size: 512 + +areas: [] + +experiment_name_suffix: small_sigmoid \ No newline at end of file diff --git a/plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow.yaml b/plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6252b122ff2ea716be8ccec15cc583c075e420b3 --- /dev/null +++ b/plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow.yaml @@ -0,0 +1,39 @@ +# @package _global_ + +defaults: + - override /dataset: inaturalist_emb + - override /model: emb_cond_cartesian + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: sigmoid + - override /model/inference_noise_scheduler: sigmoid + - override /model/loss: flow_matching + - override /model/val_sampler: flow_matching + - override /model/test_sampler: flow_matching + - _self_ + +model: + network: + depth: 12 + dim: 256 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.1 + loss: + cond_drop_rate: 0.1 + train_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + inference_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + interpolant: flow_matching + +dataset: + full_batch_size: 512 + +areas: [] + +experiment_name_suffix: small_sigmoid \ No newline at end of file diff --git a/plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml b/plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml new file mode 100644 index 0000000000000000000000000000000000000000..904eeac8ecf2d1980c3261db2fcf4eb1450fe4ab --- /dev/null +++ b/plonk/configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml @@ -0,0 +1,40 @@ +# @package _global_ + +defaults: + - override /dataset: inaturalist_emb + - override /model: emb_cond_cartesian + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: sigmoid + - override /model/inference_noise_scheduler: sigmoid + - override /model/loss: riemannian_flow_matching + - override /model/manifold: sphere + - override /model/val_sampler: riemannian_flow_matching + - override /model/test_sampler: riemannian_flow_matching + - _self_ + +model: + network: + depth: 12 + dim: 256 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.1 + loss: + cond_drop_rate: 0.1 + train_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + inference_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + interpolant: flow_matching + +dataset: + full_batch_size: 512 + +areas: [] + +experiment_name_suffix: small_sigmoid \ No newline at end of file diff --git a/plonk/configs/exp/iNaturalist_geoadalnmlp_von_fisher.yaml b/plonk/configs/exp/iNaturalist_geoadalnmlp_von_fisher.yaml new file mode 100644 index 0000000000000000000000000000000000000000..86c7400c44efaa9f306329d738d18c8b5c9af946 --- /dev/null +++ b/plonk/configs/exp/iNaturalist_geoadalnmlp_von_fisher.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +defaults: + - override /dataset: inaturalist_emb + - override /model: von_fisher + - override /model/network: geo_adaln_mlp_von_fisher + - override /model/loss: von_fisher + - override /model/val_sampler: von_fisher + - override /model/test_sampler: von_fisher + - _self_ + +model: + network: + depth: 11 # To compensate the increase in params + dim: 256 + optimizer: + optim: + lr: 1e-4 + weight_decay: 0.1 +dataset: + full_batch_size: 512 +trainer: + gradient_clip_val: 0.01 + gradient_clip_algorithm: norm +areas: [] +experiment_name_suffix: von_fisher \ No newline at end of file diff --git a/plonk/configs/exp/iNaturalist_geoadalnmlp_von_fisher_mixture.yaml b/plonk/configs/exp/iNaturalist_geoadalnmlp_von_fisher_mixture.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dfbc6019225b699de292cefd27e9d31da3515240 --- /dev/null +++ b/plonk/configs/exp/iNaturalist_geoadalnmlp_von_fisher_mixture.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +defaults: + - override /dataset: inaturalist_emb + - override /model: von_fisher_mixture + - override /model/network: geo_adaln_mlp_von_fisher_mixture + - override /model/loss: von_fisher_mixture + - override /model/val_sampler: von_fisher_mixture + - override /model/test_sampler: von_fisher_mixture + - _self_ + +model: + network: + depth: 11 # To compensate the increase in params + dim: 256 + optimizer: + optim: + lr: 1e-5 + weight_decay: 0.1 +dataset: + full_batch_size: 512 +trainer: + gradient_clip_val: 0.01 + gradient_clip_algorithm: norm +areas: [] +experiment_name_suffix: von_fisher_mixture diff --git a/plonk/configs/exp/osv_5m_geoadalnmlp_r2_small_sigmoid_diffusion.yaml b/plonk/configs/exp/osv_5m_geoadalnmlp_r2_small_sigmoid_diffusion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5c931fc74996f63e194b09d94876203421f908cd --- /dev/null +++ b/plonk/configs/exp/osv_5m_geoadalnmlp_r2_small_sigmoid_diffusion.yaml @@ -0,0 +1,34 @@ +# @package _global_ + +defaults: + - override /dataset: osv5m_emb + - override /model: emb_cond + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: sigmoid + - override /model/inference_noise_scheduler: sigmoid + - override /model/loss: ddpm + - _self_ + +model: + network: + depth: 12 + dim: 512 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.05 + loss: + cond_drop_rate: 0.1 + train_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + inference_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + interpolant: diffusion +dataset: + full_batch_size: 1024 + +experiment_name_suffix: small_sigmoid \ No newline at end of file diff --git a/plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_linear_flow_riemann.yaml b/plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_linear_flow_riemann.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5a31ffd41250fe0abe628f0de14f2a9da2d33127 --- /dev/null +++ b/plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_linear_flow_riemann.yaml @@ -0,0 +1,30 @@ +# @package _global_ + +defaults: + - override /dataset: osv5m_emb + - override /model: emb_cond_cartesian + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: linear + - override /model/inference_noise_scheduler: linear + - override /model/loss: riemannian_flow_matching + - override /model/manifold: sphere + - override /model/val_sampler: riemannian_flow_matching + - override /model/test_sampler: riemannian_flow_matching + - _self_ + +model: + network: + depth: 12 + dim: 512 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.05 + loss: + cond_drop_rate: 0.1 + interpolant: flow_matching + +dataset: + full_batch_size: 1024 + +experiment_name_suffix: small_sigmoid \ No newline at end of file diff --git a/plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_sigmoid_diffusion.yaml b/plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_sigmoid_diffusion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..df953892119cd50b386e950ebfd7e4e14a874761 --- /dev/null +++ b/plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_sigmoid_diffusion.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +defaults: + - override /dataset: osv5m_emb + - override /model: emb_cond_cartesian + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: sigmoid + - override /model/inference_noise_scheduler: sigmoid + - override /model/loss: ddpm + - _self_ + +model: + network: + depth: 12 + dim: 512 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.05 + loss: + cond_drop_rate: 0.1 + train_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + inference_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + interpolant: diffusion + +dataset: + full_batch_size: 1024 + +experiment_name_suffix: small_sigmoid \ No newline at end of file diff --git a/plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_sigmoid_flow.yaml b/plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_sigmoid_flow.yaml new file mode 100644 index 0000000000000000000000000000000000000000..05459ee799d32a8ed3e87c841ac59959df7239c0 --- /dev/null +++ b/plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_sigmoid_flow.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +defaults: + - override /dataset: osv5m_emb + - override /model: emb_cond_cartesian + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: sigmoid + - override /model/inference_noise_scheduler: sigmoid + - override /model/loss: flow_matching + - override /model/val_sampler: flow_matching + - override /model/test_sampler: flow_matching + - _self_ + +model: + network: + depth: 12 + dim: 512 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.05 + loss: + cond_drop_rate: 0.1 + train_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + inference_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + interpolant: flow_matching + +dataset: + full_batch_size: 1024 + +experiment_name_suffix: small_sigmoid \ No newline at end of file diff --git a/plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml b/plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5bfc89b84e0397c6aa6b363e59c4dad076414eea --- /dev/null +++ b/plonk/configs/exp/osv_5m_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml @@ -0,0 +1,38 @@ +# @package _global_ + +defaults: + - override /dataset: osv5m_emb + - override /model: emb_cond_cartesian + - override /model/network: geo_adaln_mlp + - override /model/train_noise_scheduler: sigmoid + - override /model/inference_noise_scheduler: sigmoid + - override /model/loss: riemannian_flow_matching + - override /model/manifold: sphere + - override /model/val_sampler: riemannian_flow_matching + - override /model/test_sampler: riemannian_flow_matching + - _self_ + +model: + network: + depth: 12 + dim: 512 + optimizer: + optim: + lr: 8e-4 + weight_decay: 0.05 + loss: + cond_drop_rate: 0.1 + train_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + inference_noise_scheduler: + start: -7 + end: 3 + tau: 1.0 + interpolant: flow_matching + +dataset: + full_batch_size: 1024 + +experiment_name_suffix: small_sigmoid \ No newline at end of file diff --git a/plonk/configs/exp/osv_5m_geoadalnmlp_von_fisher.yaml b/plonk/configs/exp/osv_5m_geoadalnmlp_von_fisher.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d48f03164a22adbeceb57c3039acd0ed81f7d02 --- /dev/null +++ b/plonk/configs/exp/osv_5m_geoadalnmlp_von_fisher.yaml @@ -0,0 +1,25 @@ +# @package _global_ + +defaults: + - override /dataset: osv5m_emb + - override /model: von_fisher + - override /model/network: geo_adaln_mlp_von_fisher + - override /model/loss: von_fisher + - override /model/val_sampler: von_fisher + - override /model/test_sampler: von_fisher + - _self_ + +model: + network: + depth: 11 # To compensate the increase in params + dim: 512 + optimizer: + optim: + lr: 1e-4 + weight_decay: 0.05 +dataset: + full_batch_size: 1024 +trainer: + gradient_clip_val: 0.05 + gradient_clip_algorithm: norm +experiment_name_suffix: von_fisher \ No newline at end of file diff --git a/plonk/configs/exp/osv_5m_geoadalnmlp_von_fisher_mixture.yaml b/plonk/configs/exp/osv_5m_geoadalnmlp_von_fisher_mixture.yaml new file mode 100644 index 0000000000000000000000000000000000000000..96c0191c064b8f7a673512c365656816c41da1c4 --- /dev/null +++ b/plonk/configs/exp/osv_5m_geoadalnmlp_von_fisher_mixture.yaml @@ -0,0 +1,25 @@ +# @package _global_ + +defaults: + - override /dataset: osv5m_emb + - override /model: von_fisher_mixture + - override /model/network: geo_adaln_mlp_von_fisher_mixture + - override /model/loss: von_fisher_mixture + - override /model/val_sampler: von_fisher_mixture + - override /model/test_sampler: von_fisher_mixture + - _self_ + +model: + network: + depth: 11 # To compensate the increase in params + dim: 512 + optimizer: + optim: + lr: 1e-4 + weight_decay: 0.05 +dataset: + full_batch_size: 1024 +trainer: + gradient_clip_val: 0.05 + gradient_clip_algorithm: norm +experiment_name_suffix: von_fisher_mixture diff --git a/plonk/configs/model/cond_preprocessing/embedding.yaml b/plonk/configs/model/cond_preprocessing/embedding.yaml new file mode 100644 index 0000000000000000000000000000000000000000..45b2ca089cd36db354b1123276dc3c107afe8915 --- /dev/null +++ b/plonk/configs/model/cond_preprocessing/embedding.yaml @@ -0,0 +1,3 @@ +_target_: plonk.models.preprocessing.PrecomputedPreconditioning +input_key: emb +output_key: emb \ No newline at end of file diff --git a/plonk/configs/model/data_preprocessing/gps.yaml b/plonk/configs/model/data_preprocessing/gps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c103609302e1926527593161dcd09d4c23f640b3 --- /dev/null +++ b/plonk/configs/model/data_preprocessing/gps.yaml @@ -0,0 +1,4 @@ +_target_: plonk.models.preprocessing.NormGPS +input_key: gps +output_key: x_0 +normalize: False \ No newline at end of file diff --git a/plonk/configs/model/data_preprocessing/gps_to_cartesian.yaml b/plonk/configs/model/data_preprocessing/gps_to_cartesian.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c0f3fdcca4bd76ce8091f712c59184e1b17d4012 --- /dev/null +++ b/plonk/configs/model/data_preprocessing/gps_to_cartesian.yaml @@ -0,0 +1,3 @@ +_target_: plonk.models.preprocessing.GPStoCartesian +input_key: gps +output_key: x_0 \ No newline at end of file diff --git a/plonk/configs/model/data_preprocessing/normalized_gps.yaml b/plonk/configs/model/data_preprocessing/normalized_gps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4221cbe0e3dceaab318aaa267d8f0a891da6bb75 --- /dev/null +++ b/plonk/configs/model/data_preprocessing/normalized_gps.yaml @@ -0,0 +1,4 @@ +_target_: plonk.models.preprocessing.NormGPS +input_key: gps +output_key: x_0 +normalize: True \ No newline at end of file diff --git a/plonk/configs/model/emb_cond.yaml b/plonk/configs/model/emb_cond.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f00df12fd3ad41aece86d88fd8ba509ae6f4d8c --- /dev/null +++ b/plonk/configs/model/emb_cond.yaml @@ -0,0 +1,24 @@ +defaults: + - optimizer: lamb + - lr_scheduler: warmup_cosine_decay + - network: geo_adaln_mlp + - train_noise_scheduler: sigmoid + - inference_noise_scheduler: cosine_simple + - preconditioning: ddpm + - data_preprocessing: normalized_gps + - cond_preprocessing: embedding + - postprocessing: renorm_gps + - loss: ddpm + - val_sampler: ddim + - test_sampler: ddpm + - manifold: null + - _self_ + +network: + input_dim: 2 +name: GeoMLP_R2 +ema_decay: 0.999 +start_ema_step: 0 +cfg_rate: 2.0 +interpolant: flow_matching +compute_nll: true \ No newline at end of file diff --git a/plonk/configs/model/emb_cond_cartesian.yaml b/plonk/configs/model/emb_cond_cartesian.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f8cc9be47b5f89fcd3dea1b00e6185ff0aade5d4 --- /dev/null +++ b/plonk/configs/model/emb_cond_cartesian.yaml @@ -0,0 +1,25 @@ +defaults: + - optimizer: lamb + - lr_scheduler: warmup_cosine_decay + - network: geo_adaln_mlp + - train_noise_scheduler: sigmoid + - inference_noise_scheduler: cosine_simple + - preconditioning: ddpm + - data_preprocessing: gps_to_cartesian + - cond_preprocessing: embedding + - postprocessing: cartesian_to_gps + - loss: ddpm + - val_sampler: ddim + - test_sampler: ddpm + - manifold: null + - _self_ + +network: + input_dim: 3 +name: GeoMLP_R3 +ema_decay: 0.999 +start_ema_step: 0 +cfg_rate: 2.0 +interpolant: flow_matching +compute_nll: true +compute_swarms: False \ No newline at end of file diff --git a/plonk/configs/model/inference_noise_scheduler/cosine.yaml b/plonk/configs/model/inference_noise_scheduler/cosine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0a107cff0ecbbcb84e8ca807127e024df5bbe083 --- /dev/null +++ b/plonk/configs/model/inference_noise_scheduler/cosine.yaml @@ -0,0 +1,5 @@ +_target_: plonk.models.schedulers.CosineScheduler +start: 1 +end: 0 +tau: 1 +clip_min: 1e-9 \ No newline at end of file diff --git a/plonk/configs/model/inference_noise_scheduler/cosine_simple.yaml b/plonk/configs/model/inference_noise_scheduler/cosine_simple.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e1aacad97b573a3481df4af5373b6c8a5b53477c --- /dev/null +++ b/plonk/configs/model/inference_noise_scheduler/cosine_simple.yaml @@ -0,0 +1,3 @@ +_target_: plonk.models.schedulers.CosineSchedulerSimple +ns: 2e-4 +ds: 2.5e-4 \ No newline at end of file diff --git a/plonk/configs/model/inference_noise_scheduler/linear.yaml b/plonk/configs/model/inference_noise_scheduler/linear.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d77bf7e8c785e0123076293aca1380faaeab6b5 --- /dev/null +++ b/plonk/configs/model/inference_noise_scheduler/linear.yaml @@ -0,0 +1,4 @@ +_target_: plonk.models.schedulers.LinearScheduler +start: 1 +end: 0 +clip_min: 1e-9 \ No newline at end of file diff --git a/plonk/configs/model/inference_noise_scheduler/sigmoid.yaml b/plonk/configs/model/inference_noise_scheduler/sigmoid.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fcd2e1f1233ac0545468dcbc44d0a4b41ae5d849 --- /dev/null +++ b/plonk/configs/model/inference_noise_scheduler/sigmoid.yaml @@ -0,0 +1,5 @@ +_target_: plonk.models.schedulers.SigmoidScheduler +start: -3 +end: 3 +tau: 0.9 +clip_min: 1e-9 \ No newline at end of file diff --git a/plonk/configs/model/loss/ddpm.yaml b/plonk/configs/model/loss/ddpm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6cbd2f3c8b4dad725f280ba2b6f6bc65adcbef48 --- /dev/null +++ b/plonk/configs/model/loss/ddpm.yaml @@ -0,0 +1,4 @@ +_partial_: true +_target_: plonk.models.losses.DDPMLoss +cond_drop_rate: 0.0 +conditioning_key: ${model.cond_preprocessing.output_key} \ No newline at end of file diff --git a/plonk/configs/model/loss/flow_matching.yaml b/plonk/configs/model/loss/flow_matching.yaml new file mode 100644 index 0000000000000000000000000000000000000000..92091fbe00ef9d486d74874124f476e5ffaa2e55 --- /dev/null +++ b/plonk/configs/model/loss/flow_matching.yaml @@ -0,0 +1,4 @@ +_partial_: true +_target_: plonk.models.losses.FlowMatchingLoss +cond_drop_rate: 0.0 +conditioning_key: ${model.cond_preprocessing.output_key} \ No newline at end of file diff --git a/plonk/configs/model/loss/riemannian_flow_matching.yaml b/plonk/configs/model/loss/riemannian_flow_matching.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a21c6f42fd0b958a1eb81c9b1c26fca3bcfc0432 --- /dev/null +++ b/plonk/configs/model/loss/riemannian_flow_matching.yaml @@ -0,0 +1,4 @@ +_partial_: true +_target_: plonk.models.losses.RiemannianFlowMatchingLoss +cond_drop_rate: 0.0 +conditioning_key: ${model.cond_preprocessing.output_key} \ No newline at end of file diff --git a/plonk/configs/model/loss/von_fisher.yaml b/plonk/configs/model/loss/von_fisher.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aff246471ac06a6adc57eebdde1839927a5e2265 --- /dev/null +++ b/plonk/configs/model/loss/von_fisher.yaml @@ -0,0 +1,2 @@ +_partial_: true +_target_: plonk.models.losses.VonFisherLoss diff --git a/plonk/configs/model/loss/von_fisher_mixture.yaml b/plonk/configs/model/loss/von_fisher_mixture.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd1ff2712bac07708500db24aed980b862e40d5d --- /dev/null +++ b/plonk/configs/model/loss/von_fisher_mixture.yaml @@ -0,0 +1,2 @@ +_partial_: true +_target_: plonk.models.losses.VonFisherMixtureLoss diff --git a/plonk/configs/model/lr_scheduler/warmup.yaml b/plonk/configs/model/lr_scheduler/warmup.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f775b8f0d95af575b447cee74d2dcbd60e97ff55 --- /dev/null +++ b/plonk/configs/model/lr_scheduler/warmup.yaml @@ -0,0 +1,4 @@ +_partial_: true +_target_: plonk.utils.lr_scheduler.WarmupLR +warmup_steps: 500 + diff --git a/plonk/configs/model/lr_scheduler/warmup_cosine_decay.yaml b/plonk/configs/model/lr_scheduler/warmup_cosine_decay.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1a86f029dfe6e7ddf2bf044f12ed216c290cf55e --- /dev/null +++ b/plonk/configs/model/lr_scheduler/warmup_cosine_decay.yaml @@ -0,0 +1,5 @@ +_partial_: true +_target_: plonk.utils.lr_scheduler.WarmupCosineDecayLR +warmup_steps: 500 +total_steps: ${trainer.max_steps} + diff --git a/plonk/configs/model/manifold/sphere.yaml b/plonk/configs/model/manifold/sphere.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2e78753de0ae352997f5dbcbab7c26498730b224 --- /dev/null +++ b/plonk/configs/model/manifold/sphere.yaml @@ -0,0 +1 @@ +_target_: plonk.utils.manifolds.Sphere \ No newline at end of file diff --git a/plonk/configs/model/network/geo_adaln_mlp.yaml b/plonk/configs/model/network/geo_adaln_mlp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b6503b603533a96259cdcf108d89f22a7706f92b --- /dev/null +++ b/plonk/configs/model/network/geo_adaln_mlp.yaml @@ -0,0 +1,6 @@ +_target_: plonk.models.networks.mlp.GeoAdaLNMLP +input_dim: 2 +dim: 256 +depth: 8 +expansion: 4 +cond_dim: ${dataset.cond_dim} \ No newline at end of file diff --git a/plonk/configs/model/network/geo_adaln_mlp_von_fisher.yaml b/plonk/configs/model/network/geo_adaln_mlp_von_fisher.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9cca721c02d67cfe1ee00702eaa43d6f382395cf --- /dev/null +++ b/plonk/configs/model/network/geo_adaln_mlp_von_fisher.yaml @@ -0,0 +1,6 @@ +_target_: plonk.models.networks.mlp.GeoAdaLNMLPVonFisher +input_dim: 2 +dim: 256 +depth: 8 +expansion: 4 +cond_dim: ${dataset.cond_dim} \ No newline at end of file diff --git a/plonk/configs/model/network/geo_adaln_mlp_von_fisher_mixture.yaml b/plonk/configs/model/network/geo_adaln_mlp_von_fisher_mixture.yaml new file mode 100644 index 0000000000000000000000000000000000000000..09b0d331ebbf4e9753929df8bd30f4f83c67650a --- /dev/null +++ b/plonk/configs/model/network/geo_adaln_mlp_von_fisher_mixture.yaml @@ -0,0 +1,7 @@ +_target_: plonk.models.networks.mlp.GeoAdaLNMLPVonFisherMixture +input_dim: 2 +dim: 256 +depth: 8 +expansion: 4 +cond_dim: ${dataset.cond_dim} +num_mixtures: 3 \ No newline at end of file diff --git a/plonk/configs/model/network/geo_mlp.yaml b/plonk/configs/model/network/geo_mlp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8312ab4fbd1bb6e024a429f4a921963c43bfef20 --- /dev/null +++ b/plonk/configs/model/network/geo_mlp.yaml @@ -0,0 +1,5 @@ +_target_: plonk.models.networks.mlp.GeoConcatNMLP +input_dim: 2 +hidden_dim: 512 +depth: 5 +cond_dim: ${dataset.cond_dim} \ No newline at end of file diff --git a/plonk/configs/model/optimizer/adam.yaml b/plonk/configs/model/optimizer/adam.yaml new file mode 100755 index 0000000000000000000000000000000000000000..55490d3492168181115ef90949a1232fece3f7b5 --- /dev/null +++ b/plonk/configs/model/optimizer/adam.yaml @@ -0,0 +1,7 @@ +optim: + _target_: torch.optim.Adam + lr: 1e-3 + betas: [0.9, 0.999] + weight_decay: 0.01 + +exclude_ln_and_biases_from_weight_decay: False \ No newline at end of file diff --git a/plonk/configs/model/optimizer/adamw.yaml b/plonk/configs/model/optimizer/adamw.yaml new file mode 100755 index 0000000000000000000000000000000000000000..7b6217c6a98035ffa390a6ea0c8930754698d8f6 --- /dev/null +++ b/plonk/configs/model/optimizer/adamw.yaml @@ -0,0 +1,7 @@ +optim: + _target_: torch.optim.AdamW + lr: 1e-3 + betas: [0.9, 0.999] + weight_decay: 0.01 + +exclude_ln_and_biases_from_weight_decay: False \ No newline at end of file diff --git a/plonk/configs/model/optimizer/lamb.yaml b/plonk/configs/model/optimizer/lamb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d0dabd18ced67730b0cc9f542fed5dc69797dca --- /dev/null +++ b/plonk/configs/model/optimizer/lamb.yaml @@ -0,0 +1,7 @@ +optim: + _target_: plonk.utils.optimizers.Lamb + lr: 1e-3 + betas: [0.9, 0.999] + weight_decay: 0.01 + +exclude_ln_and_biases_from_weight_decay: False \ No newline at end of file diff --git a/plonk/configs/model/optimizer/sgd.yaml b/plonk/configs/model/optimizer/sgd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15f1c6c52521dc794c5bbc7a2740c5d0659fa6eb --- /dev/null +++ b/plonk/configs/model/optimizer/sgd.yaml @@ -0,0 +1,6 @@ +optim: + _target_: torch.optim.SGD + lr: 1e-3 + weight_decay: 0.01 + +exclude_ln_and_biases_from_weight_decay: False \ No newline at end of file diff --git a/plonk/configs/model/postprocessing/cartesian_to_gps.yaml b/plonk/configs/model/postprocessing/cartesian_to_gps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..04a325270fd09492bf5652ecdb7f0dbc07c3fdce --- /dev/null +++ b/plonk/configs/model/postprocessing/cartesian_to_gps.yaml @@ -0,0 +1 @@ +_target_: plonk.models.postprocessing.CartesiantoGPS \ No newline at end of file diff --git a/plonk/configs/model/postprocessing/renorm_gps.yaml b/plonk/configs/model/postprocessing/renorm_gps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6e376b81d3255e032daa986f607b9461f274f7bf --- /dev/null +++ b/plonk/configs/model/postprocessing/renorm_gps.yaml @@ -0,0 +1 @@ +_target_: plonk.models.postprocessing.UnormGPS \ No newline at end of file diff --git a/plonk/configs/model/preconditioning/ddpm.yaml b/plonk/configs/model/preconditioning/ddpm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..905c13339b06668a3784063de388da557d522754 --- /dev/null +++ b/plonk/configs/model/preconditioning/ddpm.yaml @@ -0,0 +1 @@ +_target_: plonk.models.preconditioning.DDPMPrecond \ No newline at end of file diff --git a/plonk/configs/model/preconditioning/edm.yaml b/plonk/configs/model/preconditioning/edm.yaml new file mode 100755 index 0000000000000000000000000000000000000000..4222a8989c67e595d29cb26f231f80ed859c08b0 --- /dev/null +++ b/plonk/configs/model/preconditioning/edm.yaml @@ -0,0 +1,6 @@ +_partial_: true +_target_: plonk.models.preconditioning.EDMPrecond +label_dim: ${data.label_dim} +sigma_min: 0 +sigma_max: !!float .inf +sigma_data: 0.5 \ No newline at end of file diff --git a/plonk/configs/model/test_sampler/ddim.yaml b/plonk/configs/model/test_sampler/ddim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..58b3e30f972c8db39efb0c9fe7e0734b03107cf2 --- /dev/null +++ b/plonk/configs/model/test_sampler/ddim.yaml @@ -0,0 +1,4 @@ +_partial_: true +_target_: plonk.models.samplers.ddim.ddim_sampler +num_steps: 250 +cfg_rate: ${model.cfg_rate} \ No newline at end of file diff --git a/plonk/configs/model/test_sampler/ddpm.yaml b/plonk/configs/model/test_sampler/ddpm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ab9387c26fc1ccdc9669d0496594d27e5369b7f --- /dev/null +++ b/plonk/configs/model/test_sampler/ddpm.yaml @@ -0,0 +1,4 @@ +_partial_: true +_target_: plonk.models.samplers.ddpm.ddpm_sampler +num_steps: 1000 +cfg_rate: ${model.cfg_rate} \ No newline at end of file diff --git a/plonk/configs/model/test_sampler/edm.yaml b/plonk/configs/model/test_sampler/edm.yaml new file mode 100755 index 0000000000000000000000000000000000000000..4ab39959d81c2170744109ead68ca76448760f6b --- /dev/null +++ b/plonk/configs/model/test_sampler/edm.yaml @@ -0,0 +1,10 @@ +_partial_: true +_target_: plonk.models.samplers.edm.edm_sampler +num_steps: 18 +sigma_min: 0.002 +sigma_max: 80 +rho: 7 +S_churn: 0 +S_min: 0 +S_max: !!float .inf +S_noise: 1 \ No newline at end of file diff --git a/plonk/configs/model/test_sampler/flow_matching.yaml b/plonk/configs/model/test_sampler/flow_matching.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0224637974f8430c5a8d7898656965de450f7410 --- /dev/null +++ b/plonk/configs/model/test_sampler/flow_matching.yaml @@ -0,0 +1,4 @@ +_partial_: true +_target_: plonk.models.samplers.flow_sampler.flow_sampler +num_steps: 250 +cfg_rate: ${model.cfg_rate} \ No newline at end of file diff --git a/plonk/configs/model/test_sampler/riemannian_flow_matching.yaml b/plonk/configs/model/test_sampler/riemannian_flow_matching.yaml new file mode 100644 index 0000000000000000000000000000000000000000..068fd1613899a8d716c73e2b307034d0cb720929 --- /dev/null +++ b/plonk/configs/model/test_sampler/riemannian_flow_matching.yaml @@ -0,0 +1,4 @@ +_partial_: true +_target_: plonk.models.samplers.riemannian_flow_sampler.riemannian_flow_sampler +num_steps: 250 +cfg_rate: ${model.cfg_rate} \ No newline at end of file diff --git a/plonk/configs/model/test_sampler/von_fisher.yaml b/plonk/configs/model/test_sampler/von_fisher.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9cce1ead4476fb8ec76bc13647b90e2b6e250c48 --- /dev/null +++ b/plonk/configs/model/test_sampler/von_fisher.yaml @@ -0,0 +1,2 @@ +_partial_: true +_target_: plonk.models.samplers.von_fisher_sampling.vMF_sampler diff --git a/plonk/configs/model/test_sampler/von_fisher_mixture.yaml b/plonk/configs/model/test_sampler/von_fisher_mixture.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2661314df5ef4a6d6ac16511278f080298bc2824 --- /dev/null +++ b/plonk/configs/model/test_sampler/von_fisher_mixture.yaml @@ -0,0 +1,2 @@ +_partial_: true +_target_: plonk.models.samplers.von_fisher_sampling.vMF_mixture_sampler diff --git a/plonk/configs/model/train_noise_scheduler/cosine.yaml b/plonk/configs/model/train_noise_scheduler/cosine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0a107cff0ecbbcb84e8ca807127e024df5bbe083 --- /dev/null +++ b/plonk/configs/model/train_noise_scheduler/cosine.yaml @@ -0,0 +1,5 @@ +_target_: plonk.models.schedulers.CosineScheduler +start: 1 +end: 0 +tau: 1 +clip_min: 1e-9 \ No newline at end of file diff --git a/plonk/configs/model/train_noise_scheduler/cosine_simple.yaml b/plonk/configs/model/train_noise_scheduler/cosine_simple.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e1aacad97b573a3481df4af5373b6c8a5b53477c --- /dev/null +++ b/plonk/configs/model/train_noise_scheduler/cosine_simple.yaml @@ -0,0 +1,3 @@ +_target_: plonk.models.schedulers.CosineSchedulerSimple +ns: 2e-4 +ds: 2.5e-4 \ No newline at end of file diff --git a/plonk/configs/model/train_noise_scheduler/linear.yaml b/plonk/configs/model/train_noise_scheduler/linear.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d77bf7e8c785e0123076293aca1380faaeab6b5 --- /dev/null +++ b/plonk/configs/model/train_noise_scheduler/linear.yaml @@ -0,0 +1,4 @@ +_target_: plonk.models.schedulers.LinearScheduler +start: 1 +end: 0 +clip_min: 1e-9 \ No newline at end of file diff --git a/plonk/configs/model/train_noise_scheduler/sigmoid.yaml b/plonk/configs/model/train_noise_scheduler/sigmoid.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fcd2e1f1233ac0545468dcbc44d0a4b41ae5d849 --- /dev/null +++ b/plonk/configs/model/train_noise_scheduler/sigmoid.yaml @@ -0,0 +1,5 @@ +_target_: plonk.models.schedulers.SigmoidScheduler +start: -3 +end: 3 +tau: 0.9 +clip_min: 1e-9 \ No newline at end of file diff --git a/plonk/configs/model/val_sampler/ddim.yaml b/plonk/configs/model/val_sampler/ddim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..58b3e30f972c8db39efb0c9fe7e0734b03107cf2 --- /dev/null +++ b/plonk/configs/model/val_sampler/ddim.yaml @@ -0,0 +1,4 @@ +_partial_: true +_target_: plonk.models.samplers.ddim.ddim_sampler +num_steps: 250 +cfg_rate: ${model.cfg_rate} \ No newline at end of file diff --git a/plonk/configs/model/val_sampler/ddpm.yaml b/plonk/configs/model/val_sampler/ddpm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ab9387c26fc1ccdc9669d0496594d27e5369b7f --- /dev/null +++ b/plonk/configs/model/val_sampler/ddpm.yaml @@ -0,0 +1,4 @@ +_partial_: true +_target_: plonk.models.samplers.ddpm.ddpm_sampler +num_steps: 1000 +cfg_rate: ${model.cfg_rate} \ No newline at end of file diff --git a/plonk/configs/model/val_sampler/edm.yaml b/plonk/configs/model/val_sampler/edm.yaml new file mode 100755 index 0000000000000000000000000000000000000000..4ab39959d81c2170744109ead68ca76448760f6b --- /dev/null +++ b/plonk/configs/model/val_sampler/edm.yaml @@ -0,0 +1,10 @@ +_partial_: true +_target_: plonk.models.samplers.edm.edm_sampler +num_steps: 18 +sigma_min: 0.002 +sigma_max: 80 +rho: 7 +S_churn: 0 +S_min: 0 +S_max: !!float .inf +S_noise: 1 \ No newline at end of file diff --git a/plonk/configs/model/val_sampler/flow_matching.yaml b/plonk/configs/model/val_sampler/flow_matching.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0224637974f8430c5a8d7898656965de450f7410 --- /dev/null +++ b/plonk/configs/model/val_sampler/flow_matching.yaml @@ -0,0 +1,4 @@ +_partial_: true +_target_: plonk.models.samplers.flow_sampler.flow_sampler +num_steps: 250 +cfg_rate: ${model.cfg_rate} \ No newline at end of file diff --git a/plonk/configs/model/val_sampler/riemannian_flow_matching.yaml b/plonk/configs/model/val_sampler/riemannian_flow_matching.yaml new file mode 100644 index 0000000000000000000000000000000000000000..068fd1613899a8d716c73e2b307034d0cb720929 --- /dev/null +++ b/plonk/configs/model/val_sampler/riemannian_flow_matching.yaml @@ -0,0 +1,4 @@ +_partial_: true +_target_: plonk.models.samplers.riemannian_flow_sampler.riemannian_flow_sampler +num_steps: 250 +cfg_rate: ${model.cfg_rate} \ No newline at end of file diff --git a/plonk/configs/model/val_sampler/von_fisher.yaml b/plonk/configs/model/val_sampler/von_fisher.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9cce1ead4476fb8ec76bc13647b90e2b6e250c48 --- /dev/null +++ b/plonk/configs/model/val_sampler/von_fisher.yaml @@ -0,0 +1,2 @@ +_partial_: true +_target_: plonk.models.samplers.von_fisher_sampling.vMF_sampler diff --git a/plonk/configs/model/val_sampler/von_fisher_mixture.yaml b/plonk/configs/model/val_sampler/von_fisher_mixture.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2661314df5ef4a6d6ac16511278f080298bc2824 --- /dev/null +++ b/plonk/configs/model/val_sampler/von_fisher_mixture.yaml @@ -0,0 +1,2 @@ +_partial_: true +_target_: plonk.models.samplers.von_fisher_sampling.vMF_mixture_sampler diff --git a/plonk/configs/model/von_fisher.yaml b/plonk/configs/model/von_fisher.yaml new file mode 100644 index 0000000000000000000000000000000000000000..80f5429564432fa096e65b4910234b38423930f5 --- /dev/null +++ b/plonk/configs/model/von_fisher.yaml @@ -0,0 +1,19 @@ +defaults: + - optimizer: lamb + - lr_scheduler: warmup_cosine_decay + - network: geo_adaln_mlp_von_fisher + - preconditioning: ddpm + - data_preprocessing: gps_to_cartesian + - cond_preprocessing: embedding + - postprocessing: cartesian_to_gps + - loss: von_fisher + - val_sampler: von_fisher + - test_sampler: von_fisher + - _self_ + +network: + input_dim: 3 +name: GeoMLP_R3_VonFisher +ema_decay: 0.999 +start_ema_step: 0 +interpolant: von_fisher \ No newline at end of file diff --git a/plonk/configs/model/von_fisher_mixture.yaml b/plonk/configs/model/von_fisher_mixture.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ae27b7edda558d08d2746d28d98dfe80b78ef573 --- /dev/null +++ b/plonk/configs/model/von_fisher_mixture.yaml @@ -0,0 +1,19 @@ +defaults: + - optimizer: lamb + - lr_scheduler: warmup_cosine_decay + - network: geo_adaln_mlp_von_fisher_mixture + - preconditioning: ddpm + - data_preprocessing: gps_to_cartesian + - cond_preprocessing: embedding + - postprocessing: cartesian_to_gps + - loss: von_fisher_mixture + - val_sampler: von_fisher_mixture + - test_sampler: von_fisher_mixture + - _self_ + +network: + input_dim: 3 +name: GeoMLP_R3_VonFisher_Mixture +ema_decay: 0.999 +start_ema_step: 0 +interpolant: von_fisher \ No newline at end of file diff --git a/plonk/configs/stage/debug.yaml b/plonk/configs/stage/debug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e4e1f2a87a3d8f36cde2ea5de293f59a7bd1cdc --- /dev/null +++ b/plonk/configs/stage/debug.yaml @@ -0,0 +1,4 @@ +# @package _global_ + + +stage: debug \ No newline at end of file diff --git a/plonk/configs/stage/profile.yaml b/plonk/configs/stage/profile.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a09a37769b6bf617619e2dc92d68dc04127b9f7a --- /dev/null +++ b/plonk/configs/stage/profile.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +trainer: + max_steps: 15 + profiler: + _target_: pytorch_lightning.profilers.PyTorchProfiler + dirpath: ${root_dir}/plonk/profiler_log/${experiment_name} + schedule: + _target_: torch.profiler.schedule + skip_first: 5 + wait: 2 + warmup: 1 + active: 3 + repeat: 0 + on_trace_ready: + _target_: torch.profiler.tensorboard_trace_handler + dir_name: ${root_dir}/plonk/profiler_log/${experiment_name} + with_stack: True + record_shapes: True + with_modules: True \ No newline at end of file diff --git a/plonk/metrics/__init__.py b/plonk/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/plonk/metrics/distance_based.py b/plonk/metrics/distance_based.py new file mode 100644 index 0000000000000000000000000000000000000000..0b0a0878bbc468e74e86f018247726b76f379981 --- /dev/null +++ b/plonk/metrics/distance_based.py @@ -0,0 +1,272 @@ +import torch + +from metrics.utils import haversine, reverse +from sklearn.metrics import pairwise_distances +from torchmetrics import Metric +import numpy as np +from plonk.utils.kde import BatchedKDE +from tqdm import tqdm + + +class HaversineMetrics(Metric): + """ + Computes the average haversine distance between the predicted and ground truth points. + Compute the accuracy given some radiuses. + Compute the Geoguessr score given some radiuses. + + Args: + acc_radiuses (list): list of radiuses to compute the accuracy from + acc_area (list): list of areas to compute the accuracy from. + """ + + def __init__( + self, + acc_radiuses=[], + acc_area=["country", "region", "sub-region", "city"], + use_kde=False, + manifold_k=3, + ): + super().__init__() + self.use_kde = use_kde + self.add_state("haversine_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("geoguessr_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") + for acc in acc_radiuses: + self.add_state( + f"close_enough_points_{acc}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + for acc in acc_area: + self.add_state( + f"close_enough_points_{acc}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.add_state( + f"count_{acc}", default=torch.tensor(0), dist_reduce_fx="sum" + ) + self.acc_radius = acc_radiuses + self.acc_area = acc_area + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state( + "real_points", + [], + dist_reduce_fx=None, + ) + self.add_state( + "fake_points", + [], + dist_reduce_fx=None, + ) + self.manifold_k = manifold_k + + def update(self, pred, gt): + if self.use_kde: + (x_mode, y_mode), kde = estimate_kde_mode(pred["gps"]) + # self.nll_sum += -torch.log( + # kde.score(gt["gps"].unsqueeze(1).to(pred["gps"].device)) + # ).sum() + pred["gps"] = torch.stack([x_mode, y_mode], dim=1) + # Handle NaN values without modifying the original inputs + if pred["gps"].isnan().any(): + valid_mask = ~pred["gps"].isnan().any(dim=1) + pred_gps = pred["gps"][valid_mask] + gt_gps = gt["gps"][valid_mask] + if len(pred_gps) == 0: # Skip if no valid predictions remain + return + else: + pred_gps = pred["gps"] + gt_gps = gt["gps"] + haversine_distance = haversine(pred_gps, gt_gps) + for acc in self.acc_radius: + self.__dict__[f"close_enough_points_{acc}"] += ( + haversine_distance < acc + ).sum() + if len(self.acc_area) > 0: + area_pred, area_gt = reverse(pred_gps, gt, self.acc_area) + for acc in self.acc_area: + self.__dict__[f"close_enough_points_{acc}"] += ( + area_pred[acc] == area_gt["_".join(["unique", acc])] + ).sum() + self.__dict__[f"count_{acc}"] += len(area_gt["_".join(["unique", acc])]) + self.haversine_sum += haversine_distance.sum() + self.geoguessr_sum += 5000 * torch.exp(-haversine_distance / 1492.7).sum() + self.real_points.append(gt_gps) + self.fake_points.append(pred_gps) + self.count += pred_gps.shape[0] + + def compute(self): + output = { + "Haversine": self.haversine_sum / self.count, + "Geoguessr": self.geoguessr_sum / self.count, + } + for acc in self.acc_radius: + output[f"Accuracy_{acc}_km_radius"] = ( + self.__dict__[f"close_enough_points_{acc}"] / self.count + ) + for acc in self.acc_area: + output[f"Accuracy_{acc}"] = ( + self.__dict__[f"close_enough_points_{acc}"] + / self.__dict__[f"count_{acc}"] + ) + real_points = torch.cat(self.real_points, dim=0) + fake_points = torch.cat(self.fake_points, dim=0) + ( + output["precision"], + output["recall"], + output["density"], + output["coverage"], + ) = self.manifold_metrics(real_points, fake_points, self.manifold_k) + return output + + def compute_pairwise_distance(self, data_x, data_y=None): + """ + Args: + data_x: numpy.ndarray([N, feature_dim], dtype=np.float32) + data_y: numpy.ndarray([N, feature_dim], dtype=np.float32) + Returns: + numpy.ndarray([N, N], dtype=np.float32) of pairwise distances. + """ + if data_y is None: + data_y = data_x + + dists = pairwise_distances(data_x, data_y, metric="haversine", n_jobs=8) + return dists + + def get_kth_value(self, unsorted, k, axis=-1): + """ + Args: + unsorted: numpy.ndarray of any dimensionality. + k: int + Returns: + kth values along the designated axis. + """ + indices = np.argpartition(unsorted, k, axis=axis)[..., :k] + k_smallests = np.take_along_axis(unsorted, indices, axis=axis) + kth_values = k_smallests.max(axis=axis) + return kth_values + + def compute_nearest_neighbour_distances(self, input_features, nearest_k): + """ + Args: + input_features: numpy.ndarray([N, feature_dim], dtype=np.float32) + nearest_k: int + Returns: + Distances to kth nearest neighbours. + """ + distances = self.compute_pairwise_distance(input_features) + radii = self.get_kth_value(distances, k=nearest_k + 1, axis=-1) + return radii + + def compute_prdc(self, real_features, fake_features, nearest_k): + """ + Computes precision, recall, density, and coverage given two manifolds. + Args: + real_features: numpy.ndarray([N, feature_dim], dtype=np.float32) + fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32) + nearest_k: int. + Returns: + dict of precision, recall, density, and coverage. + """ + + real_nearest_neighbour_distances = self.compute_nearest_neighbour_distances( + real_features, nearest_k + ) + fake_nearest_neighbour_distances = self.compute_nearest_neighbour_distances( + fake_features, nearest_k + ) + distance_real_fake = self.compute_pairwise_distance( + real_features, fake_features + ) + + precision = ( + ( + distance_real_fake + < np.expand_dims(real_nearest_neighbour_distances, axis=1) + ) + .any(axis=0) + .mean() + ) + + recall = ( + ( + distance_real_fake + < np.expand_dims(fake_nearest_neighbour_distances, axis=0) + ) + .any(axis=1) + .mean() + ) + + density = (1.0 / float(nearest_k)) * ( + distance_real_fake + < np.expand_dims(real_nearest_neighbour_distances, axis=1) + ).sum(axis=0).mean() + + coverage = ( + distance_real_fake.min(axis=1) < real_nearest_neighbour_distances + ).mean() + + return precision, recall, density, coverage + + def manifold_metrics(self, real_features, fake_features, nearest_k, num_splits=20): + """ + Computes precision, recall, density, and coverage given two manifolds. + Args: + real_features: torch.Tensor([N, feature_dim], dtype=torch.float32) + fake_features: torch.Tensor([N, feature_dim], dtype=torch.float32) + nearest_k: int. + num_splits: int. Number of splits to use for computing metrics. + Returns: + dict of precision, recall, density, and coverage. + """ + real_features = real_features.chunk(num_splits, dim=0) + fake_features = fake_features.chunk(num_splits, dim=0) + precision, recall, density, coverage = [], [], [], [] + for real, fake in tqdm( + zip(real_features, fake_features), desc="Computing manifold" + ): + p, r, d, c = self.compute_prdc( + real.cpu().numpy(), fake.cpu().numpy(), nearest_k=nearest_k + ) + precision.append(torch.tensor(p, device=real.device)) + recall.append(torch.tensor(r, device=real.device)) + density.append(torch.tensor(d, device=real.device)) + coverage.append(torch.tensor(c, device=real.device)) + return ( + torch.stack(precision).mean().item(), + torch.stack(recall).mean().item(), + torch.stack(density).mean().item(), + torch.stack(coverage).mean().item(), + ) + + +def estimate_kde_mode(points): + kde = BatchedKDE() + kde.fit(points) + batch_size = points.shape[0] + X, Y, positions = batched_make_grid(points.cpu()) + X = X.to(points.device) + Y = Y.to(points.device) + positions = positions.to(points.device) + Z = kde.score(positions).reshape(X.shape) + + x_mode = X.reshape(batch_size, -1)[ + torch.arange(batch_size), Z.reshape(batch_size, -1).argmax(dim=1) + ] + y_mode = Y.reshape(batch_size, -1)[ + torch.arange(batch_size), Z.reshape(batch_size, -1).argmax(dim=1) + ] + return (x_mode, y_mode), kde + + +def make_grid(points): + (lat_min, long_min), _ = points.min(dim=-2) + (lat_max, long_max), _ = points.max(dim=-2) + x = torch.linspace(lat_min, lat_max, 100) + y = torch.linspace(long_min, long_max, 100) + X, Y = torch.meshgrid(x, y) + positions = torch.vstack([X.flatten(), Y.flatten()]).transpose(-1, -2) + return X, Y, positions + + +batched_make_grid = torch.vmap(make_grid) diff --git a/plonk/metrics/elo.py b/plonk/metrics/elo.py new file mode 100644 index 0000000000000000000000000000000000000000..d1dfe5a3686345f7a2748189c966e0577bf1dd9f --- /dev/null +++ b/plonk/metrics/elo.py @@ -0,0 +1,21 @@ +import os +import torch +from metrics.utils import haversine + +from torchmetrics import Metric + + +class HaversineELOMetric(Metric): + """ + Computes the ELO score of the current network given previous players + + Args: + previous_players_scores (str): path to the csv containing the scores of the previous players + previous_players_predictions (str): path to the folder containing the predictions of the previous players + tag (str): tag of the current experiment + + """ + + def __init__(self, cache_folder, tag): + ### TODO + pass diff --git a/plonk/metrics/utils.py b/plonk/metrics/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d365fc49f3c534a73a5a14cbc33d3c6f1d2fb599 --- /dev/null +++ b/plonk/metrics/utils.py @@ -0,0 +1,104 @@ +import torch +import reverse_geocoder +import numpy as np + + +def haversine(pred, gt): + # expects inputs to be np arrays in (lat, lon) format as radians + # N x 2 + + # calculate the difference in latitude and longitude between the predicted and ground truth points + lat_diff = pred[:, 0] - gt[:, 0] + lon_diff = pred[:, 1] - gt[:, 1] + + # calculate the haversine formula components + lhs = torch.sin(lat_diff / 2) ** 2 + rhs = torch.cos(pred[:, 0]) * torch.cos(gt[:, 0]) * torch.sin(lon_diff / 2) ** 2 + a = lhs + rhs + + # calculate the final distance using the haversine formula + c = 2 * torch.arctan2(torch.sqrt(a), torch.sqrt(1 - a)) + distance = 6371 * c + + return distance + +def haversine_np(pred, gt): + # expects inputs to be np arrays in (lat, lon) format as radians + # N x 2 + + # calculate the difference in latitude and longitude between the predicted and ground truth points + lat_diff = pred[0] - gt[0] + lon_diff = pred[1] - gt[1] + + # calculate the haversine formula components + lhs = np.sin(lat_diff / 2) ** 2 + rhs = np.cos(pred[0]) * np.cos(gt[0]) * np.sin(lon_diff / 2) ** 2 + a = lhs + rhs + + # calculate the final distance using the haversine formula + c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a)) + distance = 6371 * c + + return distance + + +def reverse(pred, gt, area): + df = {} + gt_area = {} + nan_mask = {} + areas = ["_".join(["unique", ar]) for ar in area] + if "unique_continent" in areas: + areas.remove("unique_continent") + for ar in areas: + inter = np.array(gt[ar]) + nan_mask[ar] = inter != "nan" + gt_area[ar] = inter[nan_mask[ar]] + location = reverse_geocoder.search( + [ + (lat, lon) + for lat, lon in zip( + np.degrees(pred[:, 0].cpu()), np.degrees(pred[:, 1].cpu()) + ) + ] + ) + if "continent" in area: + continent = torch.load("continent.pt") + inter = np.array([l.get("cc", "") for l in location])[ + nan_mask["unique_country"] + ] + df["continent"] = np.array([continent[i] for i in inter]) + gt_area["unique_continent"] = np.array( + [continent[i] for i in gt_area["unique_country"]] + ) + + if "country" in area: + df["country"] = np.array([l.get("cc", "") for l in location])[ + nan_mask["unique_country"] + ] + if "region" in area: + df["region"] = np.array( + ["_".join([l.get("admin1", ""), l.get("cc", "")]) for l in location] + )[nan_mask["unique_region"]] + if "sub-region" in area: + df["sub-region"] = np.array( + [ + "_".join([l.get("admin2", ""), l.get("admin1", ""), l.get("cc", "")]) + for l in location + ] + )[nan_mask["unique_sub-region"]] + if "city" in area: + df["city"] = np.array( + [ + "_".join( + [ + l.get("name", ""), + l.get("admin2", ""), + l.get("admin1", ""), + l.get("cc", ""), + ] + ) + for l in location + ] + )[nan_mask["unique_city"]] + + return df, gt_area diff --git a/plonk/models/__init__.py b/plonk/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7d0cf6b01d1b4a37d939c99564a626d2eaca162 --- /dev/null +++ b/plonk/models/__init__.py @@ -0,0 +1,2 @@ +# Empty file to make the directory a Python package +from .pretrained_models import Plonk diff --git a/plonk/models/losses.py b/plonk/models/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..e0532da6baad5d69b1c3728daf5d3a2f95a4b6af --- /dev/null +++ b/plonk/models/losses.py @@ -0,0 +1,155 @@ +import torch +from plonk.utils.manifolds import Sphere, geodesic +from torch.func import vjp, jvp, vmap, jacrev + + +class DDPMLoss: + def __init__( + self, + scheduler, + cond_drop_rate=0.0, + conditioning_key="label", + ): + self.scheduler = scheduler + self.cond_drop_rate = cond_drop_rate + self.conditioning_key = conditioning_key + + def __call__(self, preconditioning, network, batch, generator=None): + x_0 = batch["x_0"] + batch_size = x_0.shape[0] + device = x_0.device + t = torch.rand(batch_size, device=device, dtype=x_0.dtype, generator=generator) + gamma = self.scheduler(t).unsqueeze(-1) + n = torch.randn(x_0.shape, dtype=x_0.dtype, device=device, generator=generator) + y = torch.sqrt(gamma) * x_0 + torch.sqrt(1 - gamma) * n + batch["y"] = y + conditioning = batch[self.conditioning_key] + if conditioning is not None and self.cond_drop_rate > 0: + drop_mask = ( + torch.rand(batch_size, device=device, generator=generator) + < self.cond_drop_rate + ) + conditioning[drop_mask] = torch.zeros_like(conditioning[drop_mask]) + batch[self.conditioning_key] = conditioning.detach() + batch["gamma"] = gamma.squeeze(-1).squeeze(-1).squeeze(-1) + D_n = preconditioning(network, batch) + loss = (D_n - n) ** 2 + return loss + + +class FlowMatchingLoss: + def __init__( + self, + scheduler, + cond_drop_rate=0.0, + conditioning_key="label", + ): + self.scheduler = scheduler + self.cond_drop_rate = cond_drop_rate + self.conditioning_key = conditioning_key + + def __call__(self, preconditioning, network, batch, generator=None): + x_0 = batch["x_0"] + batch_size = x_0.shape[0] + device = x_0.device + t = torch.rand(batch_size, device=device, dtype=x_0.dtype, generator=generator) + gamma = self.scheduler(t).unsqueeze(-1) + n = torch.randn(x_0.shape, dtype=x_0.dtype, device=device, generator=generator) + y = gamma * x_0 + (1 - gamma) * n + batch["y"] = y + conditioning = batch[self.conditioning_key] + if conditioning is not None and self.cond_drop_rate > 0: + drop_mask = ( + torch.rand(batch_size, device=device, generator=generator) + < self.cond_drop_rate + ) + conditioning[drop_mask] = torch.zeros_like(conditioning[drop_mask]) + batch[self.conditioning_key] = conditioning.detach() + batch["gamma"] = gamma.squeeze(-1).squeeze(-1).squeeze(-1) + D_n = preconditioning(network, batch) + loss = (D_n - (x_0 - n)) ** 2 + return loss + + +class RiemannianFlowMatchingLoss: + def __init__( + self, + scheduler, + cond_drop_rate=0.0, + conditioning_key="label", + ): + self.scheduler = scheduler + self.cond_drop_rate = cond_drop_rate + self.conditioning_key = conditioning_key + self.manifold = Sphere() + self.manifold_dim = 3 + + def __call__(self, preconditioning, network, batch, generator=None): + x_1 = batch["x_0"] + batch_size = x_1.shape[0] + device = x_1.device + t = torch.rand(batch_size, device=device, dtype=x_1.dtype, generator=generator) + gamma = self.scheduler(t).unsqueeze(-1) + x_0 = self.manifold.random_base(x_1.shape[0], self.manifold_dim).to(x_1) + + def cond_u(x0, x1, t): + path = geodesic(self.manifold, x0, x1) + x_t, u_t = jvp(path, (t,), (torch.ones_like(t).to(t),)) + return x_t, u_t + + y, u_t = vmap(cond_u)(x_0, x_1, gamma) + y = y.reshape(batch_size, self.manifold_dim) + u_t = u_t.reshape(batch_size, self.manifold_dim) + batch["y"] = y + conditioning = batch[self.conditioning_key] + if conditioning is not None and self.cond_drop_rate > 0: + drop_mask = ( + torch.rand(batch_size, device=device, generator=generator) + < self.cond_drop_rate + ) + conditioning[drop_mask] = torch.zeros_like(conditioning[drop_mask]) + batch[self.conditioning_key] = conditioning.detach() + batch["gamma"] = gamma.squeeze(-1).squeeze(-1).squeeze(-1) + D_n = preconditioning(network, batch) + diff = D_n - u_t + loss = self.manifold.inner(y, diff, diff).mean() / self.manifold_dim + return loss + + +class VonFisherLoss: + def __init__(self, dim=3): + self.dim = dim + + def __call__(self, preconditioning, network, batch, generator=None): + x = batch["x_0"] + mu, kappa = preconditioning(network, batch) + loss = ( + torch.log((kappa + 1e-8)) + - torch.log(torch.tensor(4 * torch.pi, dtype=kappa.dtype)) + - log_sinh(kappa) + + kappa * (mu * x).sum(dim=-1, keepdim=True) + ) + return -loss + + +class VonFisherMixtureLoss: + def __init__(self, dim=3): + self.dim = dim + + def __call__(self, preconditioning, network, batch, generator=None): + x = batch["x_0"] + mu_mixture, kappa_mixture, weights = preconditioning(network, batch) + loss = 0 + for i in range(mu_mixture.shape[1]): + mu = mu_mixture[:, i] + kappa = kappa_mixture[:, i].unsqueeze(1) + loss += weights[:, i].unsqueeze(1) * ( + kappa + * torch.exp(kappa * ((mu * x).sum(dim=-1, keepdim=True) - 1)) + / (1e-8 + 2 * torch.pi * (1 - torch.exp(-2 * kappa))) + ) + return -torch.log(loss) + + +def log_sinh(x): + return x + torch.log(1e-8 + (1 - torch.exp(-2 * x)) / 2) diff --git a/plonk/models/module.py b/plonk/models/module.py new file mode 100755 index 0000000000000000000000000000000000000000..1357c87a523475a53b0ee86a3a13dc0f29f3c504 --- /dev/null +++ b/plonk/models/module.py @@ -0,0 +1,813 @@ +from typing import Any +import pytorch_lightning as L +import torch +import torch.nn as nn +from hydra.utils import instantiate +import copy +import pandas as pd +import numpy as np +from tqdm import tqdm +from plonk.utils.manifolds import Sphere +from torch.func import jacrev, vjp, vmap +from torchdiffeq import odeint +from geoopt import ProductManifold, Euclidean +from plonk.models.samplers.riemannian_flow_sampler import ode_riemannian_flow_sampler + + +class DiffGeolocalizer(L.LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.network = instantiate(cfg.network) + # self.network = torch.compile(self.network, fullgraph=True) + self.input_dim = cfg.network.input_dim + self.train_noise_scheduler = instantiate(cfg.train_noise_scheduler) + self.inference_noise_scheduler = instantiate(cfg.inference_noise_scheduler) + self.data_preprocessing = instantiate(cfg.data_preprocessing) + self.cond_preprocessing = instantiate(cfg.cond_preprocessing) + self.preconditioning = instantiate(cfg.preconditioning) + + self.ema_network = copy.deepcopy(self.network).requires_grad_(False) + self.ema_network.eval() + self.postprocessing = instantiate(cfg.postprocessing) + self.val_sampler = instantiate(cfg.val_sampler) + self.test_sampler = instantiate(cfg.test_sampler) + self.loss = instantiate(cfg.loss)( + self.train_noise_scheduler, + ) + self.val_metrics = instantiate(cfg.val_metrics) + self.test_metrics = instantiate(cfg.test_metrics) + self.manifold = instantiate(cfg.manifold) if hasattr(cfg, "manifold") else None + + self.interpolant = cfg.interpolant + + def training_step(self, batch, batch_idx): + with torch.no_grad(): + batch = self.data_preprocessing(batch) + batch = self.cond_preprocessing(batch) + batch_size = batch["x_0"].shape[0] + loss = self.loss(self.preconditioning, self.network, batch).mean() + self.log( + "train/loss", + loss, + sync_dist=True, + on_step=True, + on_epoch=True, + batch_size=batch_size, + ) + return loss + + def on_before_optimizer_step(self, optimizer): + if self.global_step == 0: + no_grad = [] + for name, param in self.network.named_parameters(): + if param.grad is None: + no_grad.append(name) + if len(no_grad) > 0: + print("Parameters without grad:") + print(no_grad) + + def on_validation_start(self): + self.validation_generator = torch.Generator(device=self.device).manual_seed( + 3407 + ) + self.validation_generator_ema = torch.Generator(device=self.device).manual_seed( + 3407 + ) + + def validation_step(self, batch, batch_idx): + batch = self.data_preprocessing(batch) + batch = self.cond_preprocessing(batch) + batch_size = batch["x_0"].shape[0] + loss = self.loss( + self.preconditioning, + self.network, + batch, + generator=self.validation_generator, + ).mean() + self.log( + "val/loss", + loss, + sync_dist=True, + on_step=False, + on_epoch=True, + batch_size=batch_size, + ) + if hasattr(self, "ema_model"): + loss_ema = self.loss( + self.preconditioning, + self.ema_network, + batch, + generator=self.validation_generator_ema, + ).mean() + self.log( + "val/loss_ema", + loss_ema, + sync_dist=True, + on_step=False, + on_epoch=True, + batch_size=batch_size, + ) + # nll = -self.compute_exact_loglikelihood(batch).mean() + # self.log( + # "val/nll", + # nll, + # sync_dist=True, + # on_step=False, + # on_epoch=True, + # batch_size=batch_size, + # ) + + # def on_validation_epoch_end(self): + # metrics = self.val_metrics.compute() + # for metric_name, metric_value in metrics.items(): + # self.log( + # f"val/{metric_name}", + # metric_value, + # sync_dist=True, + # on_step=False, + # on_epoch=True, + # ) + + def on_test_start(self): + self.test_generator = torch.Generator(device=self.device).manual_seed(3407) + + def test_step_simple(self, batch, batch_idx): + batch = self.data_preprocessing(batch) + batch = self.cond_preprocessing(batch) + batch_size = batch["x_0"].shape[0] + if isinstance(self.manifold, Sphere): + x_N = self.manifold.random_base( + batch_size, + self.input_dim, + device=self.device, + ) + x_N = x_N.reshape(batch_size, self.input_dim) + else: + x_N = torch.randn( + batch_size, + self.input_dim, + device=self.device, + generator=self.test_generator, + ) + cond = batch[self.cfg.cond_preprocessing.output_key] + + samples = self.sample( + x_N=x_N, + cond=cond, + stage="val", + generator=self.test_generator, + cfg=self.cfg.cfg_rate, + ) + self.test_metrics.update({"gps": samples}, batch) + if self.cfg.compute_nll: + nll = -self.compute_exact_loglikelihood(batch, cfg=0).mean() + self.log( + "test/NLL", + nll, + sync_dist=True, + on_step=False, + on_epoch=True, + batch_size=batch_size, + ) + + def test_best_nll(self, batch, batch_idx): + batch = self.data_preprocessing(batch) + batch = self.cond_preprocessing(batch) + batch_size = batch["x_0"].shape[0] + num_sample_per_cond = 32 + if isinstance(self.manifold, Sphere): + x_N = self.manifold.random_base( + batch_size * num_sample_per_cond, + self.input_dim, + device=self.device, + ) + x_N = x_N.reshape(batch_size * num_sample_per_cond, self.input_dim) + else: + x_N = torch.randn( + batch_size * num_sample_per_cond, + self.input_dim, + device=self.device, + generator=self.test_generator, + ) + cond = ( + batch[self.cfg.cond_preprocessing.output_key] + .unsqueeze(1) + .repeat(1, num_sample_per_cond, 1) + .view(-1, batch[self.cfg.cond_preprocessing.output_key].shape[-1]) + ) + samples = self.sample_distribution( + x_N, + cond, + sampling_batch_size=32768, + stage="val", + generator=self.test_generator, + cfg=0, + ) + samples = samples.view(batch_size * num_sample_per_cond, -1) + batch_swarm = {"gps": samples, "emb": cond} + nll_batch = -self.compute_exact_loglikelihood(batch_swarm, cfg=0) + nll_batch = nll_batch.view(batch_size, num_sample_per_cond, -1) + nll_best = nll_batch[ + torch.arange(batch_size), nll_batch.argmin(dim=1).squeeze(1) + ] + self.log( + "test/best_nll", + nll_best.mean(), + sync_dist=True, + on_step=False, + on_epoch=True, + ) + samples = samples.view(batch_size, num_sample_per_cond, -1)[ + torch.arange(batch_size), nll_batch.argmin(dim=1).squeeze(1) + ] + self.test_metrics.update({"gps": samples}, batch) + + def test_step(self, batch, batch_idx): + if self.cfg.compute_swarms: + self.test_best_nll(batch, batch_idx) + else: + self.test_step_simple(batch, batch_idx) + + def on_test_epoch_end(self): + metrics = self.test_metrics.compute() + for metric_name, metric_value in metrics.items(): + self.log( + f"test/{metric_name}", + metric_value, + sync_dist=True, + on_step=False, + on_epoch=True, + ) + + def configure_optimizers(self): + if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay: + parameters_names_wd = get_parameter_names(self.network, [nn.LayerNorm]) + parameters_names_wd = [ + name for name in parameters_names_wd if "bias" not in name + ] + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in self.network.named_parameters() + if n in parameters_names_wd + ], + "weight_decay": self.cfg.optimizer.optim.weight_decay, + "layer_adaptation": True, + }, + { + "params": [ + p + for n, p in self.network.named_parameters() + if n not in parameters_names_wd + ], + "weight_decay": 0.0, + "layer_adaptation": False, + }, + ] + optimizer = instantiate( + self.cfg.optimizer.optim, optimizer_grouped_parameters + ) + else: + optimizer = instantiate(self.cfg.optimizer.optim, self.network.parameters()) + if "lr_scheduler" in self.cfg: + scheduler = instantiate(self.cfg.lr_scheduler)(optimizer) + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] + else: + return optimizer + + def lr_scheduler_step(self, scheduler, metric): + scheduler.step(self.global_step) + + def sample( + self, + batch_size=None, + cond=None, + x_N=None, + num_steps=None, + stage="test", + cfg=0, + generator=None, + return_trajectories=False, + postprocessing=True, + ): + if x_N is None: + assert batch_size is not None + if isinstance(self.manifold, Sphere): + x_N = self.manifold.random_base( + batch_size, self.input_dim, device=self.device + ) + x_N = x_N.reshape(batch_size, self.input_dim) + else: + x_N = torch.randn(batch_size, self.input_dim, device=self.device) + batch = {"y": x_N} + if stage == "val": + sampler = self.val_sampler + elif stage == "test": + sampler = self.test_sampler + else: + raise ValueError(f"Unknown stage {stage}") + batch[self.cfg.cond_preprocessing.input_key] = cond + batch = self.cond_preprocessing(batch, device=self.device) + if num_steps is None: + output = sampler( + self.ema_model, + batch, + conditioning_keys=self.cfg.cond_preprocessing.output_key, + scheduler=self.inference_noise_scheduler, + cfg_rate=cfg, + generator=generator, + return_trajectories=return_trajectories, + ) + else: + output = sampler( + self.ema_model, + batch, + conditioning_keys=self.cfg.cond_preprocessing.output_key, + scheduler=self.inference_noise_scheduler, + num_steps=num_steps, + cfg_rate=cfg, + generator=generator, + return_trajectories=return_trajectories, + ) + if return_trajectories: + return ( + self.postprocessing(output[0]) if postprocessing else output[0], + [ + self.postprocessing(frame) if postprocessing else frame + for frame in output[1] + ], + ) + else: + return self.postprocessing(output) if postprocessing else output + + def sample_distribution( + self, + x_N, + cond, + sampling_batch_size=2048, + num_steps=None, + stage="test", + cfg=0, + generator=None, + return_trajectories=False, + ): + if return_trajectories: + x_0 = [] + trajectories = [] + i = -1 + for i in range(x_N.shape[0] // sampling_batch_size): + x_N_batch = x_N[i * sampling_batch_size : (i + 1) * sampling_batch_size] + cond_batch = cond[ + i * sampling_batch_size : (i + 1) * sampling_batch_size + ] + out, trajectories = self.sample( + cond=cond_batch, + x_N=x_N_batch, + num_steps=num_steps, + stage=stage, + cfg=cfg, + generator=generator, + return_trajectories=return_trajectories, + ) + x_0.append(out) + trajectories.append(trajectories) + if x_N.shape[0] % sampling_batch_size != 0: + x_N_batch = x_N[(i + 1) * sampling_batch_size :] + cond_batch = cond[(i + 1) * sampling_batch_size :] + out, trajectories = self.sample( + cond=cond_batch, + x_N=x_N_batch, + num_steps=num_steps, + stage=stage, + cfg=cfg, + generator=generator, + return_trajectories=return_trajectories, + ) + x_0.append(out) + trajectories.append(trajectories) + x_0 = torch.cat(x_0, dim=1) + trajectories = [torch.cat(frame, dim=1) for frame in trajectories] + return x_0, trajectories + else: + x_0 = [] + i = -1 + for i in range(x_N.shape[0] // sampling_batch_size): + x_N_batch = x_N[i * sampling_batch_size : (i + 1) * sampling_batch_size] + cond_batch = cond[ + i * sampling_batch_size : (i + 1) * sampling_batch_size + ] + out = self.sample( + cond=cond_batch, + x_N=x_N_batch, + num_steps=num_steps, + stage=stage, + cfg=cfg, + generator=generator, + return_trajectories=return_trajectories, + ) + x_0.append(out) + if x_N.shape[0] % sampling_batch_size != 0: + x_N_batch = x_N[(i + 1) * sampling_batch_size :] + cond_batch = cond[(i + 1) * sampling_batch_size :] + out = self.sample( + cond=cond_batch, + x_N=x_N_batch, + num_steps=num_steps, + stage=stage, + cfg=cfg, + generator=generator, + return_trajectories=return_trajectories, + ) + x_0.append(out) + x_0 = torch.cat(x_0, dim=0) + return x_0 + + def model(self, *args, **kwargs): + return self.preconditioning(self.network, *args, **kwargs) + + def ema_model(self, *args, **kwargs): + return self.preconditioning(self.ema_network, *args, **kwargs) + + def compute_exact_loglikelihood( + self, + batch=None, + x_1=None, + cond=None, + t1=1.0, + num_steps=1000, + rademacher=False, + data_preprocessing=True, + cfg=0, + ): + nfe = [0] + if batch is None: + batch = {"x_0": x_1, "emb": cond} + if data_preprocessing: + batch = self.data_preprocessing(batch) + batch = self.cond_preprocessing(batch) + timesteps = self.inference_noise_scheduler( + torch.linspace(0, t1, 2).to(batch["x_0"]) + ) + with torch.inference_mode(mode=False): + + def odefunc(t, tensor): + nfe[0] += 1 + t = t.to(tensor) + gamma = self.inference_noise_scheduler(t) + x = tensor[..., : self.input_dim] + y = batch["emb"] + + def vecfield(x, y): + if cfg > 0: + batch_vecfield = { + "y": x, + "emb": y, + "gamma": gamma.reshape(-1), + } + model_output_cond = self.ema_model(batch_vecfield) + batch_vecfield_uncond = { + "y": x, + "emb": torch.zeros_like(y), + "gamma": gamma.reshape(-1), + } + model_output_uncond = self.ema_model(batch_vecfield_uncond) + model_output = model_output_cond + cfg * ( + model_output_cond - model_output_uncond + ) + + else: + batch_vecfield = { + "y": x, + "emb": y, + "gamma": gamma.reshape(-1), + } + model_output = self.ema_model(batch_vecfield) + + if self.interpolant == "flow_matching": + d_gamma = self.inference_noise_scheduler.derivative(t).reshape( + -1, 1 + ) + return d_gamma * model_output + elif self.interpolant == "diffusion": + alpha_t = self.inference_noise_scheduler.alpha(t).reshape(-1, 1) + return ( + -1 / 2 * (alpha_t * x - torch.abs(alpha_t) * model_output) + ) + else: + raise ValueError(f"Unknown interpolant {self.interpolant}") + + if rademacher: + v = torch.randint_like(x, 2) * 2 - 1 + else: + v = None + dx, div = output_and_div(vecfield, x, y, v=v) + div = div.reshape(-1, 1) + del t, x + return torch.cat([dx, div], dim=-1) + + x_1 = batch["x_0"] + state1 = torch.cat([x_1, torch.zeros_like(x_1[..., :1])], dim=-1) + with torch.no_grad(): + if False and isinstance(self.manifold, Sphere): + print("Riemannian flow sampler") + product_man = ProductManifold( + (self.manifold, self.input_dim), (Euclidean(), 1) + ) + state0 = ode_riemannian_flow_sampler( + odefunc, + state1, + manifold=product_man, + scheduler=self.inference_noise_scheduler, + num_steps=num_steps, + ) + else: + print("ODE solver") + state0 = odeint( + odefunc, + state1, + t=torch.linspace(0, t1, 2).to(batch["x_0"]), + atol=1e-6, + rtol=1e-6, + method="dopri5", + options={"min_step": 1e-5}, + )[-1] + x_0, logdetjac = state0[..., : self.input_dim], state0[..., -1] + if self.manifold is not None: + x_0 = self.manifold.projx(x_0) + logp0 = self.manifold.base_logprob(x_0) + else: + logp0 = ( + -1 / 2 * (x_0**2).sum(dim=-1) + - self.input_dim + * torch.log(torch.tensor(2 * np.pi, device=x_0.device)) + / 2 + ) + print(f"nfe: {nfe[0]}") + logp1 = logp0 + logdetjac + logp1 = logp1 / (self.input_dim * np.log(2)) + return logp1 + + +def get_parameter_names(model, forbidden_layer_types): + """ + Returns the names of the model parameters that are not inside a forbidden layer. + Taken from HuggingFace transformers. + """ + result = [] + for name, child in model.named_children(): + result += [ + f"{name}.{n}" + for n in get_parameter_names(child, forbidden_layer_types) + if not isinstance(child, tuple(forbidden_layer_types)) + ] + # Add model specific parameters (defined with nn.Parameter) since they are not in any child. + result += list(model._parameters.keys()) + return result + + +# for likelihood computation +def div_fn(u): + """Accepts a function u:R^D -> R^D.""" + J = jacrev(u, argnums=0) + return lambda x, y: torch.trace(J(x, y).squeeze(0)) + + +def output_and_div(vecfield, x, y, v=None): + if v is None: + dx = vecfield(x, y) + div = vmap(div_fn(vecfield))(x, y) + else: + vecfield_x = lambda x: vecfield(x, y) + dx, vjpfunc = vjp(vecfield_x, x) + vJ = vjpfunc(v)[0] + div = torch.sum(vJ * v, dim=-1) + return dx, div + + +class VonFisherGeolocalizer(L.LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.network = instantiate(cfg.network) + # self.network = torch.compile(self.network, fullgraph=True) + self.input_dim = cfg.network.input_dim + self.data_preprocessing = instantiate(cfg.data_preprocessing) + self.cond_preprocessing = instantiate(cfg.cond_preprocessing) + self.preconditioning = instantiate(cfg.preconditioning) + + self.ema_network = copy.deepcopy(self.network).requires_grad_(False) + self.ema_network.eval() + self.postprocessing = instantiate(cfg.postprocessing) + self.val_sampler = instantiate(cfg.val_sampler) + self.test_sampler = instantiate(cfg.test_sampler) + self.loss = instantiate(cfg.loss)() + self.val_metrics = instantiate(cfg.val_metrics) + self.test_metrics = instantiate(cfg.test_metrics) + + def training_step(self, batch, batch_idx): + with torch.no_grad(): + batch = self.data_preprocessing(batch) + batch = self.cond_preprocessing(batch) + batch_size = batch["x_0"].shape[0] + loss = self.loss(self.preconditioning, self.network, batch).mean() + self.log( + "train/loss", + loss, + sync_dist=True, + on_step=True, + on_epoch=True, + batch_size=batch_size, + ) + return loss + + def on_before_optimizer_step(self, optimizer): + if self.global_step == 0: + no_grad = [] + for name, param in self.network.named_parameters(): + if param.grad is None: + no_grad.append(name) + if len(no_grad) > 0: + print("Parameters without grad:") + print(no_grad) + + def on_validation_start(self): + self.validation_generator = torch.Generator(device=self.device).manual_seed( + 3407 + ) + self.validation_generator_ema = torch.Generator(device=self.device).manual_seed( + 3407 + ) + + def validation_step(self, batch, batch_idx): + batch = self.data_preprocessing(batch) + batch = self.cond_preprocessing(batch) + batch_size = batch["x_0"].shape[0] + loss = self.loss( + self.preconditioning, + self.network, + batch, + generator=self.validation_generator, + ).mean() + self.log( + "val/loss", + loss, + sync_dist=True, + on_step=False, + on_epoch=True, + batch_size=batch_size, + ) + if hasattr(self, "ema_model"): + loss_ema = self.loss( + self.preconditioning, + self.ema_network, + batch, + generator=self.validation_generator_ema, + ).mean() + self.log( + "val/loss_ema", + loss_ema, + sync_dist=True, + on_step=False, + on_epoch=True, + batch_size=batch_size, + ) + + def on_test_start(self): + self.test_generator = torch.Generator(device=self.device).manual_seed(3407) + + def test_step(self, batch, batch_idx): + batch = self.data_preprocessing(batch) + batch = self.cond_preprocessing(batch) + batch_size = batch["x_0"].shape[0] + cond = batch[self.cfg.cond_preprocessing.output_key] + + samples = self.sample(cond=cond, stage="test") + self.test_metrics.update({"gps": samples}, batch) + nll = -self.compute_exact_loglikelihood(batch).mean() + self.log( + "test/NLL", + nll, + sync_dist=True, + on_step=False, + on_epoch=True, + batch_size=batch_size, + ) + + def on_test_epoch_end(self): + metrics = self.test_metrics.compute() + for metric_name, metric_value in metrics.items(): + self.log( + f"test/{metric_name}", + metric_value, + sync_dist=True, + on_step=False, + on_epoch=True, + ) + + def configure_optimizers(self): + if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay: + parameters_names_wd = get_parameter_names(self.network, [nn.LayerNorm]) + parameters_names_wd = [ + name for name in parameters_names_wd if "bias" not in name + ] + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in self.network.named_parameters() + if n in parameters_names_wd + ], + "weight_decay": self.cfg.optimizer.optim.weight_decay, + "layer_adaptation": True, + }, + { + "params": [ + p + for n, p in self.network.named_parameters() + if n not in parameters_names_wd + ], + "weight_decay": 0.0, + "layer_adaptation": False, + }, + ] + optimizer = instantiate( + self.cfg.optimizer.optim, optimizer_grouped_parameters + ) + else: + optimizer = instantiate(self.cfg.optimizer.optim, self.network.parameters()) + if "lr_scheduler" in self.cfg: + scheduler = instantiate(self.cfg.lr_scheduler)(optimizer) + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] + else: + return optimizer + + def lr_scheduler_step(self, scheduler, metric): + scheduler.step(self.global_step) + + def sample( + self, + batch_size=None, + cond=None, + postprocessing=True, + stage="val", + ): + batch = {} + if stage == "val": + sampler = self.val_sampler + elif stage == "test": + sampler = self.test_sampler + else: + raise ValueError(f"Unknown stage {stage}") + batch[self.cfg.cond_preprocessing.input_key] = cond + batch = self.cond_preprocessing(batch, device=self.device) + output = sampler( + self.ema_model, + batch, + ) + return self.postprocessing(output) if postprocessing else output + + def model(self, *args, **kwargs): + return self.preconditioning(self.network, *args, **kwargs) + + def ema_model(self, *args, **kwargs): + return self.preconditioning(self.ema_network, *args, **kwargs) + + def compute_exact_loglikelihood( + self, + batch=None, + ): + batch = self.data_preprocessing(batch) + batch = self.cond_preprocessing(batch) + return -self.loss(self.preconditioning, self.ema_network, batch) + + +class RandomGeolocalizer(L.LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.test_metrics = instantiate(cfg.test_metrics) + self.data_preprocessing = instantiate(cfg.data_preprocessing) + self.cond_preprocessing = instantiate(cfg.cond_preprocessing) + self.postprocessing = instantiate(cfg.postprocessing) + + def test_step(self, batch, batch_idx): + batch = self.data_preprocessing(batch) + batch = self.cond_preprocessing(batch) + batch_size = batch["x_0"].shape[0] + samples = torch.randn(batch_size, 3, device=self.device) + samples = samples / samples.norm(dim=-1, keepdim=True) + samples = self.postprocessing(samples) + self.test_metrics.update({"gps": samples}, batch) + + def on_test_epoch_end(self): + metrics = self.test_metrics.compute() + for metric_name, metric_value in metrics.items(): + self.log( + f"test/{metric_name}", + metric_value, + sync_dist=True, + on_step=False, + on_epoch=True, + ) diff --git a/plonk/models/networks/__init__.py b/plonk/models/networks/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/plonk/models/networks/mlp.py b/plonk/models/networks/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..d216aa8f5a85de50160155dd604ba04bd5438394 --- /dev/null +++ b/plonk/models/networks/mlp.py @@ -0,0 +1,190 @@ +import torch.nn as nn +from plonk.models.positional_embeddings import FourierEmbedding, PositionalEmbedding +from plonk.models.networks.transformers import FusedMLP +import torch +import torch.nn.functional as F +import numpy as np +from einops import rearrange + + +class TimeEmbedder(nn.Module): + def __init__( + self, + noise_embedding_type: str, + dim: int, + time_scaling: float, + expansion: int = 4, + ): + super().__init__() + self.encode_time = ( + PositionalEmbedding(num_channels=dim, endpoint=True) + if noise_embedding_type == "positional" + else FourierEmbedding(num_channels=dim) + ) + self.time_scaling = time_scaling + self.map_time = nn.Sequential( + nn.Linear(dim, dim * expansion), + nn.SiLU(), + nn.Linear(dim * expansion, dim * expansion), + ) + + def forward(self, t): + time = self.encode_time(t * self.time_scaling) + time_mean = time.mean(dim=-1, keepdim=True) + time_std = time.std(dim=-1, keepdim=True) + time = (time - time_mean) / time_std + return self.map_time(time) + + +def get_timestep_embedding(timesteps, embedding_dim, dtype=torch.float32): + assert len(timesteps.shape) == 1 + timesteps = timesteps * 1000.0 + + half_dim = embedding_dim // 2 + emb = np.log(10000) / (half_dim - 1) + emb = (torch.arange(half_dim, dtype=dtype, device=timesteps.device) * -emb).exp() + emb = timesteps.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=-1) + if embedding_dim % 2 == 1: # zero pad + emb = F.pad(emb, (0, 1)) + assert emb.shape == (timesteps.shape[0], embedding_dim) + return emb + + +class AdaLNMLPBlock(nn.Module): + def __init__(self, dim, expansion): + super().__init__() + self.mlp = FusedMLP( + dim, dropout=0.0, hidden_layer_multiplier=expansion, activation=nn.GELU + ) + self.ada_map = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 3)) + self.ln = nn.LayerNorm(dim, elementwise_affine=False) + + nn.init.zeros_(self.mlp[-1].weight) + nn.init.zeros_(self.mlp[-1].bias) + + def forward(self, x, y): + gamma, mu, sigma = self.ada_map(y).chunk(3, dim=-1) + x_res = (1 + gamma) * self.ln(x) + mu + x = x + self.mlp(x_res) * sigma + return x + + +class GeoAdaLNMLP(nn.Module): + def __init__(self, input_dim, dim, depth, expansion, cond_dim): + super().__init__() + self.time_embedder = TimeEmbedder("positional", dim // 4, 1000, expansion=4) + self.cond_mapper = nn.Linear(cond_dim, dim) + self.initial_mapper = nn.Linear(input_dim, dim) + self.blocks = nn.ModuleList( + [AdaLNMLPBlock(dim, expansion) for _ in range(depth)] + ) + self.final_adaln = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, dim * 2), + ) + self.final_ln = nn.LayerNorm(dim, elementwise_affine=False) + self.final_linear = nn.Linear(dim, input_dim) + + def forward(self, batch): + x = batch["y"] + x = self.initial_mapper(x) + gamma = batch["gamma"] + cond = batch["emb"] + t = self.time_embedder(gamma) + cond = self.cond_mapper(cond) + cond = cond + t + for block in self.blocks: + x = block(x, cond) + gamma_last, mu_last = self.final_adaln(cond).chunk(2, dim=-1) + x = (1 + gamma_last) * self.final_ln(x) + mu_last + x = self.final_linear(x) + return x + + +class GeoAdaLNMLPVonFisher(nn.Module): + def __init__(self, input_dim, dim, depth, expansion, cond_dim): + super().__init__() + self.cond_mapper = nn.Linear(cond_dim, dim) + self.blocks = nn.ModuleList( + [AdaLNMLPBlock(dim, expansion) for _ in range(depth)] + ) + self.final_adaln = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, dim * 2), + ) + self.final_ln = nn.LayerNorm(dim, elementwise_affine=False) + self.mu_predictor = nn.Sequential( + FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), + nn.Linear(dim, input_dim), + ) + self.kappa_predictor = nn.Sequential( + FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), + nn.Linear(dim, 1), + torch.nn.Softplus(), + ) + self.init_registers = torch.nn.Parameter(torch.randn(dim), requires_grad=True) + torch.nn.init.trunc_normal_( + self.init_registers, std=0.02, a=-2 * 0.02, b=2 * 0.02 + ) + + def forward(self, batch): + cond = batch["emb"] + cond = self.cond_mapper(cond) + x = self.init_registers.unsqueeze(0).repeat(cond.shape[0], 1) + for block in self.blocks: + x = block(x, cond) + gamma_last, mu_last = self.final_adaln(cond).chunk(2, dim=-1) + x = (1 + gamma_last) * self.final_ln(x) + mu_last + mu = self.mu_predictor(x) + mu = mu / mu.norm(dim=-1, keepdim=True) + kappa = self.kappa_predictor(x) + return mu, kappa + + +class GeoAdaLNMLPVonFisherMixture(nn.Module): + def __init__(self, input_dim, dim, depth, expansion, cond_dim, num_mixtures=3): + super().__init__() + self.cond_mapper = nn.Linear(cond_dim, dim) + self.blocks = nn.ModuleList( + [AdaLNMLPBlock(dim, expansion) for _ in range(depth)] + ) + self.final_adaln = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, dim * 2), + ) + self.final_ln = nn.LayerNorm(dim, elementwise_affine=False) + self.mu_predictor = nn.Sequential( + FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), + nn.Linear(dim, input_dim * num_mixtures), + ) + self.kappa_predictor = nn.Sequential( + FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), + nn.Linear(dim, num_mixtures), + torch.nn.Softplus(), + ) + self.mixture_weights = nn.Sequential( + FusedMLP(dim, dropout=0.0, hidden_layer_multiplier=2, activation=nn.GELU), + nn.Linear(dim, num_mixtures), + torch.nn.Softmax(dim=-1), + ) + self.num_mixtures = num_mixtures + self.init_registers = torch.nn.Parameter(torch.randn(dim), requires_grad=True) + torch.nn.init.trunc_normal_( + self.init_registers, std=0.02, a=-2 * 0.02, b=2 * 0.02 + ) + + def forward(self, batch): + cond = batch["emb"] + cond = self.cond_mapper(cond) + x = self.init_registers.unsqueeze(0).repeat(cond.shape[0], 1) + for block in self.blocks: + x = block(x, cond) + gamma_last, mu_last = self.final_adaln(cond).chunk(2, dim=-1) + x = (1 + gamma_last) * self.final_ln(x) + mu_last + mu = self.mu_predictor(x) + mu = rearrange(mu, "b (n d) -> b n d", n=self.num_mixtures) + mu = mu / mu.norm(dim=-1, keepdim=True) + kappa = self.kappa_predictor(x) + weights = self.mixture_weights(x) + return mu, kappa, weights diff --git a/plonk/models/networks/transformers.py b/plonk/models/networks/transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..ff07d8eaed348e3e12d6d4160f4caf74d90a3647 --- /dev/null +++ b/plonk/models/networks/transformers.py @@ -0,0 +1,329 @@ +import torch +import torch.nn as nn +from torch import Tensor +import math + +from plonk.models.positional_embeddings import PositionalEmbedding, FourierEmbedding +from einops import rearrange + +torch.fx.wrap("rearrange") +from typing import Tuple, Optional +from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1 + +allow_ops_in_compiled_graph() + + +class FusedMLP(nn.Sequential): + def __init__( + self, + dim_model: int, + dropout: float, + activation: nn.Module, + hidden_layer_multiplier: int = 4, + bias: bool = True, + ): + super().__init__( + nn.Linear(dim_model, dim_model * hidden_layer_multiplier, bias=bias), + activation(), + nn.Dropout(dropout), + nn.Linear(dim_model * hidden_layer_multiplier, dim_model, bias=bias), + ) + + +def _cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == "cuda": + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == "cpu": + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor + + +class LayerNorm16Bits(torch.nn.LayerNorm): + """ + 16-bit friendly version of torch.nn.LayerNorm + """ + + def __init__( + self, + normalized_shape, + eps=1e-06, + elementwise_affine=True, + device=None, + dtype=None, + ): + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + device=device, + dtype=dtype, + ) + + def forward(self, x): + module_device = x.device + downcast_x = _cast_if_autocast_enabled(x) + downcast_weight = ( + _cast_if_autocast_enabled(self.weight) + if self.weight is not None + else self.weight + ) + downcast_bias = ( + _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + ) + with torch.autocast(enabled=False, device_type=module_device.type): + return nn.functional.layer_norm( + downcast_x, + self.normalized_shape, + downcast_weight, + downcast_bias, + self.eps, + ) + + +class StochatichDepth(nn.Module): + def __init__(self, p: float): + super().__init__() + self.survival_prob = 1.0 - p + + def forward(self, x: Tensor) -> Tensor: + if self.training and self.survival_prob < 1: + mask = ( + torch.empty(x.shape[0], 1, 1, device=x.device).uniform_() + + self.survival_prob + ) + mask = mask.floor() + if self.survival_prob > 0: + mask = mask / self.survival_prob + return x * mask + else: + return x + + +class CrossAttentionOp(nn.Module): + def __init__( + self, attention_dim, num_heads, dim_q, dim_kv, use_biases=True, is_sa=False + ): + super().__init__() + self.dim_q = dim_q + self.dim_kv = dim_kv + self.attention_dim = attention_dim + self.num_heads = num_heads + self.use_biases = use_biases + self.is_sa = is_sa + if self.is_sa: + self.qkv = nn.Linear(dim_q, attention_dim * 3, bias=use_biases) + else: + self.q = nn.Linear(dim_q, attention_dim, bias=use_biases) + self.kv = nn.Linear(dim_kv, attention_dim * 2, bias=use_biases) + self.out = nn.Linear(attention_dim, dim_q, bias=use_biases) + + def forward(self, x_to, x_from=None, attention_mask=None, materialize_sdpa=False): + if x_from is None: + x_from = x_to + if self.is_sa: + q, k, v = self.qkv(x_to).chunk(3, dim=-1) + else: + q = self.q(x_to) + k, v = self.kv(x_from).chunk(2, dim=-1) + q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads) + k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads) + v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads) + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(1) + if materialize_sdpa: + x = self.materialize_sdpa(q, k, v, attention_mask) + else: + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask + ) + x = rearrange(x, "b h n d -> b n (h d)") + x = self.out(x) + return x + + def materialize_sdpa(self, q, k, v, attn_mask=None): + scale = 1.0 / math.sqrt(q.shape[-1]) + + attn_matrix = torch.einsum("b h i d, b h j d -> b h i j", q, k) * scale + if attn_mask is not None: + attn_matrix = attn_matrix * attn_mask + attn_matrix = torch.nn.functional.softmax(attn_matrix, dim=-1) + return torch.einsum("b h i j, b h j d -> b h i d", attn_matrix, v) + + +class CrossAttentionBlock(nn.Module): + def __init__( + self, + dim_q: int, + dim_kv: int, + num_heads: int, + attention_dim: int = 0, + mlp_multiplier: int = 4, + dropout: float = 0.0, + stochastic_depth: float = 0.0, + use_biases: bool = True, + retrieve_attention_scores: bool = False, + use_16_bits_layer_norm: bool = False, + ): + super().__init__() + if use_16_bits_layer_norm and not retrieve_attention_scores: + LayerNorm = LayerNorm16Bits + else: + LayerNorm = nn.LayerNorm + self.retrieve_attention_scores = retrieve_attention_scores + self.initial_to_ln = LayerNorm(dim_q, eps=1e-6) + attention_dim = min(dim_q, dim_kv) if attention_dim == 0 else attention_dim + self.ca = CrossAttentionOp( + attention_dim, num_heads, dim_q, dim_kv, is_sa=False, use_biases=use_biases + ) + self.ca_stochastic_depth = StochatichDepth(stochastic_depth) + self.middle_ln = LayerNorm(dim_q, eps=1e-6) + self.ffn = FusedMLP( + dim_model=dim_q, + dropout=dropout, + activation=nn.GELU, + hidden_layer_multiplier=mlp_multiplier, + bias=use_biases, + ) + self.ffn_stochastic_depth = StochatichDepth(stochastic_depth) + + self.register_parameter( + "attention_mask_dummy", + nn.Parameter(torch.ones(1, 1, dtype=torch.bool), requires_grad=False), + ) + + def forward( + self, + to_tokens: Tensor, + from_tokens: Tensor, + to_token_mask: Optional[Tensor] = None, + from_token_mask: Optional[Tensor] = None, + ) -> Tensor: + if to_token_mask is None and from_token_mask is None: + attention_mask = None + else: + if to_token_mask is None: + to_token_mask = self.attention_mask_dummy.expand( + to_tokens.shape[0], + to_tokens.shape[1], + ) + if from_token_mask is None: + from_token_mask = self.attention_mask_dummy.expand( + from_tokens.shape[0], + from_tokens.shape[1], + ) + attention_mask = from_token_mask.unsqueeze(1) * to_token_mask.unsqueeze(2) + if self.retrieve_attention_scores: + attention_output = self.ca( + self.initial_to_ln(to_tokens), + from_tokens, + attention_mask=attention_mask, + materialize_sdpa=True, + ) + else: + attention_output = self.ca( + self.initial_to_ln(to_tokens), + from_tokens, + attention_mask=attention_mask, + ) + to_tokens = to_tokens + self.ca_stochastic_depth(attention_output) + to_tokens = to_tokens + self.ffn_stochastic_depth( + self.ffn(self.middle_ln(to_tokens)) + ) + return to_tokens + + +class SelfAttentionBlock(nn.Module): + def __init__( + self, + dim_qkv: int, + num_heads: int, + attention_dim: int = 0, + mlp_multiplier: int = 4, + dropout: float = 0.0, + stochastic_depth: float = 0.0, + use_biases: bool = True, + use_layer_scale: bool = False, + layer_scale_value: float = 0.1, + retrieve_attention_scores: bool = False, + use_16_bits_layer_norm: bool = False, + ): + super().__init__() + if use_16_bits_layer_norm and not retrieve_attention_scores: + LayerNorm = LayerNorm16Bits + else: + LayerNorm = nn.LayerNorm + self.retrieve_attention_scores = retrieve_attention_scores + self.initial_ln = LayerNorm(dim_qkv, eps=1e-6) + attention_dim = dim_qkv if attention_dim == 0 else attention_dim + self.sa = CrossAttentionOp( + attention_dim, + num_heads, + dim_qkv, + dim_qkv, + is_sa=True, + use_biases=use_biases, + ) + self.sa_stochastic_depth = StochatichDepth(stochastic_depth) + self.middle_ln = LayerNorm(dim_qkv, eps=1e-6) + self.ffn = FusedMLP( + dim_model=dim_qkv, + dropout=dropout, + activation=nn.GELU, + hidden_layer_multiplier=mlp_multiplier, + bias=use_biases, + ) + self.ffn_stochastic_depth = StochatichDepth(stochastic_depth) + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale_1 = nn.Parameter( + torch.ones(dim_qkv) * layer_scale_value, requires_grad=True + ) + self.layer_scale_2 = nn.Parameter( + torch.ones(dim_qkv) * layer_scale_value, requires_grad=True + ) + + self.register_parameter( + "attention_mask_dummy", + nn.Parameter(torch.ones(1, 1, dtype=torch.bool), requires_grad=False), + ) + + def forward( + self, + tokens: torch.Tensor, + token_mask: Optional[torch.Tensor] = None, + ): + if token_mask is None: + attention_mask = None + else: + attention_mask = token_mask.unsqueeze(1) * self.attention_mask_dummy.expand( + tokens.shape[0], + tokens.shape[1], + ).unsqueeze(2) + if self.retrieve_attention_scores: + attention_output = self.sa( + self.initial_ln(tokens), + attention_mask=attention_mask, + materialize_sdpa=True, + ) + else: + attention_output = self.sa( + self.initial_ln(tokens), + attention_mask=attention_mask, + ) + if self.use_layer_scale: + tokens = tokens + self.sa_stochastic_depth( + self.layer_scale_1 * attention_output + ) + tokens = tokens + self.ffn_stochastic_depth( + self.layer_scale_2 * self.ffn(self.middle_ln(tokens)) + ) + else: + tokens = tokens + self.sa_stochastic_depth(attention_output) + tokens = tokens + self.ffn_stochastic_depth( + self.ffn(self.middle_ln(tokens)) + ) + return tokens diff --git a/plonk/models/positional_embeddings.py b/plonk/models/positional_embeddings.py new file mode 100755 index 0000000000000000000000000000000000000000..58f3355b4d02e4af5b572b05007dbdecbbc468f9 --- /dev/null +++ b/plonk/models/positional_embeddings.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +import numpy as np + + +class PositionalEmbedding(nn.Module): + """ + Taken from https://github.com/NVlabs/edm + """ + + def __init__(self, num_channels, max_positions=10000, endpoint=False): + super().__init__() + self.num_channels = num_channels + self.max_positions = max_positions + self.endpoint = endpoint + freqs = torch.arange(start=0, end=self.num_channels // 2, dtype=torch.float32) + freqs = 2 * freqs / self.num_channels + freqs = (1 / self.max_positions) ** freqs + self.register_buffer("freqs", freqs) + + def forward(self, x): + x = torch.outer(x, self.freqs) + out = torch.cat([x.cos(), x.sin()], dim=1) + return out.to(x.dtype) + + +# ---------------------------------------------------------------------------- +# Timestep embedding used in the NCSN++ architecture. +class FourierEmbedding(nn.Module): + """ + Taken from https://github.com/NVlabs/edm + """ + + def __init__(self, num_channels, scale=16): + super().__init__() + self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) + + def forward(self, x): + x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x diff --git a/plonk/models/postprocessing.py b/plonk/models/postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..49b78f3599fa7cb8c798e4c72a80620365e8e96c --- /dev/null +++ b/plonk/models/postprocessing.py @@ -0,0 +1,24 @@ +import torch.nn as nn +import torch +import numpy as np + +class UnormGPS(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("gps_normalize", torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0)) + + def forward(self, x): + """Unormalize latitude longtitude radians to -1, 1.""" + x = torch.clamp(x, -1, 1) + return x * self.gps_normalize + +class CartesiantoGPS(nn.Module): + def __init__(self): + super().__init__() + def forward(self, cartesian): + x = cartesian[:, 0] + y = cartesian[:, 1] + z = cartesian[:, 2] + lat = z.arcsin() + lon = y.atan2(x) + return torch.stack([lat, lon], dim=-1) \ No newline at end of file diff --git a/plonk/models/preconditioning.py b/plonk/models/preconditioning.py new file mode 100755 index 0000000000000000000000000000000000000000..098f09ab31131b407d22c3637eb9f0c0ba53a59d --- /dev/null +++ b/plonk/models/preconditioning.py @@ -0,0 +1,60 @@ +import torch +from torch import nn + +# ---------------------------------------------------------------------------- +# Improved preconditioning proposed in the paper "Elucidating the Design +# Space of Diffusion-Based Generative networks" (EDM). + + +class EDMPrecond(torch.nn.Module): + def __init__( + self, + network, + label_dim=0, # Number of class labels, 0 = unconditional. + sigma_min=0, # Minimum supported noise level. + sigma_max=float("inf"), # Maximum supported noise level. + sigma_data=0.5, # Expected standard deviation of the training data. + ): + super().__init__() + self.label_dim = label_dim + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + self.network = network + + def forward(self, x, sigma, conditioning=None, **network_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + conditioning = ( + None + if self.label_dim == 0 + else torch.zeros([1, self.label_dim], device=x.device) + if conditioning is None + else conditioning.to(torch.float32) + ) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + F_x = self.network( + (c_in * x), + c_noise.flatten(), + conditioning=conditioning, + **network_kwargs, + ) + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma): + return torch.as_tensor(sigma) + + +class DDPMPrecond(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, network, batch): + F_x = network(batch) + return F_x diff --git a/plonk/models/preprocessing.py b/plonk/models/preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc4030d781427a29fafc889e4916d47bd7ba584 --- /dev/null +++ b/plonk/models/preprocessing.py @@ -0,0 +1,50 @@ +import torch +from torch import nn +import numpy as np + + +class NormGPS(nn.Module): + def __init__(self, input_key="gps", output_key="x_0", normalize=True): + super().__init__() + self.input_key = input_key + self.output_key = output_key + self.normalize = normalize + if self.normalize: + self.register_buffer( + "gps_normalize", 1 / torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0) + ) + + def forward(self, batch): + """Normalize latitude longtitude radians to -1, 1.""" # not used currently + x = batch[self.input_key] + if self.normalize: + x = x * self.gps_normalize + batch[self.output_key] = x + return batch + +class GPStoCartesian(nn.Module): + def __init__(self, input_key="gps", output_key="x_0"): + super().__init__() + self.input_key = input_key + self.output_key = output_key + + def forward(self, batch): + """Project latitude longtitude radians to 3D coordinates.""" + x = batch[self.input_key] + lat, lon = x[:, 0], x[:, 1] + x = torch.stack([lat.cos() * lon.cos(), lat.cos() * lon.sin(), lat.sin()], dim=-1) + batch[self.output_key] = x + return batch + +class PrecomputedPreconditioning: + def __init__( + self, + input_key="emb", + output_key="emb", + ): + self.input_key = input_key + self.output_key = output_key + + def __call__(self, batch, device=None): + batch[self.output_key] = batch[self.input_key] + return batch diff --git a/plonk/models/pretrained_models.py b/plonk/models/pretrained_models.py new file mode 100644 index 0000000000000000000000000000000000000000..5e7535c8b4aa6c38f6c5703f5eedfad9350d223d --- /dev/null +++ b/plonk/models/pretrained_models.py @@ -0,0 +1,58 @@ +import sys +import os + +from plonk.models.networks.mlp import GeoAdaLNMLP +from huggingface_hub import PyTorchModelHubMixin +import torch +import argparse + +models_overrides = { + "YFCC100M_geoadalnmlp_r3_small_sigmoid_flow_riemann_10M_10M": "YFCC100M_geoadalnmlp_r3_small_sigmoid_flow_riemann", + "iNaturalist_geoadalnmlp_r3_small_sigmoid_flow_riemann_-7_3": "iNaturalist_geoadalnmlp_r3_small_sigmoid_flow_riemann", + "osv_5m_geoadalnmlp_r3_small_sigmoid_flow_riemann_-7_3": "osv_5m_geoadalnmlp_r3_small_sigmoid_flow_riemann", +} + + +class Plonk( + GeoAdaLNMLP, + PyTorchModelHubMixin, + repo_url="https://github.com/nicolas-dufour/plonk", + tags=["plonk", "geolocalization", "diffusion"], + license="mit", +): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +def upload_model(checkpoint_dir, repo_name): + import hydra + from omegaconf import OmegaConf + + hydra.initialize(version_base=None, config_path=f"../configs") + cfg = hydra.compose( + config_name="config", + overrides=[ + f"exp={models_overrides[checkpoint_dir]}", + ], + ) + network_config = cfg.model.network + serialized_network_config = OmegaConf.to_container(network_config, resolve=True) + print(serialized_network_config) + del serialized_network_config["_target_"] + model = Plonk(**serialized_network_config) + ckpt = torch.load(f"checkpoints/{checkpoint_dir}/last.ckpt") + ckpt_state_dict = ckpt["state_dict"] + ckpt_state_dict = {k: v for k, v in ckpt_state_dict.items() if "ema_network" in k} + ckpt_state_dict = { + k.replace("ema_network.", ""): v for k, v in ckpt_state_dict.items() + } + model.load_state_dict(ckpt_state_dict) + model.push_to_hub(repo_name, commit_message="Fixed ckpt keys") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_dir", type=str, required=True) + parser.add_argument("--repo_name", type=str, required=True) + args = parser.parse_args() + upload_model(args.checkpoint_dir, args.repo_name) diff --git a/plonk/models/samplers/__init__.py b/plonk/models/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3016adf2f25726b3e56835d76486203060fae1c8 --- /dev/null +++ b/plonk/models/samplers/__init__.py @@ -0,0 +1 @@ +# Empty file to make the directory a Python package diff --git a/plonk/models/samplers/ddim.py b/plonk/models/samplers/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..94e5b0d71ace47aad549378d0a1a5871b7fb7454 --- /dev/null +++ b/plonk/models/samplers/ddim.py @@ -0,0 +1,62 @@ +import torch + + +def ddim_sampler( + net, + batch, + conditioning_keys=None, + scheduler=None, + num_steps=250, + cfg_rate=0, + generator=None, + return_trajectories=False, +): + if scheduler is None: + raise ValueError("Scheduler must be provided") + + x_cur = batch["y"].to(torch.float32) + if return_trajectories: + traj = [x_cur.detach()] + step_indices = torch.arange(num_steps + 1, dtype=torch.float32, device=x_cur.device) + steps = 1 - step_indices / num_steps + gammas = scheduler(steps) + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + if cfg_rate > 0 and conditioning_keys is not None: + stacked_batch = {} + stacked_batch[conditioning_keys] = torch.cat( + [batch[conditioning_keys], torch.zeros_like(batch[conditioning_keys])], + dim=0, + ) + for step, (gamma_now, gamma_next) in enumerate(zip(gammas[:-1], gammas[1:])): + with torch.cuda.amp.autocast(dtype=dtype): + if cfg_rate > 0 and conditioning_keys is not None: + stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0) + stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2) + denoised_all = net(stacked_batch) + denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0) + denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate + else: + batch["y"] = x_cur + batch["gamma"] = gamma_now.expand(x_cur.shape[0]) + denoised = net(batch) + + x_pred = (x_cur - torch.sqrt(1 - gamma_now) * denoised) / torch.sqrt(gamma_now) + x_pred = torch.clamp(x_pred, -1, 1) + noise_pred = (x_cur - torch.sqrt(gamma_now) * x_pred) / torch.sqrt( + 1 - gamma_now + ) + x_next = ( + torch.sqrt(gamma_next) * x_pred + torch.sqrt(1 - gamma_next) * noise_pred + ) + x_cur = x_next + if return_trajectories: + traj.append(x_cur.detach().to(torch.float32)) + + if return_trajectories: + return x_cur.to(torch.float32), traj + else: + return x_cur.to(torch.float32) + + +def circular_transformation(x, min_val=-1, max_val=1): + return (x - min_val) % (max_val - min_val) + min_val diff --git a/plonk/models/samplers/ddpm.py b/plonk/models/samplers/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..cc8510ab527d68c8448794e796c16bbb46a457d2 --- /dev/null +++ b/plonk/models/samplers/ddpm.py @@ -0,0 +1,187 @@ +import torch + + +def ddpm_sampler( + net, + batch, + conditioning_keys=None, + scheduler=None, + uncond_tokens=None, + num_steps=1000, + cfg_rate=0, + generator=None, + use_confidence_sampling=False, + use_uncond_token=True, + confidence_value=1.0, + unconfidence_value=0.0, +): + if scheduler is None: + raise ValueError("Scheduler must be provided") + + x_cur = batch["y"].to(torch.float32) + latents = batch["previous_latents"] + if use_confidence_sampling: + batch["confidence"] = ( + torch.ones(x_cur.shape[0], device=x_cur.device) * confidence_value + ) + step_indices = torch.arange(num_steps + 1, dtype=torch.float32, device=x_cur.device) + steps = 1 - step_indices / num_steps + gammas = scheduler(steps) + latents_cond = latents_uncond = latents + # dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = torch.float32 + if cfg_rate > 0 and conditioning_keys is not None: + stacked_batch = {} + for key in conditioning_keys: + if f"{key}_mask" in batch: + if use_confidence_sampling and not use_uncond_token: + stacked_batch[f"{key}_mask"] = torch.cat( + [batch[f"{key}_mask"], batch[f"{key}_mask"]], dim=0 + ) + else: + if ( + batch[f"{key}_mask"].shape[1] + > uncond_tokens[f"{key}_mask"].shape[1] + ): + uncond_mask = ( + torch.zeros_like(batch[f"{key}_mask"]) + if batch[f"{key}_mask"].dtype == torch.bool + else torch.ones_like(batch[f"{key}_mask"]) * -torch.inf + ) + uncond_mask[:, : uncond_tokens[f"{key}_mask"].shape[1]] = ( + uncond_tokens[f"{key}_mask"] + ) + else: + uncond_mask = uncond_tokens[f"{key}_mask"] + batch[f"{key}_mask"] = torch.cat( + [ + batch[f"{key}_mask"], + torch.zeros( + batch[f"{key}_mask"].shape[0], + uncond_tokens[f"{key}_embeddings"].shape[1] + - batch[f"{key}_mask"].shape[1], + device=batch[f"{key}_mask"].device, + dtype=batch[f"{key}_mask"].dtype, + ), + ], + dim=1, + ) + stacked_batch[f"{key}_mask"] = torch.cat( + [batch[f"{key}_mask"], uncond_mask], dim=0 + ) + if f"{key}_embeddings" in batch: + if use_confidence_sampling and not use_uncond_token: + stacked_batch[f"{key}_embeddings"] = torch.cat( + [ + batch[f"{key}_embeddings"], + batch[f"{key}_embeddings"], + ], + dim=0, + ) + else: + if ( + batch[f"{key}_embeddings"].shape[1] + > uncond_tokens[f"{key}_embeddings"].shape[1] + ): + uncond_tokens[f"{key}_embeddings"] = torch.cat( + [ + uncond_tokens[f"{key}_embeddings"], + torch.zeros( + uncond_tokens[f"{key}_embeddings"].shape[0], + batch[f"{key}_embeddings"].shape[1] + - uncond_tokens[f"{key}_embeddings"].shape[1], + uncond_tokens[f"{key}_embeddings"].shape[2], + device=uncond_tokens[f"{key}_embeddings"].device, + ), + ], + dim=1, + ) + elif ( + batch[f"{key}_embeddings"].shape[1] + < uncond_tokens[f"{key}_embeddings"].shape[1] + ): + batch[f"{key}_embeddings"] = torch.cat( + [ + batch[f"{key}_embeddings"], + torch.zeros( + batch[f"{key}_embeddings"].shape[0], + uncond_tokens[f"{key}_embeddings"].shape[1] + - batch[f"{key}_embeddings"].shape[1], + batch[f"{key}_embeddings"].shape[2], + device=batch[f"{key}_embeddings"].device, + ), + ], + dim=1, + ) + stacked_batch[f"{key}_embeddings"] = torch.cat( + [ + batch[f"{key}_embeddings"], + uncond_tokens[f"{key}_embeddings"], + ], + dim=0, + ) + elif key not in batch: + raise ValueError(f"Key {key} not in batch") + else: + if isinstance(batch[key], torch.Tensor): + if use_confidence_sampling and not use_uncond_token: + stacked_batch[key] = torch.cat([batch[key], batch[key]], dim=0) + else: + stacked_batch[key] = torch.cat( + [batch[key], uncond_tokens], dim=0 + ) + elif isinstance(batch[key], list): + if use_confidence_sampling and not use_uncond_token: + stacked_batch[key] = [*batch[key], *batch[key]] + else: + stacked_batch[key] = [*batch[key], *uncond_tokens] + else: + raise ValueError( + "Conditioning must be a tensor or a list of tensors" + ) + if use_confidence_sampling: + stacked_batch["confidence"] = torch.cat( + [ + torch.ones(x_cur.shape[0], device=x_cur.device) * confidence_value, + torch.ones(x_cur.shape[0], device=x_cur.device) + * unconfidence_value, + ], + dim=0, + ) + for step, (gamma_now, gamma_next) in enumerate(zip(gammas[:-1], gammas[1:])): + with torch.cuda.amp.autocast(dtype=dtype): + if cfg_rate > 0 and conditioning_keys is not None: + stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0) + stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2) + stacked_batch["previous_latents"] = ( + torch.cat([latents_cond, latents_uncond], dim=0) + if latents is not None + else None + ) + denoised_all, latents_all = net(stacked_batch) + denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0) + latents_cond, latents_uncond = latents_all.chunk(2, dim=0) + denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate + else: + batch["y"] = x_cur + batch["gamma"] = gamma_now.expand(x_cur.shape[0]) + batch["previous_latents"] = latents + denoised, latents = net( + batch, + ) + x_pred = (x_cur - torch.sqrt(1 - gamma_now) * denoised) / torch.sqrt(gamma_now) + x_pred = torch.clamp(x_pred, -1, 1) + noise_pred = (x_cur - torch.sqrt(gamma_now) * x_pred) / torch.sqrt( + 1 - gamma_now + ) + + log_alpha_t = torch.log(gamma_now) - torch.log(gamma_next) + alpha_t = torch.clip(torch.exp(log_alpha_t), 0, 1) + x_mean = torch.rsqrt(alpha_t) * ( + x_cur - torch.rsqrt(1 - gamma_now) * (1 - alpha_t) * noise_pred + ) + var_t = 1 - alpha_t + eps = torch.randn(x_cur.shape, device=x_cur.device, generator=generator) + x_next = x_mean + torch.sqrt(var_t) * eps + x_cur = x_next + return x_cur.to(torch.float32) diff --git a/plonk/models/samplers/edm.py b/plonk/models/samplers/edm.py new file mode 100755 index 0000000000000000000000000000000000000000..eae4976f5ada37e2ebc72deabede9e244db9ffcb --- /dev/null +++ b/plonk/models/samplers/edm.py @@ -0,0 +1,68 @@ +import torch +import numpy as np + + +def edm_sampler( + net, + x_N, + conditioning=None, + latents=None, + randn_like=torch.randn_like, + num_steps=18, + sigma_min=0.002, + sigma_max=80, + rho=7, + S_churn=0, + S_min=0, + S_max=float("inf"), + S_noise=1, +): + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Time step discretization. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=x_N.device) + t_steps = ( + sigma_max ** (1 / rho) + + step_indices + / (num_steps - 1) + * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) + ) ** rho + t_steps = torch.cat( + [net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] + ) # t_N = 0 + + # Main sampling loop. + x_next = x_N.to(torch.float64) * t_steps[0] + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = ( + min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 + ) + t_hat = net.round_sigma(t_cur + gamma * t_cur) + x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) + + # Euler step. + denoised, latents = net( + x_hat, t_hat.expand(x_cur.shape[0]), conditioning, previous_latents=latents + ) + denoised = denoised.to(torch.float64) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + denoised, latents = net( + x_next, + t_next.expand(x_cur.shape[0]), + conditioning, + previous_latents=latents, + ) + denoised = denoised.to(torch.float64) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + + return x_next diff --git a/plonk/models/samplers/flow_sampler.py b/plonk/models/samplers/flow_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4609d415acd4a147e539bac467d5fd8bc4ae0f --- /dev/null +++ b/plonk/models/samplers/flow_sampler.py @@ -0,0 +1,57 @@ +import torch + + +def flow_sampler( + net, + batch, + conditioning_keys=None, + scheduler=None, + num_steps=250, + cfg_rate=0, + generator=None, + return_trajectories=False, +): + if scheduler is None: + raise ValueError("Scheduler must be provided") + + x_cur = batch["y"].to(torch.float32) + if return_trajectories: + traj = [x_cur.detach()] + step_indices = torch.arange(num_steps + 1, dtype=torch.float32, device=x_cur.device) + steps = 1 - step_indices / num_steps + gammas = scheduler(steps) + dtype = ( + torch.float32 + ) # torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + if cfg_rate > 0 and conditioning_keys is not None: + stacked_batch = {} + stacked_batch[conditioning_keys] = torch.cat( + [batch[conditioning_keys], torch.zeros_like(batch[conditioning_keys])], + dim=0, + ) + for step, (gamma_now, gamma_next) in enumerate(zip(gammas[:-1], gammas[1:])): + with torch.cuda.amp.autocast(dtype=dtype): + if cfg_rate > 0 and conditioning_keys is not None: + stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0) + stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2) + denoised_all = net(stacked_batch) + denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0) + denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate + else: + batch["y"] = x_cur + batch["gamma"] = gamma_now.expand(x_cur.shape[0]) + denoised = net(batch) + dt = gamma_next - gamma_now + x_next = x_cur + dt * denoised + x_cur = x_next + if return_trajectories: + traj.append(x_cur.detach().to(torch.float32)) + + if return_trajectories: + return x_cur.to(torch.float32), traj + else: + return x_cur.to(torch.float32) + + +def circular_transformation(x, min_val=-1, max_val=1): + return (x - min_val) % (max_val - min_val) + min_val diff --git a/plonk/models/samplers/riemannian_flow_sampler.py b/plonk/models/samplers/riemannian_flow_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..6e439b1bb6bcec78700ee7ae8e95a0ca1381c49b --- /dev/null +++ b/plonk/models/samplers/riemannian_flow_sampler.py @@ -0,0 +1,84 @@ +import torch +from plonk.utils.manifolds import Sphere +from tqdm.auto import tqdm + + +def riemannian_flow_sampler( + net, + batch, + manifold=Sphere(), + conditioning_keys=None, + scheduler=None, + num_steps=250, + cfg_rate=0, + generator=None, + return_trajectories=False, +): + if scheduler is None: + raise ValueError("Scheduler must be provided") + + x_cur = batch["y"].to(torch.float32) + if return_trajectories: + traj = [x_cur.detach()] + step_indices = torch.arange(num_steps + 1, dtype=torch.float32, device=x_cur.device) + steps = 1 - step_indices / num_steps + gammas = scheduler(steps) + dtype = torch.float32 + if cfg_rate > 0 and conditioning_keys is not None: + stacked_batch = {} + stacked_batch[conditioning_keys] = torch.cat( + [batch[conditioning_keys], torch.zeros_like(batch[conditioning_keys])], + dim=0, + ) + for step, (gamma_now, gamma_next) in enumerate(zip(gammas[:-1], gammas[1:])): + with torch.cuda.amp.autocast(dtype=dtype): + if cfg_rate > 0 and conditioning_keys is not None: + stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0) + stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2) + denoised_all = net(stacked_batch) + denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0) + denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate + else: + batch["y"] = x_cur + batch["gamma"] = gamma_now.expand(x_cur.shape[0]) + denoised = net(batch) + + dt = gamma_next - gamma_now + x_next = x_cur + dt * denoised # manifold.expmap(x_cur, dt * denoised) + x_next = manifold.projx(x_next) + x_cur = x_next + if return_trajectories: + traj.append(x_cur.detach().to(torch.float32)) + + if return_trajectories: + return x_cur.to(torch.float32), traj + else: + return x_cur.to(torch.float32) + + +def ode_riemannian_flow_sampler( + odefunc, + x_1, + manifold=Sphere(), + scheduler=None, + num_steps=1000, +): + if scheduler is None: + raise ValueError("Scheduler must be provided") + + x_cur = x_1.to(torch.float32) + steps = ( + torch.arange(num_steps + 1, dtype=torch.float32, device=x_cur.device) + / num_steps + ) + dtype = torch.float32 + for step, (t_now, t_next) in enumerate(zip(steps[:-1], steps[1:]), total=num_steps): + with torch.cuda.amp.autocast(dtype=dtype): + denoised = odefunc(t_now, x_cur) + gamma_now = scheduler(t_now) + gamma_next = scheduler(t_next) + dt = gamma_next - gamma_now + x_next = x_cur + dt * denoised # manifold.expmap(x_cur, dt * denoised) + x_next = manifold.projx(x_next) + x_cur = x_next + return x_cur.to(torch.float32) diff --git a/plonk/models/samplers/von_fisher_sampling.py b/plonk/models/samplers/von_fisher_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..c3afab2e65aab43455f82243dd908ac77f9486b1 --- /dev/null +++ b/plonk/models/samplers/von_fisher_sampling.py @@ -0,0 +1,105 @@ +""" +Generate multivariate von Mises Fisher samples. +PyTorch implementation of the original code from: +https://github.com/clara-labs/spherecluster +""" + +import torch + +__all__ = ["sample_vMF"] + + +def vMF_sampler( + net, + batch, +): + mu, kappa = net(batch) + return sample_vMF(mu.T, kappa.squeeze(1)) + + +def vMF_mixture_sampler( + net, + batch, +): + mu_mixture, kappa_mixture, weights = net(batch) + # Sample mixture component indices based on weights + indices = torch.multinomial(weights, num_samples=1).squeeze() + # Select corresponding mu and kappa + mu = mu_mixture[torch.arange(mu_mixture.shape[0]), indices] + kappa = kappa_mixture[torch.arange(kappa_mixture.shape[0]), indices] + return sample_vMF(mu.T, kappa) + + +def sample_vMF(mu, kappa, num_samples=1): + """Generate N-dimensional samples from von Mises Fisher + distribution around center mu ∈ R^N with concentration kappa. + mu and kappa may be vectors, + mu should have shape (N,) or (N, 1), kappa should be scalar or vector of length N. + """ + if len(mu.shape) == 1: + mu = mu.unsqueeze(1) + + if isinstance(kappa, torch.Tensor): + dim = mu.shape[0] + assert mu.shape[1] == kappa.size(0) + else: + dim = mu.shape[0] + mu = mu.repeat(1, num_samples) + kappa = torch.full((num_samples,), kappa, device=mu.device, dtype=mu.dtype) + + # sample offset from center (on sphere) with spread kappa + w = _sample_weight(kappa, dim) + + # sample a point v on the unit sphere that's orthogonal to mu + v = _sample_orthonormal_to(mu) + + # compute new point + result = v * torch.sqrt(1.0 - w**2).unsqueeze(0) + w.unsqueeze(0) * mu + return result.T + + +def _sample_weight(kappa, dim): + """Rejection sampling scheme for sampling distance from center on + surface of the sphere. + """ + dim = dim - 1 # since S^{n-1} + try: + size = kappa.size(0) + except AttributeError: + size = 1 + + b = dim / (torch.sqrt(4.0 * kappa**2 + dim**2) + 2 * kappa) + x = (1.0 - b) / (1.0 + b) + c = kappa * x + dim * torch.log(1 - x**2) + + w = torch.zeros_like(kappa) + idx = torch.zeros_like(kappa, dtype=torch.bool) + + while True: + where_zero = ~idx + if torch.all(idx): + return w + + z = ( + torch.distributions.Beta(dim / 2.0, dim / 2.0) + .sample((size,)) + .to(kappa.device) + ) + _w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z) + u = torch.rand(size, device=kappa.device) + + _idx = kappa * _w + dim * torch.log(1.0 - x * _w) - c >= torch.log(u) + + if not torch.any(_idx): + continue + + w[where_zero] = _w[where_zero] + idx[_idx] = True + + +def _sample_orthonormal_to(mu): + """Sample point on sphere orthogonal to mu.""" + v = torch.randn(mu.shape[0], mu.shape[1], device=mu.device) + proj_mu_v = mu * ((v * mu).sum(dim=0)) / torch.norm(mu, dim=0) ** 2 + orthto = v - proj_mu_v + return orthto / torch.norm(orthto, dim=0) diff --git a/plonk/models/schedulers.py b/plonk/models/schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d5c3370e76ff8ffdf613f319f4b7782c3de55c --- /dev/null +++ b/plonk/models/schedulers.py @@ -0,0 +1,106 @@ +import torch + + +class SigmoidScheduler: + def __init__(self, start=-3, end=3, tau=1, clip_min=1e-9): + self.start = start + self.end = end + self.tau = tau + self.clip_min = clip_min + + self.v_start = torch.sigmoid(torch.tensor(self.start / self.tau)) + self.v_end = torch.sigmoid(torch.tensor(self.end / self.tau)) + + def __call__(self, t): + output = ( + -torch.sigmoid((t * (self.end - self.start) + self.start) / self.tau) + + self.v_end + ) / (self.v_end - self.v_start) + return torch.clamp(output, min=self.clip_min, max=1.0) + + def derivative(self, t): + x = (t * (self.end - self.start) + self.start) / self.tau + sigmoid_x = torch.sigmoid(x) + # Chain rule: d/dt of original function + return ( + -(self.end - self.start) + * sigmoid_x + * (1 - sigmoid_x) + / (self.tau * (self.v_end - self.v_start)) + ) + + def alpha(self, t): + return -self.derivative(t) / (1e-6 + self.__call__(t)) + + +class LinearScheduler: + def __init__(self, start=1, end=0, clip_min=1e-9): + self.start = start + self.end = end + self.clip_min = clip_min + + def __call__(self, t): + output = (self.end - self.start) * t + self.start + return torch.clamp(output, min=self.clip_min, max=1.0) + + def derivative(self, t): + return torch.tensor(self.end - self.start).to(t.device) + + def alpha(self, t): + return -self.derivative(t) / (1e-6 + self.__call__(t)) + + +class CosineScheduler: + def __init__( + self, + start: float = 1, + end: float = 0, + tau: float = 1.0, + clip_min: float = 1e-9, + ): + self.start = start + self.end = end + self.tau = tau + self.clip_min = clip_min + + self.v_start = torch.cos(torch.tensor(self.start) * torch.pi / 2) ** ( + 2 * self.tau + ) + self.v_end = torch.cos(torch.tensor(self.end) * torch.pi / 2) ** (2 * self.tau) + + def __call__(self, t: float) -> float: + output = ( + torch.cos((t * (self.end - self.start) + self.start) * torch.pi / 2) + ** (2 * self.tau) + - self.v_end + ) / (self.v_start - self.v_end) + return torch.clamp(output, min=self.clip_min, max=1.0) + + def derivative(self, t: float) -> float: + x = (t * (self.end - self.start) + self.start) * torch.pi / 2 + cos_x = torch.cos(x) + # Chain rule: d/dt of original function + return ( + -2 + * self.tau + * (self.end - self.start) + * torch.pi + / 2 + * cos_x + * (cos_x ** (2 * self.tau - 1)) + * torch.sin(x) + / (self.v_start - self.v_end) + ) + + +class CosineSchedulerSimple: + def __init__(self, ns: float = 0.0002, ds: float = 0.00025): + self.ns = ns + self.ds = ds + + def __call__(self, t: float) -> float: + return torch.cos(((t + self.ns) / (1 + self.ds)) * torch.pi / 2) ** 2 + + def derivative(self, t: float) -> float: + x = ((t + self.ns) / (1 + self.ds)) * torch.pi / 2 + return -torch.pi * torch.cos(x) * torch.sin(x) / (1 + self.ds) diff --git a/plonk/pipe.py b/plonk/pipe.py new file mode 100644 index 0000000000000000000000000000000000000000..a97defc7e23a2212f633e23ed31d4fe4e20240c5 --- /dev/null +++ b/plonk/pipe.py @@ -0,0 +1,599 @@ +import torch +from plonk.models.pretrained_models import Plonk +from plonk.models.samplers.riemannian_flow_sampler import riemannian_flow_sampler + +from plonk.models.postprocessing import CartesiantoGPS + +from plonk.models.schedulers import ( + SigmoidScheduler, + LinearScheduler, + CosineScheduler, +) +from plonk.models.preconditioning import DDPMPrecond +from torchvision import transforms +from transformers import CLIPProcessor, CLIPVisionModel +from plonk.utils.image_processing import CenterCrop +import numpy as np +from plonk.utils.manifolds import Sphere +from torch.func import jacrev, vmap, vjp +from torchdiffeq import odeint +from tqdm import tqdm + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +MODELS = { + "nicolas-dufour/PLONK_YFCC": {"emb_name": "dinov2"}, + "nicolas-dufour/PLONK_OSV_5M": { + "emb_name": "street_clip", + }, + "nicolas-dufour/PLONK_iNaturalist": { + "emb_name": "dinov2", + }, +} + + +def scheduler_fn( + scheduler_type: str, start: float, end: float, tau: float, clip_min: float = 1e-9 +): + if scheduler_type == "sigmoid": + return SigmoidScheduler(start, end, tau, clip_min) + elif scheduler_type == "cosine": + return CosineScheduler(start, end, tau, clip_min) + elif scheduler_type == "linear": + return LinearScheduler(clip_min=clip_min) + else: + raise ValueError(f"Scheduler type {scheduler_type} not supported") + + +class DinoV2FeatureExtractor: + def __init__(self, device=device): + super().__init__() + self.device = device + self.emb_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg") + self.emb_model.eval() + self.emb_model.to(self.device) + self.augmentation = transforms.Compose( + [ + CenterCrop(ratio="1:1"), + transforms.Resize( + 336, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ), + ] + ) + + def __call__(self, batch): + embs = [] + with torch.no_grad(): + for img in batch["img"]: + emb = self.emb_model( + self.augmentation(img).unsqueeze(0).to(self.device) + ).squeeze(0) + embs.append(emb) + batch["emb"] = torch.stack(embs) + return batch + + +class StreetClipFeatureExtractor: + def __init__(self, device=device): + self.device = device + self.emb_model = CLIPVisionModel.from_pretrained("geolocal/StreetCLIP").to( + device + ) + self.processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") + + def __call__(self, batch): + inputs = self.processor(images=batch["img"], return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + with torch.no_grad(): + outputs = self.emb_model(**inputs) + embeddings = outputs.last_hidden_state[:, 0] + batch["emb"] = embeddings + return batch + + +def load_prepocessing(model_name, dtype=torch.float32): + if MODELS[model_name]["emb_name"] == "dinov2": + return DinoV2FeatureExtractor() + elif MODELS[model_name]["emb_name"] == "street_clip": + return StreetClipFeatureExtractor() + else: + raise ValueError(f"Embedding model {MODELS[model_name]['emb_name']} not found") + + +# Helper functions adapted from plonk/models/module.py +# for likelihood computation +def div_fn(u): + """Accepts a function u:R^D -> R^D.""" + J = jacrev(u, argnums=0) + return lambda x, y: torch.trace(J(x, y).squeeze(0)) + + +def output_and_div(vecfield, x, y, v=None): + if v is None: + dx = vecfield(x, y) + div = vmap(div_fn(vecfield))(x, y) + else: + vecfield_x = lambda x: vecfield(x, y) + dx, vjpfunc = vjp(vecfield_x, x) + vJ = vjpfunc(v)[0] + div = torch.sum(vJ * v, dim=-1) + return dx, div + + +def _gps_degrees_to_cartesian(gps_coords_deg, device): + """Converts GPS coordinates (latitude, longitude) in degrees to Cartesian coordinates.""" + if not isinstance(gps_coords_deg, np.ndarray): + gps_coords_deg = np.array(gps_coords_deg) + if gps_coords_deg.ndim == 1: + gps_coords_deg = gps_coords_deg[np.newaxis, :] + + lat_rad = np.radians(gps_coords_deg[:, 0]) + lon_rad = np.radians(gps_coords_deg[:, 1]) + x = np.cos(lat_rad) * np.cos(lon_rad) + y = np.cos(lat_rad) * np.sin(lon_rad) + z = np.sin(lat_rad) + cartesian_coords = np.stack([x, y, z], axis=-1) + return torch.tensor(cartesian_coords, dtype=torch.float32, device=device) + + +class PlonkPipeline: + """ + The PlonkPipeline class is designed to perform geolocation prediction from images using a pre-trained PLONK model. + It integrates various components such as feature extractors, samplers, and coordinate transformations to predict locations. + + Initialization: + PlonkPipeline( + model_path, + scheduler="sigmoid", + scheduler_start=-7, + scheduler_end=3, + scheduler_tau=1.0, + device="cuda", + ) + + Parameters: + model_path (str): Path to the pre-trained PLONK model. + scheduler (str): The scheduler type to use. Options are "sigmoid", "cosine", "linear". Default is "sigmoid". + scheduler_start (float): Start value for the scheduler. Default is -7. + scheduler_end (float): End value for the scheduler. Default is 3. + scheduler_tau (float): Tau value for the scheduler. Default is 1.0. + device (str): Device to run the model on. Default is "cuda". + + Methods: + model(*args, **kwargs): + Runs the preconditioning on the network with the provided arguments. + + __call__(...): + Predicts geolocation coordinates from input images. + + Parameters: + images: Input images to predict locations for. + batch_size (int, optional): Batch size for processing. + x_N (torch.Tensor, optional): Initial noise tensor. If not provided, it is generated. + num_steps (int, optional): Number of steps for the sampler. + scheduler (callable, optional): Custom scheduler function. If not provided, the default scheduler is used. + cfg (float): Classifier-free guidance scale. Default is 0. + generator (torch.Generator, optional): Random number generator. + + Returns: + torch.Tensor: Predicted latitude and longitude coordinates. + + compute_likelihood(...): + Computes the exact log-likelihood of observing the given coordinates for the given images. + + Parameters: + images: Input images (PIL Image or list of PIL Images). Optional if emb is provided. + coordinates: Target GPS coordinates (latitude, longitude) in degrees. + emb: Pre-computed embeddings. If provided, images will be ignored. + cfg (float): Classifier-free guidance scale. Default is 0 (no guidance). + rademacher (bool): Whether to use Rademacher estimator for divergence. Default is False. + atol (float): Absolute tolerance for ODE solver. Default is 1e-5. + rtol (float): Relative tolerance for ODE solver. Default is 1e-5. + normalize_logp (bool): Whether to normalize the log-likelihood by log(2) * dim. Default is True. + + compute_likelihood_grid(...): + Computes the likelihood of an image over a global grid of coordinates. + + Parameters: + image: Input PIL Image. + grid_resolution_deg (float): The resolution of the grid in degrees. Default is 10 degrees. + batch_size (int): How many grid points to process in each batch. Adjust based on available memory. Default is 1024. + cfg (float): Classifier-free guidance scale passed to compute_likelihood. Default is 0. + + Returns: + tuple: (latitude_grid, longitude_grid, likelihood_grid) + - latitude_grid (np.ndarray): 1D array of latitudes. + - longitude_grid (np.ndarray): 1D array of longitudes. + - likelihood_grid (np.ndarray): 2D array of log-likelihoods corresponding to the lat/lon grid. + + compute_localizability(...): + Computes the localizability of an image. We use importance sampling by sampling by the model and not the grid to have a more accurate estimate. + + Parameters: + image: Input PIL Image. + atol (float): Absolute tolerance for ODE solver. Default is 1e-5. + rtol (float): Relative tolerance for ODE solver. Default is 1e-5. + number_monte_carlo_samples (int): How many samples to use for importance sampling. Default is 1024. + + Returns: + torch.Tensor: Localizability of the image. + + Example Usage: + pipe = PlonkPipeline( + "path/to/plonk/model", + ) + pipe.to("cuda") + coordinates = pipe( + images, + batch_size=32 + ) + likelihood = pipe.compute_likelihood( + images, + coordinates, + cfg=0, + rademacher=False, + ) + localizability = pipe.compute_localizability( + image, + number_monte_carlo_samples=1024, + ) + """ + + def __init__( + self, + model_path, + scheduler="sigmoid", + scheduler_start=-7, + scheduler_end=3, + scheduler_tau=1.0, + device=device, + ): + self.network = Plonk.from_pretrained(model_path).to(device) + self.network.requires_grad_(False).eval() + assert scheduler in [ + "sigmoid", + "cosine", + "linear", + ], f"Scheduler {scheduler} not supported" + self.scheduler = scheduler_fn( + scheduler, scheduler_start, scheduler_end, scheduler_tau + ) + self.cond_preprocessing = load_prepocessing(model_name=model_path) + self.postprocessing = CartesiantoGPS() + self.sampler = riemannian_flow_sampler + self.model_path = model_path + self.preconditioning = DDPMPrecond() + self.device = device + # Add manifold + self.manifold = Sphere() + self.input_dim = 3 # Assuming 3D Cartesian coordinates for sphere + + def model(self, *args, **kwargs): + return self.preconditioning(self.network, *args, **kwargs) + + def __call__( + self, + images, + batch_size=None, + x_N=None, + num_steps=None, + scheduler=None, + cfg=0, + generator=None, + ): + """Sample from the model given conditioning. + + Args: + images: Conditioning input (image or list of images) + batch_size: Number of samples to generate (inferred from cond if not provided) + x_N: Initial noise tensor (generated if not provided) + num_steps: Number of sampling steps (uses default if not provided) + sampler: Custom sampler function (uses default if not provided) + scheduler: Custom scheduler function (uses default if not provided) + cfg: Classifier-free guidance scale (default 15) + generator: Random number generator + + Returns: + Sampled GPS coordinates after postprocessing + """ + # Set up batch size and initial noise + shape = [3] + if not isinstance(images, list): + images = [images] + if x_N is None: + if batch_size is None: + if isinstance(images, list): + batch_size = len(images) + else: + batch_size = 1 + x_N = torch.randn( + batch_size, *shape, device=self.device, generator=generator + ) + else: + x_N = x_N.to(self.device) + if x_N.ndim == 3: + x_N = x_N.unsqueeze(0) + batch_size = x_N.shape[0] + + # Set up batch with conditioning + batch = {"y": x_N} + batch["img"] = images + batch = self.cond_preprocessing(batch) + if len(images) > 1: + assert len(images) == batch_size + else: + batch["emb"] = batch["emb"].repeat(batch_size, 1) + + # Use default sampler/scheduler if not provided + sampler = self.sampler + if scheduler is None: + scheduler = self.scheduler + # Sample from model + if num_steps is None: + output = sampler( + self.model, + batch, + conditioning_keys="emb", + scheduler=scheduler, + cfg_rate=cfg, + generator=generator, + ) + else: + output = sampler( + self.model, + batch, + conditioning_keys="emb", + scheduler=scheduler, + num_steps=num_steps, + cfg_rate=cfg, + generator=generator, + ) + + # Apply postprocessing and return + output = self.postprocessing(output) + # To degrees + output = np.degrees(output.detach().cpu().numpy()) + return output + + def compute_likelihood( + self, + images=None, + coordinates=None, + emb=None, + cfg=0, + rademacher=False, + atol=1e-6, + rtol=1e-6, + normalize_logp=True, + ): + """ + Computes the exact log-likelihood of observing the given coordinates for the given images. + + Args: + images: Input images (PIL Image or list of PIL Images). Optional if emb is provided. + coordinates: Target GPS coordinates (latitude, longitude) in degrees. + Can be a list of pairs, numpy array (N, 2), or tensor (N, 2). + emb: Pre-computed embeddings. If provided, images will be ignored. + cfg (float): Classifier-free guidance scale. Default is 0 (no guidance). + rademacher (bool): Whether to use Rademacher estimator for divergence. Default is False. + atol (float): Absolute tolerance for ODE solver. Default is 1e-5. + rtol (float): Relative tolerance for ODE solver. Default is 1e-5. + normalize_logp (bool): Whether to normalize the log-likelihood by log(2) * dim. Default is True. + Returns: + torch.Tensor: Log-likelihood values for each input pair (image, coordinate). + """ + nfe = [0] # Counter for number of function evaluations + + # 1. Get embeddings either from images or directly from emb parameter + if emb is not None: + # Use provided embeddings directly + if isinstance(emb, torch.Tensor): + batch = {"emb": emb.to(self.device)} + else: + raise TypeError("emb must be a torch.Tensor") + else: + # Process images to get embeddings + if not isinstance(images, list): + images = [images] + batch = {"img": images} + batch = self.cond_preprocessing(batch) # Adds 'emb' key + + # 2. Preprocess coordinates (GPS degrees -> Cartesian) + x_1 = _gps_degrees_to_cartesian(coordinates, self.device) + if x_1.shape[0] != batch["emb"].shape[0]: + if x_1.shape[0] == 1: + # Repeat coordinates if only one is provided for multiple images + x_1 = x_1.repeat(batch["emb"].shape[0], 1) + elif batch["emb"].shape[0] == 1: + # Repeat embedding if only one image is provided for multiple coords + batch["emb"] = batch["emb"].repeat(x_1.shape[0], 1) + else: + raise ValueError( + f"Batch size mismatch between images ({batch['emb'].shape[0]}) and coordinates ({x_1.shape[0]})" + ) + + # Ensure correct shapes for ODE solver + if x_1.ndim == 1: + x_1 = x_1.unsqueeze(0) + if batch["emb"].ndim == 1: + batch["emb"] = batch["emb"].unsqueeze(0) + + with torch.inference_mode(mode=False): # Enable grads for jacobian calculation + # Define the ODE function + def odefunc(t, tensor): + nfe[0] += 1 + t = t.to(tensor) + gamma = self.scheduler(t) + x = tensor[..., : self.input_dim] + y = batch["emb"] # Conditioning + + def vecfield(x_vf, y_vf): + batch_vecfield = { + "y": x_vf, + "emb": y_vf, + "gamma": gamma.reshape(-1), + } + if cfg > 0: + # Apply classifier-free guidance + batch_vecfield_uncond = { + "y": x_vf, + "emb": torch.zeros_like(y_vf), # Null condition + "gamma": gamma.reshape(-1), + } + model_output_cond = self.model(batch_vecfield) + model_output_uncond = self.model(batch_vecfield_uncond) + model_output = model_output_cond + cfg * ( + model_output_cond - model_output_uncond + ) + else: + # Unconditional or naturally conditioned score + model_output = self.model(batch_vecfield) + + # Assuming 'flow_matching' interpolant based on sampler used + d_gamma = self.scheduler.derivative(t).reshape(-1, 1) + return d_gamma * model_output + + if rademacher: + v = torch.randint_like(x, 2) * 2 - 1 + else: + v = None + dx, div = output_and_div(vecfield, x, y, v=v) + div = div.reshape(-1, 1) + del t, x + return torch.cat([dx, div], dim=-1) + + # 3. Solve the ODE + state1 = torch.cat([x_1, torch.zeros_like(x_1[..., :1])], dim=-1) + + # Note: Using standard ODEINT here. For strict Riemannian integration, + # a manifold-aware solver might be needed, but this follows the + # structure from DiffGeolocalizer.compute_exact_loglikelihood more closely. + with torch.no_grad(): + state0 = odeint( + odefunc, + state1, + t=torch.linspace(0, 1.0, 2).to(x_1.device), + atol=atol, + rtol=rtol, + method="dopri5", + options={"min_step": 1e-5}, + )[ + -1 + ] # Get the state at t=0 + + x_0, logdetjac = state0[..., : self.input_dim], state0[..., -1] + + # Project final point onto the manifold (optional but good practice) + x_0 = self.manifold.projx(x_0) + + # 4. Compute log probability + # Log prob of base distribution (Gaussian projected onto sphere approx) + logp0 = self.manifold.base_logprob(x_0) + + # Change of variables formula: log p(x_1) = log p(x_0) + log |det J| + logp1 = logp0 + logdetjac + + # Optional: Normalize by log(2) * dim for bits per dimension + if normalize_logp: + logp1 = logp1 / (self.input_dim * np.log(2)) + + print(f"Likelihood NFE: {nfe[0]}") # Print number of function evaluations + return logp1 + + def compute_likelihood_grid( + self, + image, + grid_resolution_deg=10, + batch_size=1024, + cfg=0, + ): + """ + Computes the likelihood of an image over a global grid of coordinates. + + Args: + image: Input PIL Image. + grid_resolution_deg (float): The resolution of the grid in degrees. + Default is 10 degrees. + batch_size (int): How many grid points to process in each batch. + Adjust based on available memory. Default is 1024. + cfg (float): Classifier-free guidance scale passed to compute_likelihood. + Default is 0. + + Returns: + tuple: (latitude_grid, longitude_grid, likelihood_grid) + - latitude_grid (np.ndarray): 1D array of latitudes. + - longitude_grid (np.ndarray): 1D array of longitudes. + - likelihood_grid (np.ndarray): 2D array of log-likelihoods + corresponding to the lat/lon grid. + """ + # 1. Generate the grid + latitudes = np.arange(-90, 90 + grid_resolution_deg, grid_resolution_deg) + longitudes = np.arange(-180, 180 + grid_resolution_deg, grid_resolution_deg) + lon_grid, lat_grid = np.meshgrid(longitudes, latitudes) + + # Flatten the grid for processing + all_coordinates = np.vstack([lat_grid.ravel(), lon_grid.ravel()]).T + num_points = all_coordinates.shape[0] + print( + f"Computing likelihood over a {latitudes.size}x{longitudes.size} grid ({num_points} points)..." + ) + + emb = self.cond_preprocessing({"img": [image]})["emb"] + + # 2. Process in batches + all_likelihoods = [] + for i in tqdm( + range(0, num_points, batch_size), desc="Computing Likelihood Grid" + ): + coord_batch = all_coordinates[i : i + batch_size] + + # Compute likelihood for the batch + likelihood_batch = self.compute_likelihood( + emb=emb, + coordinates=coord_batch, + cfg=cfg, + rademacher=False, # Using exact divergence is better for grid + ) + all_likelihoods.append(likelihood_batch.detach().cpu().numpy()) + + # 3. Combine and reshape results + likelihood_flat = np.concatenate(all_likelihoods, axis=0) + likelihood_grid = likelihood_flat.reshape(lat_grid.shape) + + # Return grid definition and likelihood values + return latitudes, longitudes, likelihood_grid + + def compute_localizability( + self, + image, + atol=1e-6, + rtol=1e-6, + number_monte_carlo_samples=1024, + ): + """ + Computes the localizability of an image. We use importance sampling by sampling by the model and not the grid to have a more accurate estimate. + + Args: + image: Input PIL Image. + atol (float): Absolute tolerance for ODE solver. Default is 1e-5. + rtol (float): Relative tolerance for ODE solver. Default is 1e-5. + """ + samples = self(image, batch_size=number_monte_carlo_samples) + emb = self.cond_preprocessing({"img": [image]})["emb"] + localizability = self.compute_likelihood( + emb=emb, + coordinates=samples, + atol=atol, + rtol=rtol, + normalize_logp=False, + ) # importance sampling of likelihood + return localizability.mean() / (4 * torch.pi * np.log(2)) + + def to(self, device): + self.network.to(device) + self.postprocessing.to(device) + self.device = torch.device(device) + return self diff --git a/plonk/train.py b/plonk/train.py new file mode 100755 index 0000000000000000000000000000000000000000..2a1cb7d73178aec84560a165e8e4879c17c7c568 --- /dev/null +++ b/plonk/train.py @@ -0,0 +1,146 @@ +import os +import hydra +import wandb +from os.path import isfile, join +from shutil import copyfile + +import torch + +from omegaconf import OmegaConf +from hydra.core.hydra_config import HydraConfig +from hydra.utils import instantiate +from pytorch_lightning.callbacks import LearningRateMonitor +from lightning_fabric.utilities.rank_zero import _get_rank +from callbacks import EMACallback, FixNANinGrad, IncreaseDataEpoch +from plonk.models.module import DiffGeolocalizer + +torch.set_float32_matmul_precision("high") # TODO do we need that? + +# Registering the "eval" resolver allows for advanced config +# interpolation with arithmetic operations in hydra: +# https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html +OmegaConf.register_new_resolver("eval", eval) + + +def wandb_init(cfg): + directory = cfg.checkpoints.dirpath + if isfile(join(directory, "wandb_id.txt")) and cfg.logger_suffix == "": + with open(join(directory, "wandb_id.txt"), "r") as f: + wandb_id = f.readline() + else: + rank = _get_rank() + wandb_id = wandb.util.generate_id() + print(f"Generated wandb id: {wandb_id}") + if rank == 0 or rank is None: + with open(join(directory, "wandb_id.txt"), "w") as f: + f.write(str(wandb_id)) + + return wandb_id + + +def load_model(cfg, dict_config, wandb_id, callbacks): + directory = cfg.checkpoints.dirpath + if isfile(join(directory, "last.ckpt")): + checkpoint_path = join(directory, "last.ckpt") + logger = instantiate(cfg.logger, id=wandb_id, resume="allow") + model = DiffGeolocalizer.load_from_checkpoint(checkpoint_path, cfg=cfg.model) + ckpt_path = join(directory, "last.ckpt") + print(f"Loading form checkpoint ... {ckpt_path}") + else: + ckpt_path = None + logger = instantiate(cfg.logger, id=wandb_id, resume="allow") + log_dict = {"model": dict_config["model"], "dataset": dict_config["dataset"]} + logger._wandb_init.update({"config": log_dict}) + model = DiffGeolocalizer(cfg.model) + + trainer, strategy = cfg.trainer, cfg.trainer.strategy + # from pytorch_lightning.profilers import PyTorchProfiler + + trainer = instantiate( + trainer, + strategy=strategy, + logger=logger, + callbacks=callbacks, + # profiler=PyTorchProfiler( + # dirpath="logs", + # schedule=torch.profiler.schedule(wait=1, warmup=3, active=3, repeat=1), + # on_trace_ready=torch.profiler.tensorboard_trace_handler("./logs"), + # record_shapes=True, + # with_stack=True, + # with_flops=True, + # with_modules=True, + # ), + ) + return trainer, model, ckpt_path + + +def project_init(cfg): + print("Working directory set to {}".format(os.getcwd())) + directory = cfg.checkpoints.dirpath + os.makedirs(directory, exist_ok=True) + copyfile(".hydra/config.yaml", join(directory, "config.yaml")) + + +def callback_init(cfg): + checkpoint_callback = instantiate(cfg.checkpoints) + progress_bar = instantiate(cfg.progress_bar) + lr_monitor = LearningRateMonitor() + ema_callback = EMACallback( + "network", + "ema_network", + decay=cfg.model.ema_decay, + start_ema_step=cfg.model.start_ema_step, + init_ema_random=False, + ) + fix_nan_callback = FixNANinGrad( + monitor=["train/loss"], + ) + increase_data_epoch_callback = IncreaseDataEpoch() + callbacks = [ + checkpoint_callback, + progress_bar, + lr_monitor, + ema_callback, + fix_nan_callback, + increase_data_epoch_callback, + ] + return callbacks + + +def init_datamodule(cfg): + datamodule = instantiate(cfg.datamodule) + return datamodule + + +def hydra_boilerplate(cfg): + dict_config = OmegaConf.to_container(cfg, resolve=True) + callbacks = callback_init(cfg) + datamodule = init_datamodule(cfg) + project_init(cfg) + wandb_id = wandb_init(cfg) + trainer, model, ckpt_path = load_model(cfg, dict_config, wandb_id, callbacks) + return trainer, model, datamodule, ckpt_path + + +@hydra.main(config_path="configs", config_name="config", version_base=None) +def main(cfg): + if "stage" in cfg and cfg.stage == "debug": + import lovely_tensors as lt + + lt.monkey_patch() + trainer, model, datamodule, ckpt_path = hydra_boilerplate(cfg) + model.datamodule = datamodule + # model = torch.compile(model) + if cfg.mode == "train": + trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) + elif cfg.mode == "eval": + trainer.test(model, datamodule=datamodule) + elif cfg.mode == "traineval": + cfg.mode = "train" + trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) + cfg.mode = "test" + trainer.test(model, datamodule=datamodule) + + +if __name__ == "__main__": + main() diff --git a/plonk/train_random.py b/plonk/train_random.py new file mode 100755 index 0000000000000000000000000000000000000000..e53ffe7405a3bf939b086a2bd554e133467475ca --- /dev/null +++ b/plonk/train_random.py @@ -0,0 +1,146 @@ +import os +import hydra +import wandb +from os.path import isfile, join +from shutil import copyfile + +import torch + +from omegaconf import OmegaConf +from hydra.core.hydra_config import HydraConfig +from hydra.utils import instantiate +from pytorch_lightning.callbacks import LearningRateMonitor +from lightning_fabric.utilities.rank_zero import _get_rank +from callbacks import EMACallback, FixNANinGrad, IncreaseDataEpoch +from plonk.models.module import RandomGeolocalizer + +torch.set_float32_matmul_precision("high") # TODO do we need that? + +# Registering the "eval" resolver allows for advanced config +# interpolation with arithmetic operations in hydra: +# https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html +OmegaConf.register_new_resolver("eval", eval) + + +def wandb_init(cfg): + directory = cfg.checkpoints.dirpath + if isfile(join(directory, "wandb_id.txt")) and cfg.logger_suffix == "": + with open(join(directory, "wandb_id.txt"), "r") as f: + wandb_id = f.readline() + else: + rank = _get_rank() + wandb_id = wandb.util.generate_id() + print(f"Generated wandb id: {wandb_id}") + if rank == 0 or rank is None: + with open(join(directory, "wandb_id.txt"), "w") as f: + f.write(str(wandb_id)) + + return wandb_id + + +def load_model(cfg, dict_config, wandb_id, callbacks): + directory = cfg.checkpoints.dirpath + if isfile(join(directory, "last.ckpt")): + checkpoint_path = join(directory, "last.ckpt") + logger = instantiate(cfg.logger, id=wandb_id, resume="allow") + model = RandomGeolocalizer.load_from_checkpoint(checkpoint_path, cfg=cfg.model) + ckpt_path = join(directory, "last.ckpt") + print(f"Loading form checkpoint ... {ckpt_path}") + else: + ckpt_path = None + logger = instantiate(cfg.logger, id=wandb_id, resume="allow") + log_dict = {"model": dict_config["model"], "dataset": dict_config["dataset"]} + logger._wandb_init.update({"config": log_dict}) + model = RandomGeolocalizer(cfg.model) + + trainer, strategy = cfg.trainer, cfg.trainer.strategy + # from pytorch_lightning.profilers import PyTorchProfiler + + trainer = instantiate( + trainer, + strategy=strategy, + logger=logger, + callbacks=callbacks, + # profiler=PyTorchProfiler( + # dirpath="logs", + # schedule=torch.profiler.schedule(wait=1, warmup=3, active=3, repeat=1), + # on_trace_ready=torch.profiler.tensorboard_trace_handler("./logs"), + # record_shapes=True, + # with_stack=True, + # with_flops=True, + # with_modules=True, + # ), + ) + return trainer, model, ckpt_path + + +def project_init(cfg): + print("Working directory set to {}".format(os.getcwd())) + directory = cfg.checkpoints.dirpath + os.makedirs(directory, exist_ok=True) + copyfile(".hydra/config.yaml", join(directory, "config.yaml")) + + +def callback_init(cfg): + checkpoint_callback = instantiate(cfg.checkpoints) + progress_bar = instantiate(cfg.progress_bar) + lr_monitor = LearningRateMonitor() + ema_callback = EMACallback( + "network", + "ema_network", + decay=cfg.model.ema_decay, + start_ema_step=cfg.model.start_ema_step, + init_ema_random=False, + ) + fix_nan_callback = FixNANinGrad( + monitor=["train/loss"], + ) + increase_data_epoch_callback = IncreaseDataEpoch() + callbacks = [ + checkpoint_callback, + progress_bar, + lr_monitor, + ema_callback, + fix_nan_callback, + increase_data_epoch_callback, + ] + return callbacks + + +def init_datamodule(cfg): + datamodule = instantiate(cfg.datamodule) + return datamodule + + +def hydra_boilerplate(cfg): + dict_config = OmegaConf.to_container(cfg, resolve=True) + callbacks = callback_init(cfg) + datamodule = init_datamodule(cfg) + project_init(cfg) + wandb_id = wandb_init(cfg) + trainer, model, ckpt_path = load_model(cfg, dict_config, wandb_id, callbacks) + return trainer, model, datamodule, ckpt_path + + +@hydra.main(config_path="configs", config_name="config", version_base=None) +def main(cfg): + if "stage" in cfg and cfg.stage == "debug": + import lovely_tensors as lt + + lt.monkey_patch() + trainer, model, datamodule, ckpt_path = hydra_boilerplate(cfg) + model.datamodule = datamodule + # model = torch.compile(model) + if cfg.mode == "train": + trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) + elif cfg.mode == "eval": + trainer.test(model, datamodule=datamodule) + elif cfg.mode == "traineval": + cfg.mode = "train" + trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) + cfg.mode = "test" + trainer.test(model, datamodule=datamodule) + + +if __name__ == "__main__": + main() diff --git a/plonk/train_von_fisher.py b/plonk/train_von_fisher.py new file mode 100755 index 0000000000000000000000000000000000000000..b176c9d4ac2d00986b1b43682221a95459af13d4 --- /dev/null +++ b/plonk/train_von_fisher.py @@ -0,0 +1,148 @@ +import os +import hydra +import wandb +from os.path import isfile, join +from shutil import copyfile + +import torch + +from omegaconf import OmegaConf +from hydra.core.hydra_config import HydraConfig +from hydra.utils import instantiate +from pytorch_lightning.callbacks import LearningRateMonitor +from lightning_fabric.utilities.rank_zero import _get_rank +from callbacks import EMACallback, FixNANinGrad, IncreaseDataEpoch +from plonk.models.module import VonFisherGeolocalizer + +torch.set_float32_matmul_precision("high") # TODO do we need that? + +# Registering the "eval" resolver allows for advanced config +# interpolation with arithmetic operations in hydra: +# https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html +OmegaConf.register_new_resolver("eval", eval) + + +def wandb_init(cfg): + directory = cfg.checkpoints.dirpath + if isfile(join(directory, "wandb_id.txt")): + with open(join(directory, "wandb_id.txt"), "r") as f: + wandb_id = f.readline() + else: + rank = _get_rank() + wandb_id = wandb.util.generate_id() + print(f"Generated wandb id: {wandb_id}") + if rank == 0 or rank is None: + with open(join(directory, "wandb_id.txt"), "w") as f: + f.write(str(wandb_id)) + + return wandb_id + + +def load_model(cfg, dict_config, wandb_id, callbacks): + directory = cfg.checkpoints.dirpath + if isfile(join(directory, "last.ckpt")): + checkpoint_path = join(directory, "last.ckpt") + logger = instantiate(cfg.logger, id=wandb_id, resume="allow") + model = VonFisherGeolocalizer.load_from_checkpoint( + checkpoint_path, cfg=cfg.model + ) + ckpt_path = join(directory, "last.ckpt") + print(f"Loading form checkpoint ... {ckpt_path}") + else: + ckpt_path = None + logger = instantiate(cfg.logger, id=wandb_id, resume="allow") + log_dict = {"model": dict_config["model"], "dataset": dict_config["dataset"]} + logger._wandb_init.update({"config": log_dict}) + model = VonFisherGeolocalizer(cfg.model) + + trainer, strategy = cfg.trainer, cfg.trainer.strategy + # from pytorch_lightning.profilers import PyTorchProfiler + + trainer = instantiate( + trainer, + strategy=strategy, + logger=logger, + callbacks=callbacks, + # profiler=PyTorchProfiler( + # dirpath="logs", + # schedule=torch.profiler.schedule(wait=1, warmup=3, active=3, repeat=1), + # on_trace_ready=torch.profiler.tensorboard_trace_handler("./logs"), + # record_shapes=True, + # with_stack=True, + # with_flops=True, + # with_modules=True, + # ), + ) + return trainer, model, ckpt_path + + +def project_init(cfg): + print("Working directory set to {}".format(os.getcwd())) + directory = cfg.checkpoints.dirpath + os.makedirs(directory, exist_ok=True) + copyfile(".hydra/config.yaml", join(directory, "config.yaml")) + + +def callback_init(cfg): + checkpoint_callback = instantiate(cfg.checkpoints) + progress_bar = instantiate(cfg.progress_bar) + lr_monitor = LearningRateMonitor() + ema_callback = EMACallback( + "network", + "ema_network", + decay=cfg.model.ema_decay, + start_ema_step=cfg.model.start_ema_step, + init_ema_random=False, + ) + fix_nan_callback = FixNANinGrad( + monitor=["train/loss"], + ) + increase_data_epoch_callback = IncreaseDataEpoch() + callbacks = [ + checkpoint_callback, + progress_bar, + lr_monitor, + ema_callback, + fix_nan_callback, + increase_data_epoch_callback, + ] + return callbacks + + +def init_datamodule(cfg): + datamodule = instantiate(cfg.datamodule) + return datamodule + + +def hydra_boilerplate(cfg): + dict_config = OmegaConf.to_container(cfg, resolve=True) + callbacks = callback_init(cfg) + datamodule = init_datamodule(cfg) + project_init(cfg) + wandb_id = wandb_init(cfg) + trainer, model, ckpt_path = load_model(cfg, dict_config, wandb_id, callbacks) + return trainer, model, datamodule, ckpt_path + + +@hydra.main(config_path="configs", config_name="config", version_base=None) +def main(cfg): + if "stage" in cfg and cfg.stage == "debug": + import lovely_tensors as lt + + lt.monkey_patch() + trainer, model, datamodule, ckpt_path = hydra_boilerplate(cfg) + model.datamodule = datamodule + # model = torch.compile(model) + if cfg.mode == "train": + trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) + elif cfg.mode == "eval": + trainer.test(model, datamodule=datamodule) + elif cfg.mode == "traineval": + cfg.mode = "train" + trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) + cfg.mode = "test" + trainer.test(model, datamodule=datamodule) + + +if __name__ == "__main__": + main() diff --git a/plonk/utils/__init__.py b/plonk/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/plonk/utils/image_processing.py b/plonk/utils/image_processing.py new file mode 100755 index 0000000000000000000000000000000000000000..8f885eeefd3ff9f0152034b32ac441caa2b1a4cd --- /dev/null +++ b/plonk/utils/image_processing.py @@ -0,0 +1,58 @@ +import torch +import torch.nn.functional as F +import torchvision + + +def remap_image_torch(image): + image_torch = ((image + 1) / 2.0) * 255.0 + image_torch = torch.clip(image_torch, 0, 255).to(torch.uint8) + return image_torch + + +class CenterCrop(torch.nn.Module): + """Crops the given image at the center. Allows to crop to the maximum possible size. + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + ratio (str): Desired output ratio of the crop that will do the maximum possible crop with the given ratio. + """ + + def __init__(self, size=None, ratio="1:1"): + super().__init__() + self.size = size + self.ratio = ratio + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ + if self.size is None: + if isinstance(img, torch.Tensor): + h, w = img.shape[-2:] + else: + w, h = img.size + ratio = self.ratio.split(":") + ratio = float(ratio[0]) / float(ratio[1]) + ratioed_w = int(h * ratio) + ratioed_h = int(w / ratio) + if w >= h: + if ratioed_h <= h: + size = (ratioed_h, w) + else: + size = (h, ratioed_w) + else: + if ratioed_w <= w: + size = (h, ratioed_w) + else: + size = (ratioed_h, w) + else: + size = self.size + return torchvision.transforms.functional.center_crop(img, size) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" diff --git a/plonk/utils/kde.py b/plonk/utils/kde.py new file mode 100644 index 0000000000000000000000000000000000000000..1afe32b79c03cb4ef266fb8def417f5b162d5a5c --- /dev/null +++ b/plonk/utils/kde.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt + + +class BatchedKDE(nn.Module): + def __init__(self, bandwith=0.0): + super().__init__() + self.bandwidth = bandwith + self.X = None + + def fit(self, X: torch.Tensor): + self.mu = X + self.nmu2 = torch.sum(X * X, dim=-1, keepdim=True) + b, n, d = X.shape + if self.bandwidth == 0: + q = torch.quantile(X.view(b, -1), 0.75) - torch.quantile( + X.view(b, -1), 0.25 + ) + self.bandwidth = ( + 0.9 * torch.min(torch.std(X, dim=(1, 2)), q / 1.34) / pow(n, 0.2) + ) + + def score(self, X): + nx2 = torch.sum(X * X, dim=-1, keepdim=True) + dot = torch.einsum("bnd, bmd -> bnm", X, self.mu) + dist = nx2 + self.nmu2.transpose(1, 2) - 2 * dot + return torch.sum( + torch.exp(-dist / self.bandwidth.unsqueeze(-1).unsqueeze(-1)), dim=-1 + ) diff --git a/plonk/utils/lr_scheduler.py b/plonk/utils/lr_scheduler.py new file mode 100755 index 0000000000000000000000000000000000000000..f7136bef13d119dd3cff31b02b7226e96c88b4cd --- /dev/null +++ b/plonk/utils/lr_scheduler.py @@ -0,0 +1,96 @@ +import math + + +class WarmupLR: + """ + Linear Warmup learning rate scheduler. After warmup, learning rate is + constant. + + Args: + optimizer (torch.optim.Optimizer): optimizer + warmup_steps (int): number of warmup steps + + """ + + def __init__(self, optimizer, warmup_steps): + self.optimizer = optimizer + self.warmup_steps = warmup_steps + self.base_lr = None + + def get_lr(self, lr, step): + return lr * min(step / max(self.warmup_steps, 1), 1.0) + + def step(self, step): + if self.base_lr is None: + self.base_lr = [ + param_group["lr"] for param_group in self.optimizer.param_groups + ] + for param_group, base_lr_group in zip( + self.optimizer.param_groups, self.base_lr + ): + param_group["lr"] = self.get_lr(base_lr_group, step) + + def state_dict(self): + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + + +class WarmupCosineDecayLR: + """ + Linear Warmup learning rate scheduler. After warmup, learning rate is + constant. + After warmup, learning rate follows a cosine decay. + + Args: + optimizer (torch.optim.Optimizer): optimizer + warmup_steps (int): number of warmup steps + total_steps (int): total number of steps + rate (float): cosine decay rate + """ + + def __init__(self, optimizer, warmup_steps, total_steps, rate=1.0): + self.optimizer = optimizer + self.warmup_steps = warmup_steps + self.base_lr = None + self.total_steps = total_steps + self.rate = rate + + def get_lr(self, lr, step): + if step < self.warmup_steps: + return lr * min(step / max(self.warmup_steps, 1), 1.0) + else: + return ( + 0.5 + * lr + * ( + 1 + + math.cos( + self.rate + * math.pi + * (step - self.warmup_steps) + / (self.total_steps - self.warmup_steps) + ) + ) + ) + + def step(self, step): + if self.base_lr is None: + self.base_lr = [ + param_group["lr"] for param_group in self.optimizer.param_groups + ] + for param_group, base_lr_group in zip( + self.optimizer.param_groups, self.base_lr + ): + param_group["lr"] = self.get_lr(base_lr_group, step) + + def state_dict(self): + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) diff --git a/plonk/utils/manifolds.py b/plonk/utils/manifolds.py new file mode 100644 index 0000000000000000000000000000000000000000..94be76b6377ea1969344338443282b99bed1b7a0 --- /dev/null +++ b/plonk/utils/manifolds.py @@ -0,0 +1,43 @@ +"""Copyright (c) Meta Platforms, Inc. and affiliates.""" + +import math +import torch +from geoopt.manifolds import Sphere as geoopt_Sphere + + +class Sphere(geoopt_Sphere): + def transp(self, x, y, v): + denom = 1 + self.inner(x, x, y, keepdim=True) + res = v - self.inner(x, y, v, keepdim=True) / denom * (x + y) + cond = denom.gt(1e-3) + return torch.where(cond, res, -v) + + def uniform_logprob(self, x): + dim = x.shape[-1] + return torch.full_like( + x[..., 0], + math.lgamma(dim / 2) - (math.log(2) + (dim / 2) * math.log(math.pi)), + ) + + def random_base(self, *args, **kwargs): + return self.random_uniform(*args, **kwargs) + + def base_logprob(self, *args, **kwargs): + return self.uniform_logprob(*args, **kwargs) + + +def geodesic(manifold, start_point, end_point): + shooting_tangent_vec = manifold.logmap(start_point, end_point) + + def path(t): + """Generate parameterized function for geodesic curve. + Parameters + ---------- + t : array-like, shape=[n_points,] + Times at which to compute points of the geodesics. + """ + tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec) + points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs) + return points_at_time_t + + return path diff --git a/plonk/utils/model_utils.py b/plonk/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4ebf21894807ece574ae3f88664a635a2b431ab --- /dev/null +++ b/plonk/utils/model_utils.py @@ -0,0 +1,14 @@ +def print_trainable_parameters(model): + """ + Prints the number and percentage of trainable parameters in the model. + Useful for tracking % parameters trained for LoRA. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) diff --git a/plonk/utils/optimizers.py b/plonk/utils/optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..28fbdea278efa62e899778951ecf157b773f640b --- /dev/null +++ b/plonk/utils/optimizers.py @@ -0,0 +1,111 @@ +"""Lamb optimizer.""" + +import torch +from torch.optim import Optimizer +import math + + +class Lamb(Optimizer): + r"""Implements Lamb algorithm. + It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + adam (bool, optional): always use trust ratio = 1, which turns this into + Adam. Useful for comparison purposes. + .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, adam=False + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.adam = adam + super(Lamb, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "Lamb does not support sparse gradients, consider SparseAdam instad." + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # m_t + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + # v_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Paper v3 does not use debiasing. + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + exp_avg_hat = exp_avg / bias_correction1 + exp_avg_sq_hat = exp_avg_sq / bias_correction2 + # Apply bias to lr to avoid broadcast. + step_size = group["lr"] + + do_layer_adaptation = ( + group["layer_adaptation"] + if "layer_adaptation" in group + else group["weight_decay"] > 0 + ) + + adam_step = exp_avg_hat / exp_avg_sq_hat.sqrt().add(group["eps"]) + if group["weight_decay"] != 0: + adam_step.add_(p.data, alpha=group["weight_decay"]) + if do_layer_adaptation: + weight_norm = p.data.norm(p=2) + adam_norm = adam_step.norm(p=2) + trust_ratio = torch.where( + weight_norm.ne(0), + torch.where(adam_norm.ne(0), weight_norm / adam_norm, 1), + 1, + ) + if self.adam or not do_layer_adaptation: + trust_ratio = 1 + + p.data.add_(adam_step, alpha=-step_size * trust_ratio) + return loss diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..020ff8cf98061b120e5ed712ecbd0902935632df --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +gradio +torch +torchvision +torchaudio +transformers +accelerate +numpy +scipy==1.13.1 +scikit-learn +Pillow +einops +torchdiffeq +geoopt +huggingface_hub \ No newline at end of file