|
import torch |
|
|
|
from metrics.utils import haversine, reverse |
|
from sklearn.metrics import pairwise_distances |
|
from torchmetrics import Metric |
|
import numpy as np |
|
from utils.kde import BatchedKDE |
|
from tqdm import tqdm |
|
|
|
|
|
class HaversineMetrics(Metric): |
|
""" |
|
Computes the average haversine distance between the predicted and ground truth points. |
|
Compute the accuracy given some radiuses. |
|
Compute the Geoguessr score given some radiuses. |
|
|
|
Args: |
|
acc_radiuses (list): list of radiuses to compute the accuracy from |
|
acc_area (list): list of areas to compute the accuracy from. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
acc_radiuses=[], |
|
acc_area=["country", "region", "sub-region", "city"], |
|
use_kde=False, |
|
manifold_k=3, |
|
): |
|
super().__init__() |
|
self.use_kde = use_kde |
|
self.add_state("haversine_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") |
|
self.add_state("geoguessr_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") |
|
for acc in acc_radiuses: |
|
self.add_state( |
|
f"close_enough_points_{acc}", |
|
default=torch.tensor(0.0), |
|
dist_reduce_fx="sum", |
|
) |
|
for acc in acc_area: |
|
self.add_state( |
|
f"close_enough_points_{acc}", |
|
default=torch.tensor(0.0), |
|
dist_reduce_fx="sum", |
|
) |
|
self.add_state( |
|
f"count_{acc}", default=torch.tensor(0), dist_reduce_fx="sum" |
|
) |
|
self.acc_radius = acc_radiuses |
|
self.acc_area = acc_area |
|
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") |
|
self.add_state( |
|
"real_points", |
|
[], |
|
dist_reduce_fx=None, |
|
) |
|
self.add_state( |
|
"fake_points", |
|
[], |
|
dist_reduce_fx=None, |
|
) |
|
self.manifold_k = manifold_k |
|
|
|
def update(self, pred, gt): |
|
if self.use_kde: |
|
(x_mode, y_mode), kde = estimate_kde_mode(pred["gps"]) |
|
|
|
|
|
|
|
pred["gps"] = torch.stack([x_mode, y_mode], dim=1) |
|
|
|
if pred["gps"].isnan().any(): |
|
valid_mask = ~pred["gps"].isnan().any(dim=1) |
|
pred_gps = pred["gps"][valid_mask] |
|
gt_gps = gt["gps"][valid_mask] |
|
if len(pred_gps) == 0: |
|
return |
|
else: |
|
pred_gps = pred["gps"] |
|
gt_gps = gt["gps"] |
|
haversine_distance = haversine(pred_gps, gt_gps) |
|
for acc in self.acc_radius: |
|
self.__dict__[f"close_enough_points_{acc}"] += ( |
|
haversine_distance < acc |
|
).sum() |
|
if len(self.acc_area) > 0: |
|
area_pred, area_gt = reverse(pred_gps, gt, self.acc_area) |
|
for acc in self.acc_area: |
|
self.__dict__[f"close_enough_points_{acc}"] += ( |
|
area_pred[acc] == area_gt["_".join(["unique", acc])] |
|
).sum() |
|
self.__dict__[f"count_{acc}"] += len(area_gt["_".join(["unique", acc])]) |
|
self.haversine_sum += haversine_distance.sum() |
|
self.geoguessr_sum += 5000 * torch.exp(-haversine_distance / 1492.7).sum() |
|
self.real_points.append(gt_gps) |
|
self.fake_points.append(pred_gps) |
|
self.count += pred_gps.shape[0] |
|
|
|
def compute(self): |
|
output = { |
|
"Haversine": self.haversine_sum / self.count, |
|
"Geoguessr": self.geoguessr_sum / self.count, |
|
} |
|
for acc in self.acc_radius: |
|
output[f"Accuracy_{acc}_km_radius"] = ( |
|
self.__dict__[f"close_enough_points_{acc}"] / self.count |
|
) |
|
for acc in self.acc_area: |
|
output[f"Accuracy_{acc}"] = ( |
|
self.__dict__[f"close_enough_points_{acc}"] |
|
/ self.__dict__[f"count_{acc}"] |
|
) |
|
real_points = torch.cat(self.real_points, dim=0) |
|
fake_points = torch.cat(self.fake_points, dim=0) |
|
( |
|
output["precision"], |
|
output["recall"], |
|
output["density"], |
|
output["coverage"], |
|
) = self.manifold_metrics(real_points, fake_points, self.manifold_k) |
|
return output |
|
|
|
def compute_pairwise_distance(self, data_x, data_y=None): |
|
""" |
|
Args: |
|
data_x: numpy.ndarray([N, feature_dim], dtype=np.float32) |
|
data_y: numpy.ndarray([N, feature_dim], dtype=np.float32) |
|
Returns: |
|
numpy.ndarray([N, N], dtype=np.float32) of pairwise distances. |
|
""" |
|
if data_y is None: |
|
data_y = data_x |
|
|
|
dists = pairwise_distances(data_x, data_y, metric="haversine", n_jobs=8) |
|
return dists |
|
|
|
def get_kth_value(self, unsorted, k, axis=-1): |
|
""" |
|
Args: |
|
unsorted: numpy.ndarray of any dimensionality. |
|
k: int |
|
Returns: |
|
kth values along the designated axis. |
|
""" |
|
indices = np.argpartition(unsorted, k, axis=axis)[..., :k] |
|
k_smallests = np.take_along_axis(unsorted, indices, axis=axis) |
|
kth_values = k_smallests.max(axis=axis) |
|
return kth_values |
|
|
|
def compute_nearest_neighbour_distances(self, input_features, nearest_k): |
|
""" |
|
Args: |
|
input_features: numpy.ndarray([N, feature_dim], dtype=np.float32) |
|
nearest_k: int |
|
Returns: |
|
Distances to kth nearest neighbours. |
|
""" |
|
distances = self.compute_pairwise_distance(input_features) |
|
radii = self.get_kth_value(distances, k=nearest_k + 1, axis=-1) |
|
return radii |
|
|
|
def compute_prdc(self, real_features, fake_features, nearest_k): |
|
""" |
|
Computes precision, recall, density, and coverage given two manifolds. |
|
Args: |
|
real_features: numpy.ndarray([N, feature_dim], dtype=np.float32) |
|
fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32) |
|
nearest_k: int. |
|
Returns: |
|
dict of precision, recall, density, and coverage. |
|
""" |
|
|
|
real_nearest_neighbour_distances = self.compute_nearest_neighbour_distances( |
|
real_features, nearest_k |
|
) |
|
fake_nearest_neighbour_distances = self.compute_nearest_neighbour_distances( |
|
fake_features, nearest_k |
|
) |
|
distance_real_fake = self.compute_pairwise_distance( |
|
real_features, fake_features |
|
) |
|
|
|
precision = ( |
|
( |
|
distance_real_fake |
|
< np.expand_dims(real_nearest_neighbour_distances, axis=1) |
|
) |
|
.any(axis=0) |
|
.mean() |
|
) |
|
|
|
recall = ( |
|
( |
|
distance_real_fake |
|
< np.expand_dims(fake_nearest_neighbour_distances, axis=0) |
|
) |
|
.any(axis=1) |
|
.mean() |
|
) |
|
|
|
density = (1.0 / float(nearest_k)) * ( |
|
distance_real_fake |
|
< np.expand_dims(real_nearest_neighbour_distances, axis=1) |
|
).sum(axis=0).mean() |
|
|
|
coverage = ( |
|
distance_real_fake.min(axis=1) < real_nearest_neighbour_distances |
|
).mean() |
|
|
|
return precision, recall, density, coverage |
|
|
|
def manifold_metrics(self, real_features, fake_features, nearest_k, num_splits=20): |
|
""" |
|
Computes precision, recall, density, and coverage given two manifolds. |
|
Args: |
|
real_features: torch.Tensor([N, feature_dim], dtype=torch.float32) |
|
fake_features: torch.Tensor([N, feature_dim], dtype=torch.float32) |
|
nearest_k: int. |
|
num_splits: int. Number of splits to use for computing metrics. |
|
Returns: |
|
dict of precision, recall, density, and coverage. |
|
""" |
|
real_features = real_features.chunk(num_splits, dim=0) |
|
fake_features = fake_features.chunk(num_splits, dim=0) |
|
precision, recall, density, coverage = [], [], [], [] |
|
for real, fake in tqdm( |
|
zip(real_features, fake_features), desc="Computing manifold" |
|
): |
|
p, r, d, c = self.compute_prdc( |
|
real.cpu().numpy(), fake.cpu().numpy(), nearest_k=nearest_k |
|
) |
|
precision.append(torch.tensor(p, device=real.device)) |
|
recall.append(torch.tensor(r, device=real.device)) |
|
density.append(torch.tensor(d, device=real.device)) |
|
coverage.append(torch.tensor(c, device=real.device)) |
|
return ( |
|
torch.stack(precision).mean().item(), |
|
torch.stack(recall).mean().item(), |
|
torch.stack(density).mean().item(), |
|
torch.stack(coverage).mean().item(), |
|
) |
|
|
|
|
|
def estimate_kde_mode(points): |
|
kde = BatchedKDE() |
|
kde.fit(points) |
|
batch_size = points.shape[0] |
|
X, Y, positions = batched_make_grid(points.cpu()) |
|
X = X.to(points.device) |
|
Y = Y.to(points.device) |
|
positions = positions.to(points.device) |
|
Z = kde.score(positions).reshape(X.shape) |
|
|
|
x_mode = X.reshape(batch_size, -1)[ |
|
torch.arange(batch_size), Z.reshape(batch_size, -1).argmax(dim=1) |
|
] |
|
y_mode = Y.reshape(batch_size, -1)[ |
|
torch.arange(batch_size), Z.reshape(batch_size, -1).argmax(dim=1) |
|
] |
|
return (x_mode, y_mode), kde |
|
|
|
|
|
def make_grid(points): |
|
(lat_min, long_min), _ = points.min(dim=-2) |
|
(lat_max, long_max), _ = points.max(dim=-2) |
|
x = torch.linspace(lat_min, lat_max, 100) |
|
y = torch.linspace(long_min, long_max, 100) |
|
X, Y = torch.meshgrid(x, y) |
|
positions = torch.vstack([X.flatten(), Y.flatten()]).transpose(-1, -2) |
|
return X, Y, positions |
|
|
|
|
|
batched_make_grid = torch.vmap(make_grid) |
|
|