kylanoconnor's picture
Initial PLONK deployment for Hugging Face Spaces
fac3244
import torch
from plonk.models.pretrained_models import Plonk
from plonk.models.samplers.riemannian_flow_sampler import riemannian_flow_sampler
from plonk.models.postprocessing import CartesiantoGPS
from plonk.models.schedulers import (
SigmoidScheduler,
LinearScheduler,
CosineScheduler,
)
from plonk.models.preconditioning import DDPMPrecond
from torchvision import transforms
from transformers import CLIPProcessor, CLIPVisionModel
from plonk.utils.image_processing import CenterCrop
import numpy as np
from plonk.utils.manifolds import Sphere
from torch.func import jacrev, vmap, vjp
from torchdiffeq import odeint
from tqdm import tqdm
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
MODELS = {
"nicolas-dufour/PLONK_YFCC": {"emb_name": "dinov2"},
"nicolas-dufour/PLONK_OSV_5M": {
"emb_name": "street_clip",
},
"nicolas-dufour/PLONK_iNaturalist": {
"emb_name": "dinov2",
},
}
def scheduler_fn(
scheduler_type: str, start: float, end: float, tau: float, clip_min: float = 1e-9
):
if scheduler_type == "sigmoid":
return SigmoidScheduler(start, end, tau, clip_min)
elif scheduler_type == "cosine":
return CosineScheduler(start, end, tau, clip_min)
elif scheduler_type == "linear":
return LinearScheduler(clip_min=clip_min)
else:
raise ValueError(f"Scheduler type {scheduler_type} not supported")
class DinoV2FeatureExtractor:
def __init__(self, device=device):
super().__init__()
self.device = device
self.emb_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg")
self.emb_model.eval()
self.emb_model.to(self.device)
self.augmentation = transforms.Compose(
[
CenterCrop(ratio="1:1"),
transforms.Resize(
336, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
),
]
)
def __call__(self, batch):
embs = []
with torch.no_grad():
for img in batch["img"]:
emb = self.emb_model(
self.augmentation(img).unsqueeze(0).to(self.device)
).squeeze(0)
embs.append(emb)
batch["emb"] = torch.stack(embs)
return batch
class StreetClipFeatureExtractor:
def __init__(self, device=device):
self.device = device
self.emb_model = CLIPVisionModel.from_pretrained("geolocal/StreetCLIP").to(
device
)
self.processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
def __call__(self, batch):
inputs = self.processor(images=batch["img"], return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.emb_model(**inputs)
embeddings = outputs.last_hidden_state[:, 0]
batch["emb"] = embeddings
return batch
def load_prepocessing(model_name, dtype=torch.float32):
if MODELS[model_name]["emb_name"] == "dinov2":
return DinoV2FeatureExtractor()
elif MODELS[model_name]["emb_name"] == "street_clip":
return StreetClipFeatureExtractor()
else:
raise ValueError(f"Embedding model {MODELS[model_name]['emb_name']} not found")
# Helper functions adapted from plonk/models/module.py
# for likelihood computation
def div_fn(u):
"""Accepts a function u:R^D -> R^D."""
J = jacrev(u, argnums=0)
return lambda x, y: torch.trace(J(x, y).squeeze(0))
def output_and_div(vecfield, x, y, v=None):
if v is None:
dx = vecfield(x, y)
div = vmap(div_fn(vecfield))(x, y)
else:
vecfield_x = lambda x: vecfield(x, y)
dx, vjpfunc = vjp(vecfield_x, x)
vJ = vjpfunc(v)[0]
div = torch.sum(vJ * v, dim=-1)
return dx, div
def _gps_degrees_to_cartesian(gps_coords_deg, device):
"""Converts GPS coordinates (latitude, longitude) in degrees to Cartesian coordinates."""
if not isinstance(gps_coords_deg, np.ndarray):
gps_coords_deg = np.array(gps_coords_deg)
if gps_coords_deg.ndim == 1:
gps_coords_deg = gps_coords_deg[np.newaxis, :]
lat_rad = np.radians(gps_coords_deg[:, 0])
lon_rad = np.radians(gps_coords_deg[:, 1])
x = np.cos(lat_rad) * np.cos(lon_rad)
y = np.cos(lat_rad) * np.sin(lon_rad)
z = np.sin(lat_rad)
cartesian_coords = np.stack([x, y, z], axis=-1)
return torch.tensor(cartesian_coords, dtype=torch.float32, device=device)
class PlonkPipeline:
"""
The PlonkPipeline class is designed to perform geolocation prediction from images using a pre-trained PLONK model.
It integrates various components such as feature extractors, samplers, and coordinate transformations to predict locations.
Initialization:
PlonkPipeline(
model_path,
scheduler="sigmoid",
scheduler_start=-7,
scheduler_end=3,
scheduler_tau=1.0,
device="cuda",
)
Parameters:
model_path (str): Path to the pre-trained PLONK model.
scheduler (str): The scheduler type to use. Options are "sigmoid", "cosine", "linear". Default is "sigmoid".
scheduler_start (float): Start value for the scheduler. Default is -7.
scheduler_end (float): End value for the scheduler. Default is 3.
scheduler_tau (float): Tau value for the scheduler. Default is 1.0.
device (str): Device to run the model on. Default is "cuda".
Methods:
model(*args, **kwargs):
Runs the preconditioning on the network with the provided arguments.
__call__(...):
Predicts geolocation coordinates from input images.
Parameters:
images: Input images to predict locations for.
batch_size (int, optional): Batch size for processing.
x_N (torch.Tensor, optional): Initial noise tensor. If not provided, it is generated.
num_steps (int, optional): Number of steps for the sampler.
scheduler (callable, optional): Custom scheduler function. If not provided, the default scheduler is used.
cfg (float): Classifier-free guidance scale. Default is 0.
generator (torch.Generator, optional): Random number generator.
Returns:
torch.Tensor: Predicted latitude and longitude coordinates.
compute_likelihood(...):
Computes the exact log-likelihood of observing the given coordinates for the given images.
Parameters:
images: Input images (PIL Image or list of PIL Images). Optional if emb is provided.
coordinates: Target GPS coordinates (latitude, longitude) in degrees.
emb: Pre-computed embeddings. If provided, images will be ignored.
cfg (float): Classifier-free guidance scale. Default is 0 (no guidance).
rademacher (bool): Whether to use Rademacher estimator for divergence. Default is False.
atol (float): Absolute tolerance for ODE solver. Default is 1e-5.
rtol (float): Relative tolerance for ODE solver. Default is 1e-5.
normalize_logp (bool): Whether to normalize the log-likelihood by log(2) * dim. Default is True.
compute_likelihood_grid(...):
Computes the likelihood of an image over a global grid of coordinates.
Parameters:
image: Input PIL Image.
grid_resolution_deg (float): The resolution of the grid in degrees. Default is 10 degrees.
batch_size (int): How many grid points to process in each batch. Adjust based on available memory. Default is 1024.
cfg (float): Classifier-free guidance scale passed to compute_likelihood. Default is 0.
Returns:
tuple: (latitude_grid, longitude_grid, likelihood_grid)
- latitude_grid (np.ndarray): 1D array of latitudes.
- longitude_grid (np.ndarray): 1D array of longitudes.
- likelihood_grid (np.ndarray): 2D array of log-likelihoods corresponding to the lat/lon grid.
compute_localizability(...):
Computes the localizability of an image. We use importance sampling by sampling by the model and not the grid to have a more accurate estimate.
Parameters:
image: Input PIL Image.
atol (float): Absolute tolerance for ODE solver. Default is 1e-5.
rtol (float): Relative tolerance for ODE solver. Default is 1e-5.
number_monte_carlo_samples (int): How many samples to use for importance sampling. Default is 1024.
Returns:
torch.Tensor: Localizability of the image.
Example Usage:
pipe = PlonkPipeline(
"path/to/plonk/model",
)
pipe.to("cuda")
coordinates = pipe(
images,
batch_size=32
)
likelihood = pipe.compute_likelihood(
images,
coordinates,
cfg=0,
rademacher=False,
)
localizability = pipe.compute_localizability(
image,
number_monte_carlo_samples=1024,
)
"""
def __init__(
self,
model_path,
scheduler="sigmoid",
scheduler_start=-7,
scheduler_end=3,
scheduler_tau=1.0,
device=device,
):
self.network = Plonk.from_pretrained(model_path).to(device)
self.network.requires_grad_(False).eval()
assert scheduler in [
"sigmoid",
"cosine",
"linear",
], f"Scheduler {scheduler} not supported"
self.scheduler = scheduler_fn(
scheduler, scheduler_start, scheduler_end, scheduler_tau
)
self.cond_preprocessing = load_prepocessing(model_name=model_path)
self.postprocessing = CartesiantoGPS()
self.sampler = riemannian_flow_sampler
self.model_path = model_path
self.preconditioning = DDPMPrecond()
self.device = device
# Add manifold
self.manifold = Sphere()
self.input_dim = 3 # Assuming 3D Cartesian coordinates for sphere
def model(self, *args, **kwargs):
return self.preconditioning(self.network, *args, **kwargs)
def __call__(
self,
images,
batch_size=None,
x_N=None,
num_steps=None,
scheduler=None,
cfg=0,
generator=None,
):
"""Sample from the model given conditioning.
Args:
images: Conditioning input (image or list of images)
batch_size: Number of samples to generate (inferred from cond if not provided)
x_N: Initial noise tensor (generated if not provided)
num_steps: Number of sampling steps (uses default if not provided)
sampler: Custom sampler function (uses default if not provided)
scheduler: Custom scheduler function (uses default if not provided)
cfg: Classifier-free guidance scale (default 15)
generator: Random number generator
Returns:
Sampled GPS coordinates after postprocessing
"""
# Set up batch size and initial noise
shape = [3]
if not isinstance(images, list):
images = [images]
if x_N is None:
if batch_size is None:
if isinstance(images, list):
batch_size = len(images)
else:
batch_size = 1
x_N = torch.randn(
batch_size, *shape, device=self.device, generator=generator
)
else:
x_N = x_N.to(self.device)
if x_N.ndim == 3:
x_N = x_N.unsqueeze(0)
batch_size = x_N.shape[0]
# Set up batch with conditioning
batch = {"y": x_N}
batch["img"] = images
batch = self.cond_preprocessing(batch)
if len(images) > 1:
assert len(images) == batch_size
else:
batch["emb"] = batch["emb"].repeat(batch_size, 1)
# Use default sampler/scheduler if not provided
sampler = self.sampler
if scheduler is None:
scheduler = self.scheduler
# Sample from model
if num_steps is None:
output = sampler(
self.model,
batch,
conditioning_keys="emb",
scheduler=scheduler,
cfg_rate=cfg,
generator=generator,
)
else:
output = sampler(
self.model,
batch,
conditioning_keys="emb",
scheduler=scheduler,
num_steps=num_steps,
cfg_rate=cfg,
generator=generator,
)
# Apply postprocessing and return
output = self.postprocessing(output)
# To degrees
output = np.degrees(output.detach().cpu().numpy())
return output
def compute_likelihood(
self,
images=None,
coordinates=None,
emb=None,
cfg=0,
rademacher=False,
atol=1e-6,
rtol=1e-6,
normalize_logp=True,
):
"""
Computes the exact log-likelihood of observing the given coordinates for the given images.
Args:
images: Input images (PIL Image or list of PIL Images). Optional if emb is provided.
coordinates: Target GPS coordinates (latitude, longitude) in degrees.
Can be a list of pairs, numpy array (N, 2), or tensor (N, 2).
emb: Pre-computed embeddings. If provided, images will be ignored.
cfg (float): Classifier-free guidance scale. Default is 0 (no guidance).
rademacher (bool): Whether to use Rademacher estimator for divergence. Default is False.
atol (float): Absolute tolerance for ODE solver. Default is 1e-5.
rtol (float): Relative tolerance for ODE solver. Default is 1e-5.
normalize_logp (bool): Whether to normalize the log-likelihood by log(2) * dim. Default is True.
Returns:
torch.Tensor: Log-likelihood values for each input pair (image, coordinate).
"""
nfe = [0] # Counter for number of function evaluations
# 1. Get embeddings either from images or directly from emb parameter
if emb is not None:
# Use provided embeddings directly
if isinstance(emb, torch.Tensor):
batch = {"emb": emb.to(self.device)}
else:
raise TypeError("emb must be a torch.Tensor")
else:
# Process images to get embeddings
if not isinstance(images, list):
images = [images]
batch = {"img": images}
batch = self.cond_preprocessing(batch) # Adds 'emb' key
# 2. Preprocess coordinates (GPS degrees -> Cartesian)
x_1 = _gps_degrees_to_cartesian(coordinates, self.device)
if x_1.shape[0] != batch["emb"].shape[0]:
if x_1.shape[0] == 1:
# Repeat coordinates if only one is provided for multiple images
x_1 = x_1.repeat(batch["emb"].shape[0], 1)
elif batch["emb"].shape[0] == 1:
# Repeat embedding if only one image is provided for multiple coords
batch["emb"] = batch["emb"].repeat(x_1.shape[0], 1)
else:
raise ValueError(
f"Batch size mismatch between images ({batch['emb'].shape[0]}) and coordinates ({x_1.shape[0]})"
)
# Ensure correct shapes for ODE solver
if x_1.ndim == 1:
x_1 = x_1.unsqueeze(0)
if batch["emb"].ndim == 1:
batch["emb"] = batch["emb"].unsqueeze(0)
with torch.inference_mode(mode=False): # Enable grads for jacobian calculation
# Define the ODE function
def odefunc(t, tensor):
nfe[0] += 1
t = t.to(tensor)
gamma = self.scheduler(t)
x = tensor[..., : self.input_dim]
y = batch["emb"] # Conditioning
def vecfield(x_vf, y_vf):
batch_vecfield = {
"y": x_vf,
"emb": y_vf,
"gamma": gamma.reshape(-1),
}
if cfg > 0:
# Apply classifier-free guidance
batch_vecfield_uncond = {
"y": x_vf,
"emb": torch.zeros_like(y_vf), # Null condition
"gamma": gamma.reshape(-1),
}
model_output_cond = self.model(batch_vecfield)
model_output_uncond = self.model(batch_vecfield_uncond)
model_output = model_output_cond + cfg * (
model_output_cond - model_output_uncond
)
else:
# Unconditional or naturally conditioned score
model_output = self.model(batch_vecfield)
# Assuming 'flow_matching' interpolant based on sampler used
d_gamma = self.scheduler.derivative(t).reshape(-1, 1)
return d_gamma * model_output
if rademacher:
v = torch.randint_like(x, 2) * 2 - 1
else:
v = None
dx, div = output_and_div(vecfield, x, y, v=v)
div = div.reshape(-1, 1)
del t, x
return torch.cat([dx, div], dim=-1)
# 3. Solve the ODE
state1 = torch.cat([x_1, torch.zeros_like(x_1[..., :1])], dim=-1)
# Note: Using standard ODEINT here. For strict Riemannian integration,
# a manifold-aware solver might be needed, but this follows the
# structure from DiffGeolocalizer.compute_exact_loglikelihood more closely.
with torch.no_grad():
state0 = odeint(
odefunc,
state1,
t=torch.linspace(0, 1.0, 2).to(x_1.device),
atol=atol,
rtol=rtol,
method="dopri5",
options={"min_step": 1e-5},
)[
-1
] # Get the state at t=0
x_0, logdetjac = state0[..., : self.input_dim], state0[..., -1]
# Project final point onto the manifold (optional but good practice)
x_0 = self.manifold.projx(x_0)
# 4. Compute log probability
# Log prob of base distribution (Gaussian projected onto sphere approx)
logp0 = self.manifold.base_logprob(x_0)
# Change of variables formula: log p(x_1) = log p(x_0) + log |det J|
logp1 = logp0 + logdetjac
# Optional: Normalize by log(2) * dim for bits per dimension
if normalize_logp:
logp1 = logp1 / (self.input_dim * np.log(2))
print(f"Likelihood NFE: {nfe[0]}") # Print number of function evaluations
return logp1
def compute_likelihood_grid(
self,
image,
grid_resolution_deg=10,
batch_size=1024,
cfg=0,
):
"""
Computes the likelihood of an image over a global grid of coordinates.
Args:
image: Input PIL Image.
grid_resolution_deg (float): The resolution of the grid in degrees.
Default is 10 degrees.
batch_size (int): How many grid points to process in each batch.
Adjust based on available memory. Default is 1024.
cfg (float): Classifier-free guidance scale passed to compute_likelihood.
Default is 0.
Returns:
tuple: (latitude_grid, longitude_grid, likelihood_grid)
- latitude_grid (np.ndarray): 1D array of latitudes.
- longitude_grid (np.ndarray): 1D array of longitudes.
- likelihood_grid (np.ndarray): 2D array of log-likelihoods
corresponding to the lat/lon grid.
"""
# 1. Generate the grid
latitudes = np.arange(-90, 90 + grid_resolution_deg, grid_resolution_deg)
longitudes = np.arange(-180, 180 + grid_resolution_deg, grid_resolution_deg)
lon_grid, lat_grid = np.meshgrid(longitudes, latitudes)
# Flatten the grid for processing
all_coordinates = np.vstack([lat_grid.ravel(), lon_grid.ravel()]).T
num_points = all_coordinates.shape[0]
print(
f"Computing likelihood over a {latitudes.size}x{longitudes.size} grid ({num_points} points)..."
)
emb = self.cond_preprocessing({"img": [image]})["emb"]
# 2. Process in batches
all_likelihoods = []
for i in tqdm(
range(0, num_points, batch_size), desc="Computing Likelihood Grid"
):
coord_batch = all_coordinates[i : i + batch_size]
# Compute likelihood for the batch
likelihood_batch = self.compute_likelihood(
emb=emb,
coordinates=coord_batch,
cfg=cfg,
rademacher=False, # Using exact divergence is better for grid
)
all_likelihoods.append(likelihood_batch.detach().cpu().numpy())
# 3. Combine and reshape results
likelihood_flat = np.concatenate(all_likelihoods, axis=0)
likelihood_grid = likelihood_flat.reshape(lat_grid.shape)
# Return grid definition and likelihood values
return latitudes, longitudes, likelihood_grid
def compute_localizability(
self,
image,
atol=1e-6,
rtol=1e-6,
number_monte_carlo_samples=1024,
):
"""
Computes the localizability of an image. We use importance sampling by sampling by the model and not the grid to have a more accurate estimate.
Args:
image: Input PIL Image.
atol (float): Absolute tolerance for ODE solver. Default is 1e-5.
rtol (float): Relative tolerance for ODE solver. Default is 1e-5.
"""
samples = self(image, batch_size=number_monte_carlo_samples)
emb = self.cond_preprocessing({"img": [image]})["emb"]
localizability = self.compute_likelihood(
emb=emb,
coordinates=samples,
atol=atol,
rtol=rtol,
normalize_logp=False,
) # importance sampling of likelihood
return localizability.mean() / (4 * torch.pi * np.log(2))
def to(self, device):
self.network.to(device)
self.postprocessing.to(device)
self.device = torch.device(device)
return self