FLAIR / src /flair /utils /utils_sisr.py
juliuse's picture
Initial commit: track binaries with LFS
90a9dd3
# -*- coding: utf-8 -*-
import torch.fft
import torch
import numpy as np
from scipy import ndimage
from scipy.interpolate import interp2d
def splits(a, sf):
'''split a into sfxsf distinct blocks
Args:
a: NxCxWxH
sf: split factor
Returns:
b: NxCx(W/sf)x(H/sf)x(sf^2)
'''
b = torch.stack(torch.chunk(a, sf, dim=2), dim=4)
b = torch.cat(torch.chunk(b, sf, dim=3), dim=4)
return b
def p2o(psf, shape):
'''
Convert point-spread function to optical transfer function.
otf = p2o(psf) computes the Fast Fourier Transform (FFT) of the
point-spread function (PSF) array and creates the optical transfer
function (OTF) array that is not influenced by the PSF off-centering.
Args:
psf: NxCxhxw
shape: [H, W]
Returns:
otf: NxCxHxWx2
'''
otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
for axis, axis_size in enumerate(psf.shape[2:]):
otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
otf = torch.fft.fftn(otf, dim=(-2,-1))
#n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
#otf[..., 1][torch.abs(otf[..., 1]) < n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
return otf
def upsample(x, sf=3):
'''s-fold upsampler
Upsampling the spatial size by filling the new entries with zeros
x: tensor image, NxCxWxH
'''
st = 0
z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x)
z[..., st::sf, st::sf].copy_(x)
return z
def downsample(x, sf=3):
'''s-fold downsampler
Keeping the upper-left pixel for each distinct sfxsf patch and discarding the others
x: tensor image, NxCxWxH
'''
st = 0
return x[..., st::sf, st::sf]
def data_solution_simple(x, F2K, FKFy, rho):
rho = rho.clip(min=1e-2)
numerator = FKFy + torch.fft.fftn(rho*x, dim=(-2,-1))
denominator = F2K + rho
FX = numerator / denominator
Xest = torch.real(torch.fft.ifftn(FX, dim=(-2, -1)))
return Xest
def data_solution_nonuniform(x, FK, FKC, F2KM, FKFMy, rho):
rho = rho.clip(min=1e-2)
numerator = FKFMy + torch.fft.fftn(rho*x, dim=(-2,-1))
denominator = F2KM + rho
FX = numerator / denominator
Xest = torch.real(torch.fft.ifftn(FX, dim=(-2, -1)))
return Xest
def pre_calculate(x, k):
'''
Args:
x: NxCxHxW, LR input
k: NxCxhxw
Returns:
FK, FKC, F2K
will be reused during iterations
'''
w, h = x.shape[-2:]
FK = p2o(k, (w, h))
FKC = torch.conj(FK)
F2K = torch.pow(torch.abs(FK), 2)
return FK, FKC, F2K
def pre_calculate_FK(k):
'''
Args:
k: [25, 1, 33, 33] 25 is the number of filters
Returns:
FK:
FKC:
'''
# [25, 1, 512, 512] (expanded from) [25, 1, 33, 33]
FK = p2o(k, (512, 512))
FKC = torch.conj(FK)
return FK, FKC
def pre_calculate_nonuniform(x, y, FK, FKC, mask):
'''
Args:
x: [1, 3, 512, 512]
y: [1, 3, 512, 512]
FK: [25, 1, 512, 512] 25 is the number of filters
FKC: [25, 1, 512, 512]
m: [1, 25, 512, 512]
Returns:
'''
mask = mask.transpose(0, 1)
w, h = x.shape[-2:]
# [1, 3, 512, 512] -> [25, 3, 512, 512]
By = y.repeat(mask.shape[0], 1, 1, 1)
# [25, 3, 512, 512]
My = mask * By
# or use just fft..?
FMy = torch.fft.fft2(My)
# [25, 3, 512, 512]
FKFMy = FK * FMy
# [1, 3, 512, 512]
FKFMy = torch.sum(FKFMy, dim=0, keepdim=True)
# [25, 1, 512, 512]
F2KM = torch.abs(FKC * (mask ** 2) * FK)
# [1, 1, 512, 512]
F2KM = torch.sum(F2KM, dim=0, keepdim=True)
return F2KM, FKFMy
def classical_degradation(x, k, sf=3):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
#x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
return x[st::sf, st::sf, ...]
def shift_pixel(x, sf, upper_left=True):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH, image or kernel
sf: scale factor
upper_left: shift direction
"""
h, w = x.shape[:2]
shift = (sf-1)*0.5
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
if upper_left:
x1 = xv + shift
y1 = yv + shift
else:
x1 = xv - shift
y1 = yv - shift
x1 = np.clip(x1, 0, w-1)
y1 = np.clip(y1, 0, h-1)
if x.ndim == 2:
x = interp2d(xv, yv, x)(x1, y1)
if x.ndim == 3:
for i in range(x.shape[-1]):
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
return x