AnySplat / src /dataset /shims /normalize_shim.py
alexnasa's picture
Upload 243 files
2568013 verified
import torch
from einops import einsum, reduce, repeat
from jaxtyping import Float
from torch import Tensor
from ..types import BatchedExample
def inverse_normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
return tensor * std + mean
def normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
return (tensor - mean) / std
def apply_normalize_shim(
batch: BatchedExample,
mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
std: tuple[float, float, float] = (0.5, 0.5, 0.5),
) -> BatchedExample:
batch["context"]["image"] = normalize_image(batch["context"]["image"], mean, std)
return batch