|
|
|
|
|
|
|
import warnings |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
|
|
|
|
def _fspecial_gauss_1d(size: int, sigma: float) -> Tensor: |
|
r"""Create 1-D gauss kernel |
|
Args: |
|
size (int): the size of gauss kernel |
|
sigma (float): sigma of normal distribution |
|
Returns: |
|
torch.Tensor: 1D kernel (1 x 1 x size) |
|
""" |
|
coords = torch.arange(size, dtype=torch.float) |
|
coords -= size // 2 |
|
|
|
g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) |
|
g /= g.sum() |
|
|
|
return g.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
def gaussian_filter(input: Tensor, win: Tensor) -> Tensor: |
|
r""" Blur input with 1-D kernel |
|
Args: |
|
input (torch.Tensor): a batch of tensors to be blurred |
|
window (torch.Tensor): 1-D gauss kernel |
|
Returns: |
|
torch.Tensor: blurred tensors |
|
""" |
|
assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape |
|
if len(input.shape) == 4: |
|
conv = F.conv2d |
|
elif len(input.shape) == 5: |
|
conv = F.conv3d |
|
else: |
|
raise NotImplementedError(input.shape) |
|
|
|
C = input.shape[1] |
|
out = input |
|
for i, s in enumerate(input.shape[2:]): |
|
if s >= win.shape[-1]: |
|
out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C) |
|
else: |
|
warnings.warn( |
|
f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}" |
|
) |
|
|
|
return out |
|
|
|
|
|
def _ssim( |
|
X: Tensor, |
|
Y: Tensor, |
|
data_range: float, |
|
win: Tensor, |
|
size_average: bool = True, |
|
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), |
|
retrun_seprate: bool = False, |
|
) -> Tuple[Tensor, Tensor, Tensor | None, Tensor | None, Tensor | None]: |
|
r""" Calculate ssim index for X and Y |
|
|
|
Args: |
|
X (torch.Tensor): images |
|
Y (torch.Tensor): images |
|
data_range (float or int): value range of input images. (usually 1.0 or 255) |
|
win (torch.Tensor): 1-D gauss kernel |
|
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar |
|
retrun_seprate (bool, optional): if True, return brightness, contrast, and structure similarity maps as well |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: ssim results. |
|
""" |
|
K1, K2 = K |
|
|
|
compensation = 1.0 |
|
|
|
C1 = (K1 * data_range) ** 2 |
|
C2 = (K2 * data_range) ** 2 |
|
|
|
win = win.to(X.device, dtype=X.dtype) |
|
|
|
mu1 = gaussian_filter(X, win) |
|
mu2 = gaussian_filter(Y, win) |
|
|
|
mu1_sq = mu1.pow(2) |
|
mu2_sq = mu2.pow(2) |
|
mu1_mu2 = mu1 * mu2 |
|
|
|
sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq) |
|
sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq) |
|
sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2) |
|
|
|
cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) |
|
ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map |
|
ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1) |
|
cs = torch.flatten(cs_map, 2).mean(-1) |
|
|
|
brightness = contrast = structure = torch.zeros_like(ssim_per_channel) |
|
if retrun_seprate: |
|
epsilon = torch.finfo(torch.float32).eps**2 |
|
sigma1_sq = sigma1_sq.clamp(min=epsilon) |
|
sigma2_sq = sigma2_sq.clamp(min=epsilon) |
|
sigma12 = torch.sign(sigma12) * torch.minimum( |
|
torch.sqrt(sigma1_sq * sigma2_sq), torch.abs(sigma12)) |
|
|
|
C3 = C2 / 2 |
|
sigma1_sigma2 = torch.sqrt(sigma1_sq) * torch.sqrt(sigma2_sq) |
|
brightness_map = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1) |
|
contrast_map = (2 * sigma1_sigma2 + C2) / (sigma1_sq + sigma2_sq + C2) |
|
structure_map = (sigma12 + C3) / (sigma1_sigma2 + C3) |
|
|
|
contrast_map = contrast_map.clamp(max=0.98) |
|
structure_map = structure_map.clamp(max=0.98) |
|
|
|
brightness = brightness_map.flatten(2).mean(-1) |
|
contrast = contrast_map.flatten(2).mean(-1) |
|
structure = structure_map.flatten(2).mean(-1) |
|
|
|
return ssim_per_channel, cs, brightness, contrast, structure |
|
|
|
|
|
def ssim( |
|
X: Tensor, |
|
Y: Tensor, |
|
data_range: float = 255, |
|
size_average: bool = True, |
|
win_size: int = 11, |
|
win_sigma: float = 1.5, |
|
win: Optional[Tensor] = None, |
|
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), |
|
nonnegative_ssim: bool = False, |
|
retrun_seprate: bool = False, |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: |
|
r""" interface of ssim |
|
Args: |
|
X (torch.Tensor): a batch of images, (N,C,H,W) |
|
Y (torch.Tensor): a batch of images, (N,C,H,W) |
|
data_range (float or int, optional): value range of input images. (usually 1.0 or 255) |
|
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar |
|
win_size: (int, optional): the size of gauss kernel |
|
win_sigma: (float, optional): sigma of normal distribution |
|
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma |
|
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. |
|
nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu |
|
retrun_seprate (bool, optional): if True, return brightness, contrast, and structure similarity maps as well |
|
|
|
Returns: |
|
torch.Tensor: ssim results |
|
""" |
|
if not X.shape == Y.shape: |
|
raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.") |
|
|
|
for d in range(len(X.shape) - 1, 1, -1): |
|
X = X.squeeze(dim=d) |
|
Y = Y.squeeze(dim=d) |
|
|
|
if len(X.shape) not in (4, 5): |
|
raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}") |
|
|
|
|
|
|
|
|
|
if win is not None: |
|
win_size = win.shape[-1] |
|
|
|
if not (win_size % 2 == 1): |
|
raise ValueError("Window size should be odd.") |
|
|
|
if win is None: |
|
win = _fspecial_gauss_1d(win_size, win_sigma) |
|
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) |
|
|
|
ssim_per_channel, cs, brightness, contrast, structure \ |
|
= _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K, retrun_seprate=retrun_seprate) |
|
|
|
if nonnegative_ssim: |
|
ssim_per_channel = torch.relu(ssim_per_channel) |
|
|
|
if size_average: |
|
return ssim_per_channel.mean(), brightness.mean(), contrast.mean(), structure.mean() |
|
else: |
|
return ssim_per_channel.mean(1), brightness.mean(1), contrast.mean(1), structure.mean(1) |
|
|
|
|
|
def ms_ssim( |
|
X: Tensor, |
|
Y: Tensor, |
|
data_range: float = 255, |
|
size_average: bool = True, |
|
win_size: int = 11, |
|
win_sigma: float = 1.5, |
|
win: Optional[Tensor] = None, |
|
weights: Optional[List[float]] = None, |
|
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03) |
|
) -> Tensor: |
|
r""" interface of ms-ssim |
|
Args: |
|
X (torch.Tensor): a batch of images, (N,C,[T,]H,W) |
|
Y (torch.Tensor): a batch of images, (N,C,[T,]H,W) |
|
data_range (float or int, optional): value range of input images. (usually 1.0 or 255) |
|
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar |
|
win_size: (int, optional): the size of gauss kernel |
|
win_sigma: (float, optional): sigma of normal distribution |
|
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma |
|
weights (list, optional): weights for different levels |
|
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. |
|
Returns: |
|
torch.Tensor: ms-ssim results |
|
""" |
|
if not X.shape == Y.shape: |
|
raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.") |
|
|
|
for d in range(len(X.shape) - 1, 1, -1): |
|
X = X.squeeze(dim=d) |
|
Y = Y.squeeze(dim=d) |
|
|
|
|
|
|
|
|
|
if len(X.shape) == 4: |
|
avg_pool = F.avg_pool2d |
|
elif len(X.shape) == 5: |
|
avg_pool = F.avg_pool3d |
|
else: |
|
raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}") |
|
|
|
if win is not None: |
|
win_size = win.shape[-1] |
|
|
|
if not (win_size % 2 == 1): |
|
raise ValueError("Window size should be odd.") |
|
|
|
smaller_side = min(X.shape[-2:]) |
|
assert smaller_side > (win_size - 1) * ( |
|
2 ** 4 |
|
), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % ((win_size - 1) * (2 ** 4)) |
|
|
|
if weights is None: |
|
weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] |
|
weights_tensor = X.new_tensor(weights) |
|
|
|
if win is None: |
|
win = _fspecial_gauss_1d(win_size, win_sigma) |
|
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) |
|
|
|
levels = weights_tensor.shape[0] |
|
mcs = [] |
|
for i in range(levels): |
|
ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K) |
|
|
|
if i < levels - 1: |
|
mcs.append(torch.relu(cs)) |
|
padding = [s % 2 for s in X.shape[2:]] |
|
X = avg_pool(X, kernel_size=2, padding=padding) |
|
Y = avg_pool(Y, kernel_size=2, padding=padding) |
|
|
|
ssim_per_channel = torch.relu(ssim_per_channel) |
|
mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0) |
|
ms_ssim_val = torch.prod(mcs_and_ssim ** weights_tensor.view(-1, 1, 1), dim=0) |
|
|
|
if size_average: |
|
return ms_ssim_val.mean() |
|
else: |
|
return ms_ssim_val.mean(1) |
|
|
|
|
|
class SSIM(torch.nn.Module): |
|
def __init__( |
|
self, |
|
data_range: float = 255, |
|
size_average: bool = True, |
|
win_size: int = 11, |
|
win_sigma: float = 1.5, |
|
channel: int = 3, |
|
spatial_dims: int = 2, |
|
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), |
|
nonnegative_ssim: bool = False, |
|
) -> None: |
|
r""" class for ssim |
|
Args: |
|
data_range (float or int, optional): value range of input images. (usually 1.0 or 255) |
|
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar |
|
win_size: (int, optional): the size of gauss kernel |
|
win_sigma: (float, optional): sigma of normal distribution |
|
channel (int, optional): input channels (default: 3) |
|
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. |
|
nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu. |
|
""" |
|
|
|
super(SSIM, self).__init__() |
|
self.win_size = win_size |
|
self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims) |
|
self.size_average = size_average |
|
self.data_range = data_range |
|
self.K = K |
|
self.nonnegative_ssim = nonnegative_ssim |
|
|
|
def forward(self, X: Tensor, Y: Tensor) -> Tensor: |
|
return ssim( |
|
X, |
|
Y, |
|
data_range=self.data_range, |
|
size_average=self.size_average, |
|
win=self.win, |
|
K=self.K, |
|
nonnegative_ssim=self.nonnegative_ssim, |
|
) |
|
|
|
|
|
class MS_SSIM(torch.nn.Module): |
|
def __init__( |
|
self, |
|
data_range: float = 255, |
|
size_average: bool = True, |
|
win_size: int = 11, |
|
win_sigma: float = 1.5, |
|
channel: int = 3, |
|
spatial_dims: int = 2, |
|
weights: Optional[List[float]] = None, |
|
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), |
|
) -> None: |
|
r""" class for ms-ssim |
|
Args: |
|
data_range (float or int, optional): value range of input images. (usually 1.0 or 255) |
|
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar |
|
win_size: (int, optional): the size of gauss kernel |
|
win_sigma: (float, optional): sigma of normal distribution |
|
channel (int, optional): input channels (default: 3) |
|
weights (list, optional): weights for different levels |
|
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. |
|
""" |
|
|
|
super(MS_SSIM, self).__init__() |
|
self.win_size = win_size |
|
self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims) |
|
self.size_average = size_average |
|
self.data_range = data_range |
|
self.weights = weights |
|
self.K = K |
|
|
|
def forward(self, X: Tensor, Y: Tensor) -> Tensor: |
|
return ms_ssim( |
|
X, |
|
Y, |
|
data_range=self.data_range, |
|
size_average=self.size_average, |
|
win=self.win, |
|
weights=self.weights, |
|
K=self.K, |
|
) |
|
|