Spaces:
Build error
Build error
| # ''' | |
| # https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py | |
| # ''' | |
| # | |
| # import torch | |
| # import torch.jit | |
| # import torch.nn.functional as F | |
| # | |
| # | |
| # @torch.jit.script | |
| # def create_window(window_size: int, sigma: float, channel: int): | |
| # ''' | |
| # Create 1-D gauss kernel | |
| # :param window_size: the size of gauss kernel | |
| # :param sigma: sigma of normal distribution | |
| # :param channel: input channel | |
| # :return: 1D kernel | |
| # ''' | |
| # coords = torch.arange(window_size, dtype=torch.float) | |
| # coords -= window_size // 2 | |
| # | |
| # g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) | |
| # g /= g.sum() | |
| # | |
| # g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1) | |
| # return g | |
| # | |
| # | |
| # @torch.jit.script | |
| # def _gaussian_filter(x, window_1d, use_padding: bool): | |
| # ''' | |
| # Blur input with 1-D kernel | |
| # :param x: batch of tensors to be blured | |
| # :param window_1d: 1-D gauss kernel | |
| # :param use_padding: padding image before conv | |
| # :return: blured tensors | |
| # ''' | |
| # C = x.shape[1] | |
| # padding = 0 | |
| # if use_padding: | |
| # window_size = window_1d.shape[3] | |
| # padding = window_size // 2 | |
| # out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C) | |
| # out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C) | |
| # return out | |
| # | |
| # | |
| # @torch.jit.script | |
| # def ssim(X, Y, window, data_range: float, use_padding: bool = False): | |
| # ''' | |
| # Calculate ssim index for X and Y | |
| # :param X: images [B, C, H, N_bins] | |
| # :param Y: images [B, C, H, N_bins] | |
| # :param window: 1-D gauss kernel | |
| # :param data_range: value range of input images. (usually 1.0 or 255) | |
| # :param use_padding: padding image before conv | |
| # :return: | |
| # ''' | |
| # | |
| # K1 = 0.01 | |
| # K2 = 0.03 | |
| # compensation = 1.0 | |
| # | |
| # C1 = (K1 * data_range) ** 2 | |
| # C2 = (K2 * data_range) ** 2 | |
| # | |
| # mu1 = _gaussian_filter(X, window, use_padding) | |
| # mu2 = _gaussian_filter(Y, window, use_padding) | |
| # sigma1_sq = _gaussian_filter(X * X, window, use_padding) | |
| # sigma2_sq = _gaussian_filter(Y * Y, window, use_padding) | |
| # sigma12 = _gaussian_filter(X * Y, window, use_padding) | |
| # | |
| # mu1_sq = mu1.pow(2) | |
| # mu2_sq = mu2.pow(2) | |
| # mu1_mu2 = mu1 * mu2 | |
| # | |
| # sigma1_sq = compensation * (sigma1_sq - mu1_sq) | |
| # sigma2_sq = compensation * (sigma2_sq - mu2_sq) | |
| # sigma12 = compensation * (sigma12 - mu1_mu2) | |
| # | |
| # cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) | |
| # # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan. | |
| # cs_map = cs_map.clamp_min(0.) | |
| # ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map | |
| # | |
| # ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW | |
| # cs = cs_map.mean(dim=(1, 2, 3)) | |
| # | |
| # return ssim_val, cs | |
| # | |
| # | |
| # @torch.jit.script | |
| # def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8): | |
| # ''' | |
| # interface of ms-ssim | |
| # :param X: a batch of images, (N,C,H,W) | |
| # :param Y: a batch of images, (N,C,H,W) | |
| # :param window: 1-D gauss kernel | |
| # :param data_range: value range of input images. (usually 1.0 or 255) | |
| # :param weights: weights for different levels | |
| # :param use_padding: padding image before conv | |
| # :param eps: use for avoid grad nan. | |
| # :return: | |
| # ''' | |
| # levels = weights.shape[0] | |
| # cs_vals = [] | |
| # ssim_vals = [] | |
| # for _ in range(levels): | |
| # ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding) | |
| # # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf. | |
| # ssim_val = ssim_val.clamp_min(eps) | |
| # cs = cs.clamp_min(eps) | |
| # cs_vals.append(cs) | |
| # | |
| # ssim_vals.append(ssim_val) | |
| # padding = (X.shape[2] % 2, X.shape[3] % 2) | |
| # X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding) | |
| # Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding) | |
| # | |
| # cs_vals = torch.stack(cs_vals, dim=0) | |
| # ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0) | |
| # return ms_ssim_val | |
| # | |
| # | |
| # class SSIM(torch.jit.ScriptModule): | |
| # __constants__ = ['data_range', 'use_padding'] | |
| # | |
| # def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False): | |
| # ''' | |
| # :param window_size: the size of gauss kernel | |
| # :param window_sigma: sigma of normal distribution | |
| # :param data_range: value range of input images. (usually 1.0 or 255) | |
| # :param channel: input channels (default: 3) | |
| # :param use_padding: padding image before conv | |
| # ''' | |
| # super().__init__() | |
| # assert window_size % 2 == 1, 'Window size must be odd.' | |
| # window = create_window(window_size, window_sigma, channel) | |
| # self.register_buffer('window', window) | |
| # self.data_range = data_range | |
| # self.use_padding = use_padding | |
| # | |
| # @torch.jit.script_method | |
| # def forward(self, X, Y): | |
| # r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding) | |
| # return r[0] | |
| # | |
| # | |
| # class MS_SSIM(torch.jit.ScriptModule): | |
| # __constants__ = ['data_range', 'use_padding', 'eps'] | |
| # | |
| # def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None, | |
| # levels=None, eps=1e-8): | |
| # ''' | |
| # class for ms-ssim | |
| # :param window_size: the size of gauss kernel | |
| # :param window_sigma: sigma of normal distribution | |
| # :param data_range: value range of input images. (usually 1.0 or 255) | |
| # :param channel: input channels | |
| # :param use_padding: padding image before conv | |
| # :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) | |
| # :param levels: number of downsampling | |
| # :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf. | |
| # ''' | |
| # super().__init__() | |
| # assert window_size % 2 == 1, 'Window size must be odd.' | |
| # self.data_range = data_range | |
| # self.use_padding = use_padding | |
| # self.eps = eps | |
| # | |
| # window = create_window(window_size, window_sigma, channel) | |
| # self.register_buffer('window', window) | |
| # | |
| # if weights is None: | |
| # weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] | |
| # weights = torch.tensor(weights, dtype=torch.float) | |
| # | |
| # if levels is not None: | |
| # weights = weights[:levels] | |
| # weights = weights / weights.sum() | |
| # | |
| # self.register_buffer('weights', weights) | |
| # | |
| # @torch.jit.script_method | |
| # def forward(self, X, Y): | |
| # return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights, | |
| # use_padding=self.use_padding, eps=self.eps) | |
| # | |
| # | |
| # if __name__ == '__main__': | |
| # print('Simple Test') | |
| # im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda') | |
| # img1 = im / 255 | |
| # img2 = img1 * 0.5 | |
| # | |
| # losser = SSIM(data_range=1.).cuda() | |
| # loss = losser(img1, img2).mean() | |
| # | |
| # losser2 = MS_SSIM(data_range=1.).cuda() | |
| # loss2 = losser2(img1, img2).mean() | |
| # | |
| # print(loss.item()) | |
| # print(loss2.item()) | |
| # | |
| # if __name__ == '__main__': | |
| # print('Training Test') | |
| # import cv2 | |
| # import torch.optim | |
| # import numpy as np | |
| # import imageio | |
| # import time | |
| # | |
| # out_test_video = False | |
| # # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF | |
| # video_use_gif = False | |
| # | |
| # im = cv2.imread('test_img1.jpg', 1) | |
| # t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255. | |
| # | |
| # if out_test_video: | |
| # if video_use_gif: | |
| # fps = 0.5 | |
| # out_wh = (im.shape[1] // 2, im.shape[0] // 2) | |
| # suffix = '.gif' | |
| # else: | |
| # fps = 5 | |
| # out_wh = (im.shape[1], im.shape[0]) | |
| # suffix = '.mkv' | |
| # video_last_time = time.perf_counter() | |
| # video = imageio.get_writer('ssim_test' + suffix, fps=fps) | |
| # | |
| # # 测试ssim | |
| # print('Training SSIM') | |
| # rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255. | |
| # rand_im.requires_grad = True | |
| # optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8) | |
| # losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda() | |
| # ssim_score = 0 | |
| # while ssim_score < 0.999: | |
| # optim.zero_grad() | |
| # loss = losser(rand_im, t_im) | |
| # (-loss).sum().backward() | |
| # ssim_score = loss.item() | |
| # optim.step() | |
| # r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0] | |
| # r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2) | |
| # | |
| # if out_test_video: | |
| # if time.perf_counter() - video_last_time > 1. / fps: | |
| # video_last_time = time.perf_counter() | |
| # out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB) | |
| # out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA) | |
| # if isinstance(out_frame, cv2.UMat): | |
| # out_frame = out_frame.get() | |
| # video.append_data(out_frame) | |
| # | |
| # cv2.imshow('ssim', r_im) | |
| # cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score) | |
| # cv2.waitKey(1) | |
| # | |
| # if out_test_video: | |
| # video.close() | |
| # | |
| # # 测试ms_ssim | |
| # if out_test_video: | |
| # if video_use_gif: | |
| # fps = 0.5 | |
| # out_wh = (im.shape[1] // 2, im.shape[0] // 2) | |
| # suffix = '.gif' | |
| # else: | |
| # fps = 5 | |
| # out_wh = (im.shape[1], im.shape[0]) | |
| # suffix = '.mkv' | |
| # video_last_time = time.perf_counter() | |
| # video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps) | |
| # | |
| # print('Training MS_SSIM') | |
| # rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255. | |
| # rand_im.requires_grad = True | |
| # optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8) | |
| # losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda() | |
| # ssim_score = 0 | |
| # while ssim_score < 0.999: | |
| # optim.zero_grad() | |
| # loss = losser(rand_im, t_im) | |
| # (-loss).sum().backward() | |
| # ssim_score = loss.item() | |
| # optim.step() | |
| # r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0] | |
| # r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2) | |
| # | |
| # if out_test_video: | |
| # if time.perf_counter() - video_last_time > 1. / fps: | |
| # video_last_time = time.perf_counter() | |
| # out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB) | |
| # out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA) | |
| # if isinstance(out_frame, cv2.UMat): | |
| # out_frame = out_frame.get() | |
| # video.append_data(out_frame) | |
| # | |
| # cv2.imshow('ms_ssim', r_im) | |
| # cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score) | |
| # cv2.waitKey(1) | |
| # | |
| # if out_test_video: | |
| # video.close() | |
| """ | |
| Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.autograd import Variable | |
| import numpy as np | |
| from math import exp | |
| def gaussian(window_size, sigma): | |
| gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) | |
| return gauss / gauss.sum() | |
| def create_window(window_size, channel): | |
| _1D_window = gaussian(window_size, 1.5).unsqueeze(1) | |
| _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |
| window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) | |
| return window | |
| def _ssim(img1, img2, window, window_size, channel, size_average=True): | |
| mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) | |
| mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) | |
| mu1_sq = mu1.pow(2) | |
| mu2_sq = mu2.pow(2) | |
| mu1_mu2 = mu1 * mu2 | |
| sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq | |
| sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq | |
| sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 | |
| C1 = 0.01 ** 2 | |
| C2 = 0.03 ** 2 | |
| ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |
| if size_average: | |
| return ssim_map.mean() | |
| else: | |
| return ssim_map.mean(1) | |
| class SSIM(torch.nn.Module): | |
| def __init__(self, window_size=11, size_average=True): | |
| super(SSIM, self).__init__() | |
| self.window_size = window_size | |
| self.size_average = size_average | |
| self.channel = 1 | |
| self.window = create_window(window_size, self.channel) | |
| def forward(self, img1, img2): | |
| (_, channel, _, _) = img1.size() | |
| if channel == self.channel and self.window.data.type() == img1.data.type(): | |
| window = self.window | |
| else: | |
| window = create_window(self.window_size, channel) | |
| if img1.is_cuda: | |
| window = window.cuda(img1.get_device()) | |
| window = window.type_as(img1) | |
| self.window = window | |
| self.channel = channel | |
| return _ssim(img1, img2, window, self.window_size, channel, self.size_average) | |
| window = None | |
| def ssim(img1, img2, window_size=11, size_average=True): | |
| (_, channel, _, _) = img1.size() | |
| global window | |
| if window is None: | |
| window = create_window(window_size, channel) | |
| if img1.is_cuda: | |
| window = window.cuda(img1.get_device()) | |
| window = window.type_as(img1) | |
| return _ssim(img1, img2, window, window_size, channel, size_average) | |