Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
def distance_to_similarity(distances, temperature=1.0): | |
""" | |
Turns a distance matrix into a similarity matrix so it works with distribution-based metrics. | |
""" | |
similarities = torch.exp(-distances / temperature) | |
similarities = torch.clamp(similarities, min=1e-8) | |
return similarities | |
################################# | |
## "New Object" Detection ## | |
################################# | |
def detect_newness_two_sided(distances, k=3, quantile=0.97): | |
device = distances.device | |
N_src, N_tgt = distances.shape | |
topk_src_idx_t = torch.topk(distances, k, dim=0, largest=False).indices # [k, N_tgt] | |
topk_tgt_idx_s = torch.topk(distances, k, dim=1, largest=False).indices # [N_src, k] | |
src_to_tgt_mask = torch.zeros((N_src, N_tgt), device=device) | |
tgt_to_src_mask = torch.zeros((N_src, N_tgt), device=device) | |
row_indices = topk_src_idx_t # [k, N_tgt] | |
col_indices = torch.arange(N_tgt, device=device).unsqueeze(0).repeat(k, 1) # [k, N_tgt] | |
src_to_tgt_mask[row_indices, col_indices] = 1.0 # Assign 1.0 at the top-k positions | |
row_indices = torch.arange(N_src, device=device).unsqueeze(1).repeat(1, k) # [N_src, k] | |
col_indices = topk_tgt_idx_s # [N_src, k] | |
tgt_to_src_mask[row_indices, col_indices] = 1.0 # Assign 1.0 at the top-k positions | |
overlap_mask = (src_to_tgt_mask * tgt_to_src_mask).sum(dim=0) > 0 # [N_tgt] | |
distances[:, overlap_mask] = 0.0 | |
two_sided_mask = (~overlap_mask).float() | |
min_distances, _ = distances.min(dim=0) | |
threshold = torch.quantile(min_distances, quantile) | |
threshold_mask = (min_distances > threshold).float() | |
combined_mask = two_sided_mask * threshold_mask | |
return combined_mask | |
def detect_newness_distance(min_distances, quantile=0.97): | |
""" | |
Old approach: threshold on min distance at a chosen percentile. | |
""" | |
threshold = torch.quantile(min_distances, quantile) | |
newness_mask = (min_distances > threshold).float() | |
return newness_mask | |
def detect_newness_topk_margin(distances, top_k=2, quantile=0.03): | |
""" | |
Top-k margin approach in distance space. | |
distances: [N_src, N_tgt] | |
Sort each column ascending => best match is index 0, second best is index 1, etc. | |
A smaller margin => ambiguous => likely new. | |
We threshold the margin at some percentile. | |
""" | |
sorted_dists, _ = torch.sort(distances, dim=0) | |
best = sorted_dists[0] # [N_tgt] | |
second_best = sorted_dists[1] if top_k >= 2 else sorted_dists[0] # [N_tgt] | |
margin = second_best - best # [N_tgt] | |
# If margin < threshold => ambiguous => "new" | |
# We'll pick threshold as a quantile of margin | |
threshold = torch.quantile(margin, quantile) | |
newness_mask = (margin < threshold).float() | |
return newness_mask | |
def detect_newness_entropy(distances, temperature=1.0, quantile=0.97): | |
""" | |
Entropy-based approach. First convert distance->similarity with an exponential. | |
Then normalize to get a distribution for each target patch, compute Shannon entropy. | |
High entropy => new object (no strong match). | |
""" | |
similarities = distance_to_similarity(distances, temperature=temperature) | |
probs = similarities / similarities.sum(dim=0, keepdim=True) # [N_src, N_tgt] | |
# Shannon Entropy: -sum(p log p) | |
entropy = -torch.sum(probs * torch.log(probs), dim=0) # [N_tgt] | |
# threshold | |
threshold = torch.quantile(entropy, quantile) | |
newness_mask = (entropy > threshold).float() | |
return newness_mask | |
def detect_newness_gini(distances, temperature=1.0, quantile=0.97): | |
""" | |
Gini impurity-based approach. Convert distances to similarities, | |
get a distribution, compute Gini. | |
High Gini => wide distribution => new object. | |
""" | |
similarities = distance_to_similarity(distances, temperature=temperature) | |
probs = similarities / similarities.sum(dim=0, keepdim=True) | |
# Gini: sum(p_i*(1-p_i)) => high if spread out | |
gini = torch.sum(probs * (1.0 - probs), dim=0) # [N_tgt] | |
threshold = torch.quantile(gini, quantile) | |
newness_mask = (gini > threshold).float() | |
return newness_mask | |
def detect_newness_kl(distances, temperature=1.0, quantile=0.97): | |
""" | |
KL-based approach. Compare distribution to uniform => if close to uniform => new object. | |
1) Convert distances -> similarities | |
2) p(x) = similarities / sum(similarities) | |
3) KL(p || uniform) => sum p(x) log (p(x)/(1/N_src)) | |
4) If p is near uniform => KL small => new object. | |
We'll invert it => newness ~ 1/KL. | |
""" | |
similarities = distance_to_similarity(distances, temperature=temperature) | |
N_src = distances.shape[0] | |
probs = similarities / similarities.sum(dim=0, keepdim=True) | |
uniform_val = 1.0 / float(N_src) | |
kl_vals = torch.sum(probs * torch.log(probs / uniform_val), dim=0) # [N_tgt] | |
inv_kl = 1.0 / (kl_vals + 1e-8) # big => distribution is near uniform => new | |
threshold = torch.quantile(inv_kl, quantile) | |
newness_mask = (inv_kl > threshold).float() | |
return newness_mask | |
def detect_newness_variation_ratio(distances, temperature=1.0, quantile=0.97): | |
""" | |
Variation Ratio: 1 - max(prob). | |
1) Convert distance->similarity | |
2) p(x) = sim(x) / sum_x'(sim(x')) | |
3) var_ratio = 1 - max(p) | |
High var_ratio => new object. | |
""" | |
similarities = distance_to_similarity(distances, temperature=temperature) | |
probs = similarities / similarities.sum(dim=0, keepdim=True) | |
max_prob, _ = torch.max(probs, dim=0) # [N_tgt] | |
var_ratio = 1.0 - max_prob | |
threshold = torch.quantile(var_ratio, quantile) | |
newness_mask = (var_ratio > threshold).float() | |
return newness_mask | |
def detect_newness_two_sided_ratio( | |
distances, | |
top_k_ratio_quantile=0.03, | |
two_sided=True | |
): | |
""" | |
Two-sided matching + ratio test in distance space. | |
Ratio test: For each t, let d0 = best distance, d1 = second best. | |
ratio = d0 / (d1 + 1e-8). | |
If ratio < ratio_threshold => ambiguous => new. | |
(Typically a smaller ratio means a better match, but we invert logic: | |
a patch can be "new" if the ratio is extremely small or ambiguous.) | |
""" | |
N_src, N_tgt = distances.shape | |
# Target → Source: best match | |
min_vals_t, best_s_for_t = torch.min(distances, dim=0) | |
# Source → Target: best match | |
min_vals_s, best_t_for_s = torch.min(distances, dim=1) | |
# Two-sided consistency check | |
twosided_mask = torch.zeros(N_tgt, device=distances.device) | |
if two_sided: | |
for t in range(N_tgt): | |
s = best_s_for_t[t] | |
if best_t_for_s[s] != t: | |
twosided_mask[t] = 1.0 | |
# Ratio test: ambiguous if best match is not clearly better than second best | |
sorted_dists, _ = torch.sort(distances, dim=0) | |
d0 = sorted_dists[0] | |
d1 = sorted_dists[1] | |
ratio = d0 / (d1 + 1e-8) | |
ratio_threshold = torch.quantile(ratio, top_k_ratio_quantile) | |
ratio_mask = (ratio < ratio_threshold).float() | |
# Combine checks (currently using only two-sided result) | |
newness_mask = twosided_mask | |
return newness_mask | |