cora / model /modules /new_object_detection.py
armikaeili's picture
code added
79c5088
import torch
def distance_to_similarity(distances, temperature=1.0):
"""
Turns a distance matrix into a similarity matrix so it works with distribution-based metrics.
"""
similarities = torch.exp(-distances / temperature)
similarities = torch.clamp(similarities, min=1e-8)
return similarities
#################################
## "New Object" Detection ##
#################################
def detect_newness_two_sided(distances, k=3, quantile=0.97):
device = distances.device
N_src, N_tgt = distances.shape
topk_src_idx_t = torch.topk(distances, k, dim=0, largest=False).indices # [k, N_tgt]
topk_tgt_idx_s = torch.topk(distances, k, dim=1, largest=False).indices # [N_src, k]
src_to_tgt_mask = torch.zeros((N_src, N_tgt), device=device)
tgt_to_src_mask = torch.zeros((N_src, N_tgt), device=device)
row_indices = topk_src_idx_t # [k, N_tgt]
col_indices = torch.arange(N_tgt, device=device).unsqueeze(0).repeat(k, 1) # [k, N_tgt]
src_to_tgt_mask[row_indices, col_indices] = 1.0 # Assign 1.0 at the top-k positions
row_indices = torch.arange(N_src, device=device).unsqueeze(1).repeat(1, k) # [N_src, k]
col_indices = topk_tgt_idx_s # [N_src, k]
tgt_to_src_mask[row_indices, col_indices] = 1.0 # Assign 1.0 at the top-k positions
overlap_mask = (src_to_tgt_mask * tgt_to_src_mask).sum(dim=0) > 0 # [N_tgt]
distances[:, overlap_mask] = 0.0
two_sided_mask = (~overlap_mask).float()
min_distances, _ = distances.min(dim=0)
threshold = torch.quantile(min_distances, quantile)
threshold_mask = (min_distances > threshold).float()
combined_mask = two_sided_mask * threshold_mask
return combined_mask
def detect_newness_distance(min_distances, quantile=0.97):
"""
Old approach: threshold on min distance at a chosen percentile.
"""
threshold = torch.quantile(min_distances, quantile)
newness_mask = (min_distances > threshold).float()
return newness_mask
def detect_newness_topk_margin(distances, top_k=2, quantile=0.03):
"""
Top-k margin approach in distance space.
distances: [N_src, N_tgt]
Sort each column ascending => best match is index 0, second best is index 1, etc.
A smaller margin => ambiguous => likely new.
We threshold the margin at some percentile.
"""
sorted_dists, _ = torch.sort(distances, dim=0)
best = sorted_dists[0] # [N_tgt]
second_best = sorted_dists[1] if top_k >= 2 else sorted_dists[0] # [N_tgt]
margin = second_best - best # [N_tgt]
# If margin < threshold => ambiguous => "new"
# We'll pick threshold as a quantile of margin
threshold = torch.quantile(margin, quantile)
newness_mask = (margin < threshold).float()
return newness_mask
def detect_newness_entropy(distances, temperature=1.0, quantile=0.97):
"""
Entropy-based approach. First convert distance->similarity with an exponential.
Then normalize to get a distribution for each target patch, compute Shannon entropy.
High entropy => new object (no strong match).
"""
similarities = distance_to_similarity(distances, temperature=temperature)
probs = similarities / similarities.sum(dim=0, keepdim=True) # [N_src, N_tgt]
# Shannon Entropy: -sum(p log p)
entropy = -torch.sum(probs * torch.log(probs), dim=0) # [N_tgt]
# threshold
threshold = torch.quantile(entropy, quantile)
newness_mask = (entropy > threshold).float()
return newness_mask
def detect_newness_gini(distances, temperature=1.0, quantile=0.97):
"""
Gini impurity-based approach. Convert distances to similarities,
get a distribution, compute Gini.
High Gini => wide distribution => new object.
"""
similarities = distance_to_similarity(distances, temperature=temperature)
probs = similarities / similarities.sum(dim=0, keepdim=True)
# Gini: sum(p_i*(1-p_i)) => high if spread out
gini = torch.sum(probs * (1.0 - probs), dim=0) # [N_tgt]
threshold = torch.quantile(gini, quantile)
newness_mask = (gini > threshold).float()
return newness_mask
def detect_newness_kl(distances, temperature=1.0, quantile=0.97):
"""
KL-based approach. Compare distribution to uniform => if close to uniform => new object.
1) Convert distances -> similarities
2) p(x) = similarities / sum(similarities)
3) KL(p || uniform) => sum p(x) log (p(x)/(1/N_src))
4) If p is near uniform => KL small => new object.
We'll invert it => newness ~ 1/KL.
"""
similarities = distance_to_similarity(distances, temperature=temperature)
N_src = distances.shape[0]
probs = similarities / similarities.sum(dim=0, keepdim=True)
uniform_val = 1.0 / float(N_src)
kl_vals = torch.sum(probs * torch.log(probs / uniform_val), dim=0) # [N_tgt]
inv_kl = 1.0 / (kl_vals + 1e-8) # big => distribution is near uniform => new
threshold = torch.quantile(inv_kl, quantile)
newness_mask = (inv_kl > threshold).float()
return newness_mask
def detect_newness_variation_ratio(distances, temperature=1.0, quantile=0.97):
"""
Variation Ratio: 1 - max(prob).
1) Convert distance->similarity
2) p(x) = sim(x) / sum_x'(sim(x'))
3) var_ratio = 1 - max(p)
High var_ratio => new object.
"""
similarities = distance_to_similarity(distances, temperature=temperature)
probs = similarities / similarities.sum(dim=0, keepdim=True)
max_prob, _ = torch.max(probs, dim=0) # [N_tgt]
var_ratio = 1.0 - max_prob
threshold = torch.quantile(var_ratio, quantile)
newness_mask = (var_ratio > threshold).float()
return newness_mask
def detect_newness_two_sided_ratio(
distances,
top_k_ratio_quantile=0.03,
two_sided=True
):
"""
Two-sided matching + ratio test in distance space.
Ratio test: For each t, let d0 = best distance, d1 = second best.
ratio = d0 / (d1 + 1e-8).
If ratio < ratio_threshold => ambiguous => new.
(Typically a smaller ratio means a better match, but we invert logic:
a patch can be "new" if the ratio is extremely small or ambiguous.)
"""
N_src, N_tgt = distances.shape
# Target → Source: best match
min_vals_t, best_s_for_t = torch.min(distances, dim=0)
# Source → Target: best match
min_vals_s, best_t_for_s = torch.min(distances, dim=1)
# Two-sided consistency check
twosided_mask = torch.zeros(N_tgt, device=distances.device)
if two_sided:
for t in range(N_tgt):
s = best_s_for_t[t]
if best_t_for_s[s] != t:
twosided_mask[t] = 1.0
# Ratio test: ambiguous if best match is not clearly better than second best
sorted_dists, _ = torch.sort(distances, dim=0)
d0 = sorted_dists[0]
d1 = sorted_dists[1]
ratio = d0 / (d1 + 1e-8)
ratio_threshold = torch.quantile(ratio, top_k_ratio_quantile)
ratio_mask = (ratio < ratio_threshold).float()
# Combine checks (currently using only two-sided result)
newness_mask = twosided_mask
return newness_mask