File size: 1,010 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch
from einops import einsum, reduce, repeat
from jaxtyping import Float
from torch import Tensor
from ..types import BatchedExample
def inverse_normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
return tensor * std + mean
def normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
return (tensor - mean) / std
def apply_normalize_shim(
batch: BatchedExample,
mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
std: tuple[float, float, float] = (0.5, 0.5, 0.5),
) -> BatchedExample:
batch["context"]["image"] = normalize_image(batch["context"]["image"], mean, std)
return batch
|