import torch from torch import nn import numpy as np import scipy from src.flair.utils.motionblur import Kernel as MotionKernel class Blurkernel(nn.Module): def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None): super().__init__() self.blur_type = blur_type self.kernel_size = kernel_size self.std = std self.device = device self.seq = nn.Sequential( nn.ReflectionPad2d(self.kernel_size//2), nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3) ) self.weights_init() def forward(self, x): return self.seq(x) def weights_init(self): if self.blur_type == "gaussian": n = np.zeros((self.kernel_size, self.kernel_size)) n[self.kernel_size // 2,self.kernel_size // 2] = 1 k = scipy.ndimage.gaussian_filter(n, sigma=self.std) k = torch.from_numpy(k) self.k = k for name, f in self.named_parameters(): f.data.copy_(k) elif self.blur_type == "motion": k = MotionKernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix k = torch.from_numpy(k) self.k = k for name, f in self.named_parameters(): f.data.copy_(k) def update_weights(self, k): if not torch.is_tensor(k): k = torch.from_numpy(k).to(self.device) for name, f in self.named_parameters(): f.data.copy_(k) def get_kernel(self): return self.k