Spaces:
Running
Running
File size: 1,219 Bytes
9fd1204 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from typing import Optional
import torch
def normalize(x: torch.Tensor, min: float = -1.0, max: float = 1.0, dim: Optional[int] = None) -> torch.Tensor:
"""
Normalize a tensor to the range [min_val, max_val].
Args:
x (`torch.Tensor`):
The input tensor to normalize.
min (`float`, defaults to `-1.0`):
The minimum value of the normalized range.
max (`float`, defaults to `1.0`):
The maximum value of the normalized range.
dim (`int`, *optional*):
The dimension along which to normalize. If `None`, the entire tensor is normalized.
Returns:
The normalized tensor of the same shape as `x`.
"""
if dim is None:
x_min = x.min()
x_max = x.max()
if torch.isclose(x_min, x_max).any():
x = torch.full_like(x, min)
else:
x = min + (max - min) * (x - x_min) / (x_max - x_min)
else:
x_min = x.amin(dim=dim, keepdim=True)
x_max = x.amax(dim=dim, keepdim=True)
if torch.isclose(x_min, x_max).any():
x = torch.full_like(x, min)
else:
x = min + (max - min) * (x - x_min) / (x_max - x_min)
return x
|