|
from typing import Any |
|
import pytorch_lightning as L |
|
import torch |
|
import torch.nn as nn |
|
from hydra.utils import instantiate |
|
import copy |
|
import pandas as pd |
|
import numpy as np |
|
from tqdm import tqdm |
|
from utils.manifolds import Sphere |
|
from torch.func import jacrev, vjp, vmap |
|
from torchdiffeq import odeint |
|
from geoopt import ProductManifold, Euclidean |
|
from models.samplers.riemannian_flow_sampler import ode_riemannian_flow_sampler |
|
|
|
|
|
class DiffGeolocalizer(L.LightningModule): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.network = instantiate(cfg.network) |
|
|
|
self.input_dim = cfg.network.input_dim |
|
self.train_noise_scheduler = instantiate(cfg.train_noise_scheduler) |
|
self.inference_noise_scheduler = instantiate(cfg.inference_noise_scheduler) |
|
self.data_preprocessing = instantiate(cfg.data_preprocessing) |
|
self.cond_preprocessing = instantiate(cfg.cond_preprocessing) |
|
self.preconditioning = instantiate(cfg.preconditioning) |
|
|
|
self.ema_network = copy.deepcopy(self.network).requires_grad_(False) |
|
self.ema_network.eval() |
|
self.postprocessing = instantiate(cfg.postprocessing) |
|
self.val_sampler = instantiate(cfg.val_sampler) |
|
self.test_sampler = instantiate(cfg.test_sampler) |
|
self.loss = instantiate(cfg.loss)( |
|
self.train_noise_scheduler, |
|
) |
|
self.val_metrics = instantiate(cfg.val_metrics) |
|
self.test_metrics = instantiate(cfg.test_metrics) |
|
self.manifold = instantiate(cfg.manifold) if hasattr(cfg, "manifold") else None |
|
|
|
self.interpolant = cfg.interpolant |
|
|
|
def training_step(self, batch, batch_idx): |
|
with torch.no_grad(): |
|
batch = self.data_preprocessing(batch) |
|
batch = self.cond_preprocessing(batch) |
|
batch_size = batch["x_0"].shape[0] |
|
loss = self.loss(self.preconditioning, self.network, batch).mean() |
|
self.log( |
|
"train/loss", |
|
loss, |
|
sync_dist=True, |
|
on_step=True, |
|
on_epoch=True, |
|
batch_size=batch_size, |
|
) |
|
return loss |
|
|
|
def on_before_optimizer_step(self, optimizer): |
|
if self.global_step == 0: |
|
no_grad = [] |
|
for name, param in self.network.named_parameters(): |
|
if param.grad is None: |
|
no_grad.append(name) |
|
if len(no_grad) > 0: |
|
print("Parameters without grad:") |
|
print(no_grad) |
|
|
|
def on_validation_start(self): |
|
self.validation_generator = torch.Generator(device=self.device).manual_seed( |
|
3407 |
|
) |
|
self.validation_generator_ema = torch.Generator(device=self.device).manual_seed( |
|
3407 |
|
) |
|
|
|
def validation_step(self, batch, batch_idx): |
|
batch = self.data_preprocessing(batch) |
|
batch = self.cond_preprocessing(batch) |
|
batch_size = batch["x_0"].shape[0] |
|
loss = self.loss( |
|
self.preconditioning, |
|
self.network, |
|
batch, |
|
generator=self.validation_generator, |
|
).mean() |
|
self.log( |
|
"val/loss", |
|
loss, |
|
sync_dist=True, |
|
on_step=False, |
|
on_epoch=True, |
|
batch_size=batch_size, |
|
) |
|
if hasattr(self, "ema_model"): |
|
loss_ema = self.loss( |
|
self.preconditioning, |
|
self.ema_network, |
|
batch, |
|
generator=self.validation_generator_ema, |
|
).mean() |
|
self.log( |
|
"val/loss_ema", |
|
loss_ema, |
|
sync_dist=True, |
|
on_step=False, |
|
on_epoch=True, |
|
batch_size=batch_size, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_test_start(self): |
|
self.test_generator = torch.Generator(device=self.device).manual_seed(3407) |
|
|
|
def test_step_simple(self, batch, batch_idx): |
|
batch = self.data_preprocessing(batch) |
|
batch = self.cond_preprocessing(batch) |
|
batch_size = batch["x_0"].shape[0] |
|
if isinstance(self.manifold, Sphere): |
|
x_N = self.manifold.random_base( |
|
batch_size, |
|
self.input_dim, |
|
device=self.device, |
|
) |
|
x_N = x_N.reshape(batch_size, self.input_dim) |
|
else: |
|
x_N = torch.randn( |
|
batch_size, |
|
self.input_dim, |
|
device=self.device, |
|
generator=self.test_generator, |
|
) |
|
cond = batch[self.cfg.cond_preprocessing.output_key] |
|
|
|
samples = self.sample( |
|
x_N=x_N, |
|
cond=cond, |
|
stage="val", |
|
generator=self.test_generator, |
|
cfg=self.cfg.cfg_rate, |
|
) |
|
self.test_metrics.update({"gps": samples}, batch) |
|
if self.cfg.compute_nll: |
|
nll = -self.compute_exact_loglikelihood(batch, cfg=0).mean() |
|
self.log( |
|
"test/NLL", |
|
nll, |
|
sync_dist=True, |
|
on_step=False, |
|
on_epoch=True, |
|
batch_size=batch_size, |
|
) |
|
|
|
def test_best_nll(self, batch, batch_idx): |
|
batch = self.data_preprocessing(batch) |
|
batch = self.cond_preprocessing(batch) |
|
batch_size = batch["x_0"].shape[0] |
|
num_sample_per_cond = 32 |
|
if isinstance(self.manifold, Sphere): |
|
x_N = self.manifold.random_base( |
|
batch_size * num_sample_per_cond, |
|
self.input_dim, |
|
device=self.device, |
|
) |
|
x_N = x_N.reshape(batch_size * num_sample_per_cond, self.input_dim) |
|
else: |
|
x_N = torch.randn( |
|
batch_size * num_sample_per_cond, |
|
self.input_dim, |
|
device=self.device, |
|
generator=self.test_generator, |
|
) |
|
cond = ( |
|
batch[self.cfg.cond_preprocessing.output_key] |
|
.unsqueeze(1) |
|
.repeat(1, num_sample_per_cond, 1) |
|
.view(-1, batch[self.cfg.cond_preprocessing.output_key].shape[-1]) |
|
) |
|
samples = self.sample_distribution( |
|
x_N, |
|
cond, |
|
sampling_batch_size=32768, |
|
stage="val", |
|
generator=self.test_generator, |
|
cfg=0, |
|
) |
|
samples = samples.view(batch_size * num_sample_per_cond, -1) |
|
batch_swarm = {"gps": samples, "emb": cond} |
|
nll_batch = -self.compute_exact_loglikelihood(batch_swarm, cfg=0) |
|
nll_batch = nll_batch.view(batch_size, num_sample_per_cond, -1) |
|
nll_best = nll_batch[ |
|
torch.arange(batch_size), nll_batch.argmin(dim=1).squeeze(1) |
|
] |
|
self.log( |
|
"test/best_nll", |
|
nll_best.mean(), |
|
sync_dist=True, |
|
on_step=False, |
|
on_epoch=True, |
|
) |
|
samples = samples.view(batch_size, num_sample_per_cond, -1)[ |
|
torch.arange(batch_size), nll_batch.argmin(dim=1).squeeze(1) |
|
] |
|
self.test_metrics.update({"gps": samples}, batch) |
|
|
|
def test_step(self, batch, batch_idx): |
|
if self.cfg.compute_swarms: |
|
self.test_best_nll(batch, batch_idx) |
|
else: |
|
self.test_step_simple(batch, batch_idx) |
|
|
|
def on_test_epoch_end(self): |
|
metrics = self.test_metrics.compute() |
|
for metric_name, metric_value in metrics.items(): |
|
self.log( |
|
f"test/{metric_name}", |
|
metric_value, |
|
sync_dist=True, |
|
on_step=False, |
|
on_epoch=True, |
|
) |
|
|
|
def configure_optimizers(self): |
|
if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay: |
|
parameters_names_wd = get_parameter_names(self.network, [nn.LayerNorm]) |
|
parameters_names_wd = [ |
|
name for name in parameters_names_wd if "bias" not in name |
|
] |
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params": [ |
|
p |
|
for n, p in self.network.named_parameters() |
|
if n in parameters_names_wd |
|
], |
|
"weight_decay": self.cfg.optimizer.optim.weight_decay, |
|
"layer_adaptation": True, |
|
}, |
|
{ |
|
"params": [ |
|
p |
|
for n, p in self.network.named_parameters() |
|
if n not in parameters_names_wd |
|
], |
|
"weight_decay": 0.0, |
|
"layer_adaptation": False, |
|
}, |
|
] |
|
optimizer = instantiate( |
|
self.cfg.optimizer.optim, optimizer_grouped_parameters |
|
) |
|
else: |
|
optimizer = instantiate(self.cfg.optimizer.optim, self.network.parameters()) |
|
if "lr_scheduler" in self.cfg: |
|
scheduler = instantiate(self.cfg.lr_scheduler)(optimizer) |
|
return [optimizer], [{"scheduler": scheduler, "interval": "step"}] |
|
else: |
|
return optimizer |
|
|
|
def lr_scheduler_step(self, scheduler, metric): |
|
scheduler.step(self.global_step) |
|
|
|
def sample( |
|
self, |
|
batch_size=None, |
|
cond=None, |
|
x_N=None, |
|
num_steps=None, |
|
stage="test", |
|
cfg=0, |
|
generator=None, |
|
return_trajectories=False, |
|
postprocessing=True, |
|
): |
|
if x_N is None: |
|
assert batch_size is not None |
|
if isinstance(self.manifold, Sphere): |
|
x_N = self.manifold.random_base( |
|
batch_size, self.input_dim, device=self.device |
|
) |
|
x_N = x_N.reshape(batch_size, self.input_dim) |
|
else: |
|
x_N = torch.randn(batch_size, self.input_dim, device=self.device) |
|
batch = {"y": x_N} |
|
if stage == "val": |
|
sampler = self.val_sampler |
|
elif stage == "test": |
|
sampler = self.test_sampler |
|
else: |
|
raise ValueError(f"Unknown stage {stage}") |
|
batch[self.cfg.cond_preprocessing.input_key] = cond |
|
batch = self.cond_preprocessing(batch, device=self.device) |
|
if num_steps is None: |
|
output = sampler( |
|
self.ema_model, |
|
batch, |
|
conditioning_keys=self.cfg.cond_preprocessing.output_key, |
|
scheduler=self.inference_noise_scheduler, |
|
cfg_rate=cfg, |
|
generator=generator, |
|
return_trajectories=return_trajectories, |
|
) |
|
else: |
|
output = sampler( |
|
self.ema_model, |
|
batch, |
|
conditioning_keys=self.cfg.cond_preprocessing.output_key, |
|
scheduler=self.inference_noise_scheduler, |
|
num_steps=num_steps, |
|
cfg_rate=cfg, |
|
generator=generator, |
|
return_trajectories=return_trajectories, |
|
) |
|
if return_trajectories: |
|
return ( |
|
self.postprocessing(output[0]) if postprocessing else output[0], |
|
[ |
|
self.postprocessing(frame) if postprocessing else frame |
|
for frame in output[1] |
|
], |
|
) |
|
else: |
|
return self.postprocessing(output) if postprocessing else output |
|
|
|
def sample_distribution( |
|
self, |
|
x_N, |
|
cond, |
|
sampling_batch_size=2048, |
|
num_steps=None, |
|
stage="test", |
|
cfg=0, |
|
generator=None, |
|
return_trajectories=False, |
|
): |
|
if return_trajectories: |
|
x_0 = [] |
|
trajectories = [] |
|
i = -1 |
|
for i in range(x_N.shape[0] // sampling_batch_size): |
|
x_N_batch = x_N[i * sampling_batch_size : (i + 1) * sampling_batch_size] |
|
cond_batch = cond[ |
|
i * sampling_batch_size : (i + 1) * sampling_batch_size |
|
] |
|
out, trajectories = self.sample( |
|
cond=cond_batch, |
|
x_N=x_N_batch, |
|
num_steps=num_steps, |
|
stage=stage, |
|
cfg=cfg, |
|
generator=generator, |
|
return_trajectories=return_trajectories, |
|
) |
|
x_0.append(out) |
|
trajectories.append(trajectories) |
|
if x_N.shape[0] % sampling_batch_size != 0: |
|
x_N_batch = x_N[(i + 1) * sampling_batch_size :] |
|
cond_batch = cond[(i + 1) * sampling_batch_size :] |
|
out, trajectories = self.sample( |
|
cond=cond_batch, |
|
x_N=x_N_batch, |
|
num_steps=num_steps, |
|
stage=stage, |
|
cfg=cfg, |
|
generator=generator, |
|
return_trajectories=return_trajectories, |
|
) |
|
x_0.append(out) |
|
trajectories.append(trajectories) |
|
x_0 = torch.cat(x_0, dim=1) |
|
trajectories = [torch.cat(frame, dim=1) for frame in trajectories] |
|
return x_0, trajectories |
|
else: |
|
x_0 = [] |
|
i = -1 |
|
for i in range(x_N.shape[0] // sampling_batch_size): |
|
x_N_batch = x_N[i * sampling_batch_size : (i + 1) * sampling_batch_size] |
|
cond_batch = cond[ |
|
i * sampling_batch_size : (i + 1) * sampling_batch_size |
|
] |
|
out = self.sample( |
|
cond=cond_batch, |
|
x_N=x_N_batch, |
|
num_steps=num_steps, |
|
stage=stage, |
|
cfg=cfg, |
|
generator=generator, |
|
return_trajectories=return_trajectories, |
|
) |
|
x_0.append(out) |
|
if x_N.shape[0] % sampling_batch_size != 0: |
|
x_N_batch = x_N[(i + 1) * sampling_batch_size :] |
|
cond_batch = cond[(i + 1) * sampling_batch_size :] |
|
out = self.sample( |
|
cond=cond_batch, |
|
x_N=x_N_batch, |
|
num_steps=num_steps, |
|
stage=stage, |
|
cfg=cfg, |
|
generator=generator, |
|
return_trajectories=return_trajectories, |
|
) |
|
x_0.append(out) |
|
x_0 = torch.cat(x_0, dim=0) |
|
return x_0 |
|
|
|
def model(self, *args, **kwargs): |
|
return self.preconditioning(self.network, *args, **kwargs) |
|
|
|
def ema_model(self, *args, **kwargs): |
|
return self.preconditioning(self.ema_network, *args, **kwargs) |
|
|
|
def compute_exact_loglikelihood( |
|
self, |
|
batch=None, |
|
x_1=None, |
|
cond=None, |
|
t1=1.0, |
|
num_steps=1000, |
|
rademacher=False, |
|
data_preprocessing=True, |
|
cfg=0, |
|
): |
|
nfe = [0] |
|
if batch is None: |
|
batch = {"x_0": x_1, "emb": cond} |
|
if data_preprocessing: |
|
batch = self.data_preprocessing(batch) |
|
batch = self.cond_preprocessing(batch) |
|
timesteps = self.inference_noise_scheduler( |
|
torch.linspace(0, t1, 2).to(batch["x_0"]) |
|
) |
|
with torch.inference_mode(mode=False): |
|
|
|
def odefunc(t, tensor): |
|
nfe[0] += 1 |
|
t = t.to(tensor) |
|
gamma = self.inference_noise_scheduler(t) |
|
x = tensor[..., : self.input_dim] |
|
y = batch["emb"] |
|
|
|
def vecfield(x, y): |
|
if cfg > 0: |
|
batch_vecfield = { |
|
"y": x, |
|
"emb": y, |
|
"gamma": gamma.reshape(-1), |
|
} |
|
model_output_cond = self.ema_model(batch_vecfield) |
|
batch_vecfield_uncond = { |
|
"y": x, |
|
"emb": torch.zeros_like(y), |
|
"gamma": gamma.reshape(-1), |
|
} |
|
model_output_uncond = self.ema_model(batch_vecfield_uncond) |
|
model_output = model_output_cond + cfg * ( |
|
model_output_cond - model_output_uncond |
|
) |
|
|
|
else: |
|
batch_vecfield = { |
|
"y": x, |
|
"emb": y, |
|
"gamma": gamma.reshape(-1), |
|
} |
|
model_output = self.ema_model(batch_vecfield) |
|
|
|
if self.interpolant == "flow_matching": |
|
d_gamma = self.inference_noise_scheduler.derivative(t).reshape( |
|
-1, 1 |
|
) |
|
return d_gamma * model_output |
|
elif self.interpolant == "diffusion": |
|
alpha_t = self.inference_noise_scheduler.alpha(t).reshape(-1, 1) |
|
return ( |
|
-1 / 2 * (alpha_t * x - torch.abs(alpha_t) * model_output) |
|
) |
|
else: |
|
raise ValueError(f"Unknown interpolant {self.interpolant}") |
|
|
|
if rademacher: |
|
v = torch.randint_like(x, 2) * 2 - 1 |
|
else: |
|
v = None |
|
dx, div = output_and_div(vecfield, x, y, v=v) |
|
div = div.reshape(-1, 1) |
|
del t, x |
|
return torch.cat([dx, div], dim=-1) |
|
|
|
x_1 = batch["x_0"] |
|
state1 = torch.cat([x_1, torch.zeros_like(x_1[..., :1])], dim=-1) |
|
with torch.no_grad(): |
|
if False and isinstance(self.manifold, Sphere): |
|
print("Riemannian flow sampler") |
|
product_man = ProductManifold( |
|
(self.manifold, self.input_dim), (Euclidean(), 1) |
|
) |
|
state0 = ode_riemannian_flow_sampler( |
|
odefunc, |
|
state1, |
|
manifold=product_man, |
|
scheduler=self.inference_noise_scheduler, |
|
num_steps=num_steps, |
|
) |
|
else: |
|
print("ODE solver") |
|
state0 = odeint( |
|
odefunc, |
|
state1, |
|
t=torch.linspace(0, t1, 2).to(batch["x_0"]), |
|
atol=1e-6, |
|
rtol=1e-6, |
|
method="dopri5", |
|
options={"min_step": 1e-5}, |
|
)[-1] |
|
x_0, logdetjac = state0[..., : self.input_dim], state0[..., -1] |
|
if self.manifold is not None: |
|
x_0 = self.manifold.projx(x_0) |
|
logp0 = self.manifold.base_logprob(x_0) |
|
else: |
|
logp0 = ( |
|
-1 / 2 * (x_0**2).sum(dim=-1) |
|
- self.input_dim |
|
* torch.log(torch.tensor(2 * np.pi, device=x_0.device)) |
|
/ 2 |
|
) |
|
print(f"nfe: {nfe[0]}") |
|
logp1 = logp0 + logdetjac |
|
logp1 = logp1 / (self.input_dim * np.log(2)) |
|
return logp1 |
|
|
|
|
|
def get_parameter_names(model, forbidden_layer_types): |
|
""" |
|
Returns the names of the model parameters that are not inside a forbidden layer. |
|
Taken from HuggingFace transformers. |
|
""" |
|
result = [] |
|
for name, child in model.named_children(): |
|
result += [ |
|
f"{name}.{n}" |
|
for n in get_parameter_names(child, forbidden_layer_types) |
|
if not isinstance(child, tuple(forbidden_layer_types)) |
|
] |
|
|
|
result += list(model._parameters.keys()) |
|
return result |
|
|
|
|
|
|
|
def div_fn(u): |
|
"""Accepts a function u:R^D -> R^D.""" |
|
J = jacrev(u, argnums=0) |
|
return lambda x, y: torch.trace(J(x, y).squeeze(0)) |
|
|
|
|
|
def output_and_div(vecfield, x, y, v=None): |
|
if v is None: |
|
dx = vecfield(x, y) |
|
div = vmap(div_fn(vecfield))(x, y) |
|
else: |
|
vecfield_x = lambda x: vecfield(x, y) |
|
dx, vjpfunc = vjp(vecfield_x, x) |
|
vJ = vjpfunc(v)[0] |
|
div = torch.sum(vJ * v, dim=-1) |
|
return dx, div |
|
|
|
|
|
class VonFisherGeolocalizer(L.LightningModule): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.network = instantiate(cfg.network) |
|
|
|
self.input_dim = cfg.network.input_dim |
|
self.data_preprocessing = instantiate(cfg.data_preprocessing) |
|
self.cond_preprocessing = instantiate(cfg.cond_preprocessing) |
|
self.preconditioning = instantiate(cfg.preconditioning) |
|
|
|
self.ema_network = copy.deepcopy(self.network).requires_grad_(False) |
|
self.ema_network.eval() |
|
self.postprocessing = instantiate(cfg.postprocessing) |
|
self.val_sampler = instantiate(cfg.val_sampler) |
|
self.test_sampler = instantiate(cfg.test_sampler) |
|
self.loss = instantiate(cfg.loss)() |
|
self.val_metrics = instantiate(cfg.val_metrics) |
|
self.test_metrics = instantiate(cfg.test_metrics) |
|
|
|
def training_step(self, batch, batch_idx): |
|
with torch.no_grad(): |
|
batch = self.data_preprocessing(batch) |
|
batch = self.cond_preprocessing(batch) |
|
batch_size = batch["x_0"].shape[0] |
|
loss = self.loss(self.preconditioning, self.network, batch).mean() |
|
self.log( |
|
"train/loss", |
|
loss, |
|
sync_dist=True, |
|
on_step=True, |
|
on_epoch=True, |
|
batch_size=batch_size, |
|
) |
|
return loss |
|
|
|
def on_before_optimizer_step(self, optimizer): |
|
if self.global_step == 0: |
|
no_grad = [] |
|
for name, param in self.network.named_parameters(): |
|
if param.grad is None: |
|
no_grad.append(name) |
|
if len(no_grad) > 0: |
|
print("Parameters without grad:") |
|
print(no_grad) |
|
|
|
def on_validation_start(self): |
|
self.validation_generator = torch.Generator(device=self.device).manual_seed( |
|
3407 |
|
) |
|
self.validation_generator_ema = torch.Generator(device=self.device).manual_seed( |
|
3407 |
|
) |
|
|
|
def validation_step(self, batch, batch_idx): |
|
batch = self.data_preprocessing(batch) |
|
batch = self.cond_preprocessing(batch) |
|
batch_size = batch["x_0"].shape[0] |
|
loss = self.loss( |
|
self.preconditioning, |
|
self.network, |
|
batch, |
|
generator=self.validation_generator, |
|
).mean() |
|
self.log( |
|
"val/loss", |
|
loss, |
|
sync_dist=True, |
|
on_step=False, |
|
on_epoch=True, |
|
batch_size=batch_size, |
|
) |
|
if hasattr(self, "ema_model"): |
|
loss_ema = self.loss( |
|
self.preconditioning, |
|
self.ema_network, |
|
batch, |
|
generator=self.validation_generator_ema, |
|
).mean() |
|
self.log( |
|
"val/loss_ema", |
|
loss_ema, |
|
sync_dist=True, |
|
on_step=False, |
|
on_epoch=True, |
|
batch_size=batch_size, |
|
) |
|
|
|
def on_test_start(self): |
|
self.test_generator = torch.Generator(device=self.device).manual_seed(3407) |
|
|
|
def test_step(self, batch, batch_idx): |
|
batch = self.data_preprocessing(batch) |
|
batch = self.cond_preprocessing(batch) |
|
batch_size = batch["x_0"].shape[0] |
|
cond = batch[self.cfg.cond_preprocessing.output_key] |
|
|
|
samples = self.sample(cond=cond, stage="test") |
|
self.test_metrics.update({"gps": samples}, batch) |
|
nll = -self.compute_exact_loglikelihood(batch).mean() |
|
self.log( |
|
"test/NLL", |
|
nll, |
|
sync_dist=True, |
|
on_step=False, |
|
on_epoch=True, |
|
batch_size=batch_size, |
|
) |
|
|
|
def on_test_epoch_end(self): |
|
metrics = self.test_metrics.compute() |
|
for metric_name, metric_value in metrics.items(): |
|
self.log( |
|
f"test/{metric_name}", |
|
metric_value, |
|
sync_dist=True, |
|
on_step=False, |
|
on_epoch=True, |
|
) |
|
|
|
def configure_optimizers(self): |
|
if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay: |
|
parameters_names_wd = get_parameter_names(self.network, [nn.LayerNorm]) |
|
parameters_names_wd = [ |
|
name for name in parameters_names_wd if "bias" not in name |
|
] |
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params": [ |
|
p |
|
for n, p in self.network.named_parameters() |
|
if n in parameters_names_wd |
|
], |
|
"weight_decay": self.cfg.optimizer.optim.weight_decay, |
|
"layer_adaptation": True, |
|
}, |
|
{ |
|
"params": [ |
|
p |
|
for n, p in self.network.named_parameters() |
|
if n not in parameters_names_wd |
|
], |
|
"weight_decay": 0.0, |
|
"layer_adaptation": False, |
|
}, |
|
] |
|
optimizer = instantiate( |
|
self.cfg.optimizer.optim, optimizer_grouped_parameters |
|
) |
|
else: |
|
optimizer = instantiate(self.cfg.optimizer.optim, self.network.parameters()) |
|
if "lr_scheduler" in self.cfg: |
|
scheduler = instantiate(self.cfg.lr_scheduler)(optimizer) |
|
return [optimizer], [{"scheduler": scheduler, "interval": "step"}] |
|
else: |
|
return optimizer |
|
|
|
def lr_scheduler_step(self, scheduler, metric): |
|
scheduler.step(self.global_step) |
|
|
|
def sample( |
|
self, |
|
batch_size=None, |
|
cond=None, |
|
postprocessing=True, |
|
stage="val", |
|
): |
|
batch = {} |
|
if stage == "val": |
|
sampler = self.val_sampler |
|
elif stage == "test": |
|
sampler = self.test_sampler |
|
else: |
|
raise ValueError(f"Unknown stage {stage}") |
|
batch[self.cfg.cond_preprocessing.input_key] = cond |
|
batch = self.cond_preprocessing(batch, device=self.device) |
|
output = sampler( |
|
self.ema_model, |
|
batch, |
|
) |
|
return self.postprocessing(output) if postprocessing else output |
|
|
|
def model(self, *args, **kwargs): |
|
return self.preconditioning(self.network, *args, **kwargs) |
|
|
|
def ema_model(self, *args, **kwargs): |
|
return self.preconditioning(self.ema_network, *args, **kwargs) |
|
|
|
def compute_exact_loglikelihood( |
|
self, |
|
batch=None, |
|
): |
|
batch = self.data_preprocessing(batch) |
|
batch = self.cond_preprocessing(batch) |
|
return -self.loss(self.preconditioning, self.ema_network, batch) |
|
|
|
|
|
class RandomGeolocalizer(L.LightningModule): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.test_metrics = instantiate(cfg.test_metrics) |
|
self.data_preprocessing = instantiate(cfg.data_preprocessing) |
|
self.cond_preprocessing = instantiate(cfg.cond_preprocessing) |
|
self.postprocessing = instantiate(cfg.postprocessing) |
|
|
|
def test_step(self, batch, batch_idx): |
|
batch = self.data_preprocessing(batch) |
|
batch = self.cond_preprocessing(batch) |
|
batch_size = batch["x_0"].shape[0] |
|
samples = torch.randn(batch_size, 3, device=self.device) |
|
samples = samples / samples.norm(dim=-1, keepdim=True) |
|
samples = self.postprocessing(samples) |
|
self.test_metrics.update({"gps": samples}, batch) |
|
|
|
def on_test_epoch_end(self): |
|
metrics = self.test_metrics.compute() |
|
for metric_name, metric_value in metrics.items(): |
|
self.log( |
|
f"test/{metric_name}", |
|
metric_value, |
|
sync_dist=True, |
|
on_step=False, |
|
on_epoch=True, |
|
) |
|
|