File size: 1,219 Bytes
9fd1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import Optional

import torch


def normalize(x: torch.Tensor, min: float = -1.0, max: float = 1.0, dim: Optional[int] = None) -> torch.Tensor:
    """
    Normalize a tensor to the range [min_val, max_val].

    Args:
        x (`torch.Tensor`):
            The input tensor to normalize.
        min (`float`, defaults to `-1.0`):
            The minimum value of the normalized range.
        max (`float`, defaults to `1.0`):
            The maximum value of the normalized range.
        dim (`int`, *optional*):
            The dimension along which to normalize. If `None`, the entire tensor is normalized.

    Returns:
        The normalized tensor of the same shape as `x`.
    """
    if dim is None:
        x_min = x.min()
        x_max = x.max()
        if torch.isclose(x_min, x_max).any():
            x = torch.full_like(x, min)
        else:
            x = min + (max - min) * (x - x_min) / (x_max - x_min)
    else:
        x_min = x.amin(dim=dim, keepdim=True)
        x_max = x.amax(dim=dim, keepdim=True)
        if torch.isclose(x_min, x_max).any():
            x = torch.full_like(x, min)
        else:
            x = min + (max - min) * (x - x_min) / (x_max - x_min)
    return x