import torch from einops import einsum, reduce, repeat from jaxtyping import Float from torch import Tensor from ..types import BatchedExample def inverse_normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) return tensor * std + mean def normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1) return (tensor - mean) / std def apply_normalize_shim( batch: BatchedExample, mean: tuple[float, float, float] = (0.5, 0.5, 0.5), std: tuple[float, float, float] = (0.5, 0.5, 0.5), ) -> BatchedExample: batch["context"]["image"] = normalize_image(batch["context"]["image"], mean, std) return batch