File size: 1,731 Bytes
1b34a12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
Copyright (c) Facebook, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import torch

import fastmri
import sigpy as sp
import numpy as np


def rss(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
    """
    Compute the Root Sum of Squares (RSS).

    The RSS is computed assuming that `dim` is the coil dimension.

    Parameters
    ----------
    data : torch.Tensor
        The input tensor.
    dim : int, optional
        The dimension along which to apply the RSS transform (default is 0).

    Returns
    -------
    torch.Tensor
        The computed RSS value.
    """
    return torch.sqrt((data**2).sum(dim))


def mvue(spatial_pred, sens_maps, dim: int = 0) -> torch.Tensor:
    spatial_pred = torch.view_as_complex(spatial_pred)
    sens_maps = torch.view_as_complex(sens_maps)
    
    numerator = torch.sum(spatial_pred * torch.conj(sens_maps), dim=dim)
    denominator = torch.sqrt(
        torch.sum(torch.square(torch.abs(sens_maps)), dim=dim)
    )
    res = numerator / denominator
    res = torch.abs(res)
    return res


def rss_complex(data: torch.Tensor, dim: int = 0) -> torch.Tensor:
    """
    Compute the Root Sum of Squares (RSS) for complex inputs.

    The RSS is computed assuming that `dim` is the coil dimension.

    Parameters
    ----------
    data : torch.Tensor
        The input tensor containing complex values.
    dim : int, optional
        The dimension along which to apply the RSS transform (default is 0).

    Returns
    -------
    torch.Tensor
        The computed RSS value for complex inputs.
    """
    return torch.sqrt(fastmri.complex_abs_sq(data).sum(dim))