File size: 4,533 Bytes
966ae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
r"""General purpose functions"""
from typing import Tuple, Union, Optional
import torch
from ..utils import _parse_version


def ifftshift(x: torch.Tensor) -> torch.Tensor:
    r""" Similar to np.fft.ifftshift but applies to PyTorch Tensors"""
    shift = [-(ax // 2) for ax in x.size()]
    return torch.roll(x, shift, tuple(range(len(shift))))


def get_meshgrid(size: Tuple[int, int], device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor:
    r"""Return coordinate grid matrices centered at zero point.
    Args:
        size: Shape of meshgrid to create
        device: device to use for creation
        dtype: dtype to use for creation
    Returns:
        Meshgrid of size on device with dtype values.
    """
    if size[0] % 2:
        # Odd
        x = torch.arange(-(size[0] - 1) / 2, size[0] / 2, device=device, dtype=dtype) / (size[0] - 1)
    else:
        # Even
        x = torch.arange(- size[0] / 2, size[0] / 2, device=device, dtype=dtype) / size[0]

    if size[1] % 2:
        # Odd
        y = torch.arange(-(size[1] - 1) / 2, size[1] / 2, device=device, dtype=dtype) / (size[1] - 1)
    else:
        # Even
        y = torch.arange(- size[1] / 2, size[1] / 2, device=device, dtype=dtype) / size[1]
    # Use indexing param depending on torch version
    recommended_torch_version = _parse_version("1.10.0")
    torch_version = _parse_version(torch.__version__)
    if len(torch_version) > 0 and torch_version >= recommended_torch_version:
        return torch.meshgrid(x, y, indexing='ij')
    return torch.meshgrid(x, y)


def similarity_map(map_x: torch.Tensor, map_y: torch.Tensor, constant: float, alpha: float = 0.0) -> torch.Tensor:
    r""" Compute similarity_map between two tensors using Dice-like equation.

    Args:
        map_x: Tensor with map to be compared
        map_y: Tensor with map to be compared
        constant: Used for numerical stability
        alpha: Masking coefficient. Subtracts - `alpha` * map_x * map_y from denominator and nominator
    """
    return (2.0 * map_x * map_y - alpha * map_x * map_y + constant) / \
           (map_x ** 2 + map_y ** 2 - alpha * map_x * map_y + constant)


def gradient_map(x: torch.Tensor, kernels: torch.Tensor) -> torch.Tensor:
    r""" Compute gradient map for a given tensor and stack of kernels.

    Args:
        x: Tensor with shape (N, C, H, W).
        kernels: Stack of tensors for gradient computation with shape (k_N, k_H, k_W)
    Returns:
        Gradients of x per-channel with shape (N, C, H, W)
    """
    padding = kernels.size(-1) // 2
    grads = torch.nn.functional.conv2d(x, kernels, padding=padding)

    return torch.sqrt(torch.sum(grads ** 2, dim=-3, keepdim=True))


def pow_for_complex(base: torch.Tensor, exp: Union[int, float]) -> torch.Tensor:
    r""" Takes the power of each element in a 4D tensor with negative values or 5D tensor with complex values.
    Complex numbers are represented by modulus and argument: r * \exp(i * \phi).

    It will likely to be redundant with introduction of torch.ComplexTensor.

    Args:
        base: Tensor with shape (N, C, H, W) or (N, C, H, W, 2).
        exp: Exponent
    Returns:
        Complex tensor with shape (N, C, H, W, 2).
    """
    if base.dim() == 4:
        x_complex_r = base.abs()
        x_complex_phi = torch.atan2(torch.zeros_like(base), base)
    elif base.dim() == 5 and base.size(-1) == 2:
        x_complex_r = base.pow(2).sum(dim=-1).sqrt()
        x_complex_phi = torch.atan2(base[..., 1], base[..., 0])
    else:
        raise ValueError(f'Expected real or complex tensor, got {base.size()}')

    x_complex_pow_r = x_complex_r ** exp
    x_complex_pow_phi = x_complex_phi * exp
    x_real_pow = x_complex_pow_r * torch.cos(x_complex_pow_phi)
    x_imag_pow = x_complex_pow_r * torch.sin(x_complex_pow_phi)
    return torch.stack((x_real_pow, x_imag_pow), dim=-1)


def crop_patches(x: torch.Tensor, size=64, stride=32) -> torch.Tensor:
    r"""Crop tensor with images into small patches
    Args:
        x: Tensor with shape (N, C, H, W), expected to be images-like entities
        size: Size of a square patch
        stride: Step between patches
    """
    assert (x.shape[2] >= size) and (x.shape[3] >= size), \
        f"Images must be bigger than patch size. Got ({x.shape[2], x.shape[3]}) and ({size}, {size})"
    channels = x.shape[1]
    patches = x.unfold(1, channels, channels).unfold(2, size, stride).unfold(3, size, stride)
    patches = patches.reshape(-1, channels, size, size)
    return patches