cora / utils /utils.py
armikaeili's picture
code added
79c5088
import torch
import torch.nn.functional as F
import io
import cv2
import numpy as np
from PIL import Image
def normalize(
z_t,
i,
max_norm_zs,
):
max_norm = max_norm_zs[i]
if max_norm < 0:
return z_t, 1
norm = torch.norm(z_t)
if norm < max_norm:
return z_t, 1
coeff = max_norm / norm
z_t = z_t * coeff
return z_t, coeff
def normalize2(x, dim):
x_mean = x.mean(dim=dim, keepdim=True)
x_std = x.std(dim=dim, keepdim=True)
x_normalized = (x - x_mean) / x_std
return x_normalized
def find_lambda_via_newton_batched(Qp, K_source, K_target, max_iter=50, tol=1e-7):
dot_QpK_source = torch.einsum("bcd,bmd->bcm", Qp, K_source) # shape [B]
dot_QpK_target = torch.einsum("bcd,bmd->bcm", Qp, K_target) # shape [B]
X = torch.exp(dot_QpK_source)
lmbd = torch.zeros([1], device=Qp.device, dtype=Qp.dtype) + 0.7
for it in range(max_iter):
y = torch.exp(lmbd * dot_QpK_target)
Z = (X + y).sum(dim=(2), keepdim=True)
x = X / Z
y = y / Z
val = (x.sum(dim=(1,2)) - y.sum(dim=(1,2))).sum()
grad = - (dot_QpK_target * y).sum()
if not (val.abs() > tol and grad.abs() > 1e-12):
break
lmbd = lmbd - val / grad
if lmbd.item() < 0.4:
return 0.1
elif lmbd.item() > 0.9:
return 0.65
return lmbd.item()
def find_lambda_via_super_halley(Qp, K_source, K_target, max_iter=50, tol=1e-7):
dot_QpK_source = torch.einsum("bcd,bmd->bcm", Qp, K_source)
dot_QpK_target = torch.einsum("bcd,bmd->bcm", Qp, K_target)
X = torch.exp(dot_QpK_source)
lmbd = torch.zeros([], device=Qp.device, dtype=Qp.dtype) + 0.8
for it in range(max_iter):
y = torch.exp(lmbd * dot_QpK_target)
Z = (X + y).sum(dim=2, keepdim=True)
x = X / Z
y = y / Z
val = (x.sum(dim=(1,2)) - y.sum(dim=(1,2))).sum()
grad = - (dot_QpK_target * y).sum()
f2 = - (dot_QpK_target**2 * y).sum()
if not (val.abs() > tol and grad.abs() > 1e-12):
break
denom = grad**2 - val * f2
if denom.abs() < 1e-20:
break
update = (val * grad) / denom
lmbd = lmbd - update
print(f"iter={it}, λ={lmbd.item():.6f}, val={val.item():.6e}, grad={grad.item():.6e}")
return lmbd
def find_smallest_key_with_suffix(features_dict: dict, suffix: str = "_1") -> str:
smallest_key = None
smallest_number = float('inf')
for key in features_dict.keys():
if key.endswith(suffix):
try:
number = int(key.split('_')[0])
if number < smallest_number:
smallest_number = number
smallest_key = key
except ValueError:
continue
return smallest_key
def extract_mask(masks, original_width, original_height):
if not masks:
return None
combined_mask = torch.zeros(512, 512)
scale_x = 512 / original_width
scale_y = 512 / original_height
for mask in masks:
start_x, start_y = mask["start_point"]
end_x, end_y = mask["end_point"]
start_x, end_x = min(start_x, end_x), max(start_x, end_x)
start_y, end_y = min(start_y, end_y), max(start_y, end_y)
scaled_start_x, scaled_start_y = int(start_x * scale_x), int(start_y * scale_y)
scaled_end_x, scaled_end_y = int(end_x * scale_x), int(end_y * scale_y)
combined_mask[scaled_start_y:scaled_end_y, scaled_start_x:scaled_end_x] += 1
binary_mask = (combined_mask > 0).float()
resized_mask = F.interpolate(binary_mask[None, None, :, :], size=(64, 64), mode="nearest")[0, 0]
return resized_mask