Spaces:
Sleeping
Sleeping
import numpy as np | |
import cv2 as cv | |
import os | |
from numpy.linalg import norm, inv | |
from scipy.stats import multivariate_normal as mv_norm | |
import joblib # or import pickle | |
import os | |
import torch | |
from torch.distributions import MultivariateNormal | |
import torch.nn.functional as F | |
init_weight = [0.7, 0.11, 0.1, 0.09] | |
init_u = np.zeros(3) | |
# initial Covariance matrix | |
init_sigma = 225*np.eye(3) | |
init_alpha = 0.05 | |
class GMM(): | |
def __init__(self, data_dir, train_num, alpha=init_alpha): | |
self.data_dir = data_dir | |
self.train_num = train_num | |
self.alpha = alpha | |
self.img_shape = None | |
self.weight = None | |
self.mu = None | |
self.sigma = None | |
self.K = None | |
self.B = None | |
def check(self, pixel, mu, sigma): | |
''' | |
Check whether a pixel matches a Gaussian distribution. | |
Matching means the Mahalanobis distance is less than 2.5. | |
''' | |
# Convert to torch tensors on same device | |
if isinstance(mu, np.ndarray): | |
mu = torch.from_numpy(mu).float() | |
if isinstance(sigma, np.ndarray): | |
sigma = torch.from_numpy(sigma).float() | |
if isinstance(pixel, np.ndarray): | |
pixel = torch.from_numpy(pixel).float() | |
# Ensure all are on the same device | |
device = mu.device | |
pixel = pixel.to(device) | |
sigma = sigma.to(device) | |
# Compute Mahalanobis distance | |
delta = pixel - mu | |
sigma_inv = torch.linalg.inv(sigma) | |
d_squared = delta @ sigma_inv @ delta | |
d = torch.sqrt(d_squared + 1e-5) | |
return d.item() < 0.1 | |
def train(self, K=4): | |
''' | |
train model with GPU acceleration | |
''' | |
self.K = K | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using device: {device}") | |
file_list = [] | |
for i in range(self.train_num): | |
file_name = os.path.join(self.data_dir, 'b%05d' % i + '.bmp') | |
file_list.append(file_name) | |
# Initialize with first image | |
img_init = cv.imread(file_list[0]) | |
img_shape = img_shape = img_init.shape | |
self.img_shape = img_shape | |
height, width, channels = img_shape | |
# Initialize model parameters on GPU | |
self.weight = torch.full((height, width, K), 1.0/K, | |
dtype=torch.float32, device=device) | |
self.mu = torch.zeros(height, width, K, 3, | |
dtype=torch.float32, device=device) | |
self.sigma = torch.zeros(height, width, K, 3, 3, | |
dtype=torch.float32, device=device) | |
self.B = torch.ones((height, width), | |
dtype=torch.int32, device=device) | |
# Initialize mu with first image values | |
img_tensor = torch.from_numpy(img_init).float().to(device) | |
for k in range(K): | |
self.mu[:, :, k, :] = img_tensor | |
# Initialize sigma with identity matrix * 225 | |
self.sigma[:] = torch.eye(3, device=device) * 225 | |
# Training loop | |
for file in file_list: | |
print('training:{}'.format(file)) | |
img = cv.imread(file) | |
img_tensor = torch.from_numpy(img).float().to(device) # (H,W,3) | |
# Check matches for all pixels | |
matches = torch.full((height, width), -1, dtype=torch.long, device=device) | |
for k in range(K): | |
# Calculate Mahalanobis distance for each distribution | |
delta = img_tensor.unsqueeze(2) - self.mu # (H,W,K,3) | |
sigma_inv = torch.linalg.inv(self.sigma) # (H,W,K,3,3) | |
# Compute (x-μ)T Σ^-1 (x-μ) | |
temp = torch.einsum('hwki,hwkij->hwkj', delta, sigma_inv) | |
mahalanobis = torch.sqrt(torch.einsum('hwki,hwki->hwk', temp, delta)) | |
# Update matches where distance < 2.5 and not already matched | |
match_mask = (mahalanobis[:,:,k] < 2.5) & (matches == -1) | |
matches[match_mask] = k | |
# Process matched pixels | |
for k in range(K): | |
# Get mask for current distribution matches | |
mask = matches == k | |
if mask.any(): | |
# Get matched pixels | |
matched_pixels = img_tensor[mask] # (N,3) | |
matched_mu = self.mu[:,:,k,:][mask] # (N,3) | |
matched_sigma = self.sigma[:,:,k,:,:][mask] # (N,3,3) | |
try: | |
# Create multivariate normal distribution | |
mvn = MultivariateNormal(matched_mu, | |
covariance_matrix=matched_sigma) | |
# Calculate rho | |
rho = self.alpha * torch.exp(mvn.log_prob(matched_pixels)) | |
# Update weights | |
self.weight[:,:,k][mask] = (1 - self.alpha) * self.weight[:,:,k][mask] + self.alpha | |
# Update mu | |
delta = matched_pixels - matched_mu | |
self.mu[:,:,k,:][mask] += rho.unsqueeze(1) * delta | |
# Update sigma | |
delta_outer = torch.einsum('bi,bj->bij', delta, delta) | |
sigma_update = rho.unsqueeze(1).unsqueeze(2) * (delta_outer - matched_sigma) | |
self.sigma[:,:,k,:,:][mask] += sigma_update | |
except RuntimeError as e: | |
print(f"Error updating distribution {k}: {e}") | |
continue | |
# Process non-matched pixels | |
non_matched = matches == -1 | |
if non_matched.any(): | |
# Find least probable distribution for each non-matched pixel | |
weight_non_matched = self.weight[non_matched] # shape: (N, K) | |
min_weight_idx = torch.argmin(weight_non_matched, dim=1) # shape: (N,) | |
# Create flat indices of non-matched pixels | |
non_matched_indices = non_matched.nonzero(as_tuple=False) # shape: (N, 2) | |
for k in range(K): | |
# Find positions where min_weight_idx == k | |
k_mask = (min_weight_idx == k) | |
if k_mask.any(): | |
selected_indices = non_matched_indices[k_mask] # shape: (M, 2) | |
y_idx = selected_indices[:, 0] | |
x_idx = selected_indices[:, 1] | |
# Update mu and sigma | |
self.mu[y_idx, x_idx, k, :] = img_tensor[y_idx, x_idx] | |
self.sigma[y_idx, x_idx, k, :, :] = torch.eye(3, device=device) * 225 | |
# Convert to numpy for reordering and debug prints | |
weight_np = self.weight.cpu().numpy() | |
mu_np = self.mu.cpu().numpy() | |
sigma_np = self.sigma.cpu().numpy() | |
B_np = self.B.cpu().numpy() | |
print('img:{}'.format(img[100][100])) | |
print('weight:{}'.format(weight_np[100][100])) | |
# Update numpy arrays for reorder | |
self.weight = weight_np | |
self.mu = mu_np | |
self.sigma = sigma_np | |
self.B = B_np | |
self.reorder() | |
for i in range(self.K): | |
print('u:{}'.format(self.mu[100][100][i])) | |
# Move back to GPU for next iteration | |
self.weight = torch.from_numpy(self.weight).to(device) | |
self.mu = torch.from_numpy(self.mu).to(device) | |
self.sigma = torch.from_numpy(self.sigma).to(device) | |
self.B = torch.from_numpy(self.B).to(device) | |
def save_model(self, file_path): | |
""" | |
Save the trained model to a file | |
""" | |
# Only make directories if there is a directory in the path | |
dir_name = os.path.dirname(file_path) | |
if dir_name: | |
os.makedirs(dir_name, exist_ok=True) | |
joblib.dump({ | |
'weight': self.weight, | |
'mu': self.mu, | |
'sigma': self.sigma, | |
'K': self.K, | |
'B': self.B, | |
'img_shape': self.img_shape, | |
'alpha': self.alpha, | |
'data_dir': self.data_dir, | |
'train_num': self.train_num | |
}, file_path) | |
print(f"Model saved to {file_path}") | |
def load_model(cls, file_path): | |
""" | |
Load a trained model from file | |
""" | |
data = joblib.load(file_path) | |
# Create new instance | |
gmm = cls(data['data_dir'], data['train_num'], data['alpha']) | |
# Restore all attributes | |
gmm.weight = data['weight'] | |
gmm.mu = data['mu'] | |
gmm.sigma = data['sigma'] | |
gmm.K = data['K'] | |
gmm.B = data['B'] | |
gmm.img_shape = data['img_shape'] | |
gmm.image_shape = data['img_shape'] | |
print(f"Model loaded from {file_path}") | |
return gmm | |
# @classmethod | |
# def load_model(cls, file_path): | |
# """ | |
# Load a trained model safely onto CPU, even if saved from GPU. | |
# """ | |
# import pickle | |
# def cpu_load(path): | |
# with open(path, "rb") as f: | |
# unpickler = pickle._Unpickler(f) | |
# unpickler.persistent_load = lambda saved_id: torch.load(saved_id, map_location="cpu") | |
# return unpickler.load() | |
# # Force joblib to use pickle with CPU-mapped tensors | |
# data = cpu_load(file_path) | |
# # Create instance | |
# gmm = cls(data['data_dir'], data['train_num'], data['alpha']) | |
# Assign all attributes (already CPU tensors now) | |
gmm.weight = data['weight'] | |
gmm.mu = data['mu'] | |
gmm.sigma = data['sigma'] | |
gmm.K = data['K'] | |
gmm.B = data['B'] | |
gmm.img_shape = data['img_shape'] | |
gmm.image_shape = data['img_shape'] | |
print(f"✅ GMM model loaded on CPU from {file_path}") | |
return gmm | |
def reorder(self, T=0.90): | |
''' | |
Reorder the estimated components based on the ratio pi / the norm of standard deviation. | |
The first B components are chosen as background components. | |
The default threshold is 0.90. | |
''' | |
epsilon = 1e-6 # to prevent divide-by-zero | |
for i in range(self.img_shape[0]): | |
for j in range(self.img_shape[1]): | |
k_weight = self.weight[i][j] | |
k_norm = [] | |
for k in range(self.K): | |
cov = self.sigma[i][j][k] | |
try: | |
if np.all(np.linalg.eigvals(cov) >= 0): | |
stddev = np.sqrt(cov) | |
k_norm.append(norm(stddev)) | |
else: | |
k_norm.append(epsilon) | |
except: | |
k_norm.append(epsilon) | |
k_norm = np.array(k_norm) | |
ratio = k_weight / (k_norm + epsilon) | |
descending_order = np.argsort(-ratio) | |
self.weight[i][j] = self.weight[i][j][descending_order] | |
self.mu[i][j] = self.mu[i][j][descending_order] | |
self.sigma[i][j] = self.sigma[i][j][descending_order] | |
cum_weight = 0 | |
for index, order in enumerate(descending_order): | |
cum_weight += self.weight[i][j][index] | |
if cum_weight > T: | |
self.B[i][j] = index + 1 | |
break | |
# def infer(self, img, heatmap=None, alpha=0.1): | |
# ''' | |
# Perform inference with a persistent heatmap that intensifies with movement. | |
# ''' | |
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# img_tensor = torch.from_numpy(img).float().to(device) # (H, W, 3) | |
# H, W, _ = img.shape | |
# # Initialize heatmap on the first frame | |
# if heatmap is None: | |
# heatmap = torch.zeros((H, W), dtype=torch.float32, device=device) | |
# # No need for an 'else' that converts from numpy, | |
# # as we will pass the tensor back in subsequent calls. | |
# # --- Your existing foreground detection logic remains the same --- | |
# detection_mask = torch.ones((H, W), dtype=torch.bool, device=device) | |
# for k in range(self.K): | |
# B_mask = (self.B >= (k + 1)).to(device) | |
# mu_k = self.mu[:, :, k, :].to(device) | |
# sigma_k = self.sigma[:, :, k, :, :].to(device) | |
# delta = (img_tensor - mu_k).unsqueeze(-1) | |
# sigma_inv = torch.linalg.inv(sigma_k) | |
# temp = torch.matmul(sigma_inv, delta) | |
# dist_sq = torch.matmul(delta.transpose(-2, -1), temp).squeeze(-1).squeeze(-1) | |
# dist = torch.sqrt(dist_sq + 1e-5) | |
# match_mask = (dist < 9.5) & B_mask | |
# detection_mask[match_mask] = False | |
# img_tensor[match_mask] = mu_k[match_mask] # Optional: for visualization | |
# foreground_mask = detection_mask & (img_tensor.abs().sum(dim=-1) > 0) | |
# heatmap[foreground_mask] = torch.clamp(heatmap[foreground_mask] + alpha, 0, 1) | |
# # Convert heatmap tensor to a numpy array for visualization | |
# heatmap_np = heatmap.cpu().numpy() | |
# # Apply the colormap (0 -> Blue, 1 -> Red) | |
# heatmap_viz = cv.applyColorMap((heatmap_np * 255).astype(np.uint8), cv.COLORMAP_JET) | |
# # Blend the heatmap with the original image | |
# result = cv.addWeighted(img, 0.7, heatmap_viz, 0.5, 0) | |
# # Return the blended image and the heatmap tensor for the next frame | |
# return result, heatmap | |
#-------------------------------------------------------------------------------------------- | |
def infer(self, img, heatmap=None, decay_factor=0.95, alpha=0.1): | |
''' | |
Perform inference with improved heatmap reflecting persistence of foreground objects. | |
Default areas remain unchanged (no bluish tone), only heatmap areas are colored. | |
''' | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
img_tensor = torch.from_numpy(img).float().to(device) # (H, W, 3) | |
H, W, _ = img.shape | |
# Initialize or move heatmap to tensor on device | |
if heatmap is None: | |
heatmap = torch.zeros((H, W), dtype=torch.float32, device=device) | |
else: | |
heatmap = torch.from_numpy(heatmap).float().to(device) | |
# Detection mask initialized to 1 (foreground), 0 means background | |
detection_mask = torch.ones((H, W), dtype=torch.bool, device=device) | |
for k in range(self.K): | |
B_mask = (self.B >= (k + 1)).to(device) | |
mu_k = self.mu[:, :, k, :].to(device) | |
sigma_k = self.sigma[:, :, k, :, :].to(device) | |
delta = img_tensor - mu_k | |
delta = delta.unsqueeze(-1) | |
sigma_inv = torch.linalg.inv(sigma_k) | |
temp = torch.matmul(sigma_inv, delta) | |
dist_sq = torch.matmul(delta.transpose(-2, -1), temp).squeeze(-1).squeeze(-1) | |
dist = torch.sqrt(dist_sq + 1e-5) | |
match_mask = (dist < 9.5) & B_mask | |
# Mark matched pixels as background | |
detection_mask[match_mask] = False | |
img_tensor[match_mask] = mu_k[match_mask] | |
# Foreground mask (boolean tensor) | |
foreground_mask = detection_mask & (img_tensor.abs().sum(dim=-1) > 0) | |
# Update heatmap: | |
heatmap[foreground_mask] = torch.clamp(heatmap[foreground_mask] + alpha, 0, 1) | |
heatmap[~foreground_mask] *= decay_factor | |
# Convert heatmap to numpy for visualization | |
heatmap_np = heatmap.cpu().numpy() | |
# Create heatmap visualization | |
heatmap_viz = cv.applyColorMap((heatmap_np * 255).astype(np.uint8), cv.COLORMAP_JET) | |
# Create mask of significant heatmap areas (adjust threshold as needed) | |
significant_heat = (heatmap_np > 0.1) | |
# Initialize result with original image | |
result = img.copy() | |
# Only process if there are significant heat areas | |
if np.any(significant_heat): | |
# Ensure we have valid regions to blend | |
img_region = img[significant_heat] | |
heat_region = heatmap_viz[significant_heat] | |
# Only blend if we have valid regions | |
if img_region.size > 0 and heat_region.size > 0: | |
blended = cv.addWeighted( | |
img_region, 0.7, | |
heat_region, 0.3, | |
0 | |
) | |
result[significant_heat] = blended | |
return result, heatmap_np | |
#_____________________________________________________________________________________Decay factor and working good | |
# def infer(self, img, heatmap=None, decay_factor=0.95, alpha=0.1): | |
# ''' | |
# Perform inference with binary red mask (no intensity variation) and dilation. | |
# Returns: | |
# - result: Image with solid red overlay on detections (same dtype as input) | |
# - heatmap_np: Heatmap array | |
# ''' | |
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# # Ensure input is numpy array and get original dtype | |
# original_dtype = img.dtype | |
# img = np.asarray(img).astype(np.float32) | |
# H, W, C = img.shape | |
# # Initialize tensors | |
# img_tensor = torch.from_numpy(img).float().to(device) | |
# # Initialize heatmap | |
# if heatmap is None: | |
# heatmap = torch.zeros((H, W), dtype=torch.float32, device=device) | |
# else: | |
# heatmap = torch.from_numpy(heatmap).float().to(device) | |
# # Detection processing (your original code) | |
# detection_mask = torch.ones((H, W), dtype=torch.bool, device=device) | |
# for k in range(self.K): | |
# B_mask = (self.B >= (k + 1)).to(device) | |
# mu_k = self.mu[:, :, k, :].to(device) | |
# sigma_k = self.sigma[:, :, k, :, :].to(device) | |
# delta = img_tensor - mu_k | |
# delta = delta.unsqueeze(-1) | |
# sigma_inv = torch.linalg.inv(sigma_k) | |
# temp = torch.matmul(sigma_inv, delta) | |
# dist_sq = torch.matmul(delta.transpose(-2, -1), temp).squeeze(-1).squeeze(-1) | |
# dist = torch.sqrt(dist_sq + 1e-5) | |
# match_mask = (dist < 9.5) & B_mask | |
# detection_mask[match_mask] = False | |
# img_tensor[match_mask] = mu_k[match_mask] | |
# # Update heatmap | |
# foreground_mask = detection_mask & (img_tensor.abs().sum(dim=-1) > 0) | |
# heatmap[foreground_mask] = torch.clamp(heatmap[foreground_mask] + alpha, 0, 1) | |
# heatmap[~foreground_mask] *= decay_factor | |
# heatmap_np = heatmap.cpu().numpy() | |
# # Create binary mask and dilate | |
# binary_mask = (heatmap_np > 0.1).astype(np.uint8) | |
# kernel = np.ones((5,5), np.uint8) | |
# dilated_mask = cv.dilate(binary_mask, kernel, iterations=1) | |
# # Create solid red overlay (BGR) | |
# red_overlay = np.zeros_like(img) | |
# red_overlay[..., 2] = 200 # Red channel | |
# # Apply overlay using where instead of boolean indexing | |
# result = np.where( | |
# dilated_mask[..., np.newaxis].astype(bool), | |
# cv.addWeighted(img, 0.7, red_overlay, 0.3, 0), | |
# img | |
# ) | |
# # Convert back to original dtype | |
# if original_dtype != np.float32: | |
# result = np.clip(result, 0, 255).astype(original_dtype) | |
# return result, heatmap_np | |
#________________________________________________________________________________________________ | |
# def infer(self, img, heatmap=None, alpha=0.1): | |
# ''' | |
# Perform inference with binary red mask (no intensity variation) and dilation. | |
# Heatmap is fully recalculated every frame — no temporal decay or retention. | |
# Returns: | |
# - result: Image with solid red overlay on detections | |
# - heatmap_np: Binary heatmap array | |
# ''' | |
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# # Ensure input is numpy array and get original dtype | |
# original_dtype = img.dtype | |
# img = np.asarray(img).astype(np.float32) | |
# H, W, C = img.shape | |
# # Initialize tensors | |
# img_tensor = torch.from_numpy(img).float().to(device) | |
# # Detection processing | |
# detection_mask = torch.ones((H, W), dtype=torch.bool, device=device) | |
# for k in range(self.K): | |
# B_mask = (self.B >= (k + 1)).to(device) | |
# mu_k = self.mu[:, :, k, :].to(device) | |
# sigma_k = self.sigma[:, :, k, :, :].to(device) | |
# delta = img_tensor - mu_k | |
# delta = delta.unsqueeze(-1) | |
# sigma_inv = torch.linalg.inv(sigma_k) | |
# temp = torch.matmul(sigma_inv, delta) | |
# dist_sq = torch.matmul(delta.transpose(-2, -1), temp).squeeze(-1).squeeze(-1) | |
# dist = torch.sqrt(dist_sq + 1e-5) | |
# match_mask = (dist < 9.5) & B_mask | |
# detection_mask[match_mask] = False | |
# img_tensor[match_mask] = mu_k[match_mask] | |
# # Generate a binary heatmap (no decay, no accumulation) | |
# foreground_mask = detection_mask & (img_tensor.abs().sum(dim=-1) > 0) | |
# heatmap = torch.zeros((H, W), dtype=torch.float32, device=device) | |
# heatmap[foreground_mask] = alpha | |
# heatmap_np = heatmap.cpu().numpy() | |
# # Create binary mask and dilate | |
# binary_mask = (heatmap_np > 0.05).astype(np.uint8) | |
# kernel = np.ones((5, 5), np.uint8) | |
# dilated_mask = cv.dilate(binary_mask, kernel, iterations=1) | |
# # Create solid red overlay (BGR) | |
# red_overlay = np.zeros_like(img) | |
# red_overlay[..., 2] = 200 # Red channel | |
# # Apply overlay | |
# result = np.where( | |
# dilated_mask[..., np.newaxis].astype(bool), | |
# cv.addWeighted(img, 0.7, red_overlay, 0.3, 0), | |
# img | |
# ) | |
# # Convert back to original dtype | |
# if original_dtype != np.float32: | |
# result = np.clip(result, 0, 255).astype(original_dtype) | |
# return result, heatmap_np | |
# def infer(self, img, heatmap=None, alpha=0.1): | |
# ''' | |
# Perform inference with binary red mask and GPU-based dilation. | |
# Heatmap is recalculated each frame (no temporal retention). | |
# Returns: | |
# - result: Image with red overlay where foreground is detected. | |
# - heatmap_np: Numpy array of binary heatmap. | |
# ''' | |
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# # Convert image to float32 and move to GPU | |
# original_dtype = img.dtype | |
# img = np.asarray(img).astype(np.float32) | |
# H, W, C = img.shape | |
# img_tensor = torch.from_numpy(img).float().to(device) | |
# # Initialize detection mask as all True (foreground by default) | |
# detection_mask = torch.ones((H, W), dtype=torch.bool, device=device) | |
# for k in range(self.K): | |
# B_mask = (self.B >= (k + 1)).to(device) | |
# mu_k = self.mu[:, :, k, :].to(device) | |
# sigma_k = self.sigma[:, :, k, :, :].to(device) | |
# delta = img_tensor - mu_k | |
# delta = delta.unsqueeze(-1) # shape: (H, W, 3, 1) | |
# sigma_inv = torch.linalg.inv(sigma_k) | |
# temp = torch.matmul(sigma_inv, delta) | |
# dist_sq = torch.matmul(delta.transpose(-2, -1), temp).squeeze(-1).squeeze(-1) | |
# dist = torch.sqrt(dist_sq + 1e-5) | |
# match_mask = (dist < 9.5) & B_mask | |
# detection_mask[match_mask] = False | |
# # img_tensor[match_mask] = mu_k[match_mask] | |
# # Generate heatmap | |
# foreground_mask = detection_mask & (img_tensor.abs().sum(dim=-1) > 0) | |
# heatmap_tensor = torch.zeros((H, W), dtype=torch.float32, device=device) | |
# heatmap_tensor[foreground_mask] = alpha | |
# # Convert heatmap to binary mask and apply dilation (GPU-based) | |
# binary_mask = (heatmap_tensor > 0.05).float().unsqueeze(0).unsqueeze(0) # shape: (1, 1, H, W) | |
# kernel = torch.ones((1, 1, 5, 5), dtype=torch.float32, device=device) | |
# dilated = F.conv2d(binary_mask, kernel, padding=2) | |
# dilated_mask = (dilated > 0).squeeze().to(torch.bool) | |
# # Create red overlay (on GPU) | |
# red_overlay = torch.zeros_like(img_tensor) | |
# red_overlay[..., 2] = 200 # Red channel | |
# # Blend red overlay on detected regions | |
# result_tensor = torch.where( | |
# dilated_mask.unsqueeze(-1), | |
# 0.7 * img_tensor + 0.3 * red_overlay, | |
# img_tensor | |
# ) | |
# # Convert back to NumPy and original dtype | |
# result = result_tensor.clamp(0, 255).cpu().numpy() | |
# if original_dtype != np.float32: | |
# result = result.astype(original_dtype) | |
# heatmap_np = (heatmap_tensor > 0.05).float().cpu().numpy() | |
# return result, heatmap_np | |