Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn.functional as F | |
import io | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
def normalize( | |
z_t, | |
i, | |
max_norm_zs, | |
): | |
max_norm = max_norm_zs[i] | |
if max_norm < 0: | |
return z_t, 1 | |
norm = torch.norm(z_t) | |
if norm < max_norm: | |
return z_t, 1 | |
coeff = max_norm / norm | |
z_t = z_t * coeff | |
return z_t, coeff | |
def normalize2(x, dim): | |
x_mean = x.mean(dim=dim, keepdim=True) | |
x_std = x.std(dim=dim, keepdim=True) | |
x_normalized = (x - x_mean) / x_std | |
return x_normalized | |
def find_lambda_via_newton_batched(Qp, K_source, K_target, max_iter=50, tol=1e-7): | |
dot_QpK_source = torch.einsum("bcd,bmd->bcm", Qp, K_source) # shape [B] | |
dot_QpK_target = torch.einsum("bcd,bmd->bcm", Qp, K_target) # shape [B] | |
X = torch.exp(dot_QpK_source) | |
lmbd = torch.zeros([1], device=Qp.device, dtype=Qp.dtype) + 0.7 | |
for it in range(max_iter): | |
y = torch.exp(lmbd * dot_QpK_target) | |
Z = (X + y).sum(dim=(2), keepdim=True) | |
x = X / Z | |
y = y / Z | |
val = (x.sum(dim=(1,2)) - y.sum(dim=(1,2))).sum() | |
grad = - (dot_QpK_target * y).sum() | |
if not (val.abs() > tol and grad.abs() > 1e-12): | |
break | |
lmbd = lmbd - val / grad | |
if lmbd.item() < 0.4: | |
return 0.1 | |
elif lmbd.item() > 0.9: | |
return 0.65 | |
return lmbd.item() | |
def find_lambda_via_super_halley(Qp, K_source, K_target, max_iter=50, tol=1e-7): | |
dot_QpK_source = torch.einsum("bcd,bmd->bcm", Qp, K_source) | |
dot_QpK_target = torch.einsum("bcd,bmd->bcm", Qp, K_target) | |
X = torch.exp(dot_QpK_source) | |
lmbd = torch.zeros([], device=Qp.device, dtype=Qp.dtype) + 0.8 | |
for it in range(max_iter): | |
y = torch.exp(lmbd * dot_QpK_target) | |
Z = (X + y).sum(dim=2, keepdim=True) | |
x = X / Z | |
y = y / Z | |
val = (x.sum(dim=(1,2)) - y.sum(dim=(1,2))).sum() | |
grad = - (dot_QpK_target * y).sum() | |
f2 = - (dot_QpK_target**2 * y).sum() | |
if not (val.abs() > tol and grad.abs() > 1e-12): | |
break | |
denom = grad**2 - val * f2 | |
if denom.abs() < 1e-20: | |
break | |
update = (val * grad) / denom | |
lmbd = lmbd - update | |
print(f"iter={it}, λ={lmbd.item():.6f}, val={val.item():.6e}, grad={grad.item():.6e}") | |
return lmbd | |
def find_smallest_key_with_suffix(features_dict: dict, suffix: str = "_1") -> str: | |
smallest_key = None | |
smallest_number = float('inf') | |
for key in features_dict.keys(): | |
if key.endswith(suffix): | |
try: | |
number = int(key.split('_')[0]) | |
if number < smallest_number: | |
smallest_number = number | |
smallest_key = key | |
except ValueError: | |
continue | |
return smallest_key | |
def extract_mask(masks, original_width, original_height): | |
if not masks: | |
return None | |
combined_mask = torch.zeros(512, 512) | |
scale_x = 512 / original_width | |
scale_y = 512 / original_height | |
for mask in masks: | |
start_x, start_y = mask["start_point"] | |
end_x, end_y = mask["end_point"] | |
start_x, end_x = min(start_x, end_x), max(start_x, end_x) | |
start_y, end_y = min(start_y, end_y), max(start_y, end_y) | |
scaled_start_x, scaled_start_y = int(start_x * scale_x), int(start_y * scale_y) | |
scaled_end_x, scaled_end_y = int(end_x * scale_x), int(end_y * scale_y) | |
combined_mask[scaled_start_y:scaled_end_y, scaled_start_x:scaled_end_x] += 1 | |
binary_mask = (combined_mask > 0).float() | |
resized_mask = F.interpolate(binary_mask[None, None, :, :], size=(64, 64), mode="nearest")[0, 0] | |
return resized_mask | |