Spaces:
Running
Running
from typing import Optional | |
import torch | |
def normalize(x: torch.Tensor, min: float = -1.0, max: float = 1.0, dim: Optional[int] = None) -> torch.Tensor: | |
""" | |
Normalize a tensor to the range [min_val, max_val]. | |
Args: | |
x (`torch.Tensor`): | |
The input tensor to normalize. | |
min (`float`, defaults to `-1.0`): | |
The minimum value of the normalized range. | |
max (`float`, defaults to `1.0`): | |
The maximum value of the normalized range. | |
dim (`int`, *optional*): | |
The dimension along which to normalize. If `None`, the entire tensor is normalized. | |
Returns: | |
The normalized tensor of the same shape as `x`. | |
""" | |
if dim is None: | |
x_min = x.min() | |
x_max = x.max() | |
if torch.isclose(x_min, x_max).any(): | |
x = torch.full_like(x, min) | |
else: | |
x = min + (max - min) * (x - x_min) / (x_max - x_min) | |
else: | |
x_min = x.amin(dim=dim, keepdim=True) | |
x_max = x.amax(dim=dim, keepdim=True) | |
if torch.isclose(x_min, x_max).any(): | |
x = torch.full_like(x, min) | |
else: | |
x = min + (max - min) * (x - x_min) / (x_max - x_min) | |
return x | |