FLAIR / src /flair /utils /blur_util.py
juliuse's picture
import flair fix
a7169e0
import torch
from torch import nn
import numpy as np
import scipy
from src.flair.utils.motionblur import Kernel as MotionKernel
class Blurkernel(nn.Module):
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
super().__init__()
self.blur_type = blur_type
self.kernel_size = kernel_size
self.std = std
self.device = device
self.seq = nn.Sequential(
nn.ReflectionPad2d(self.kernel_size//2),
nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
)
self.weights_init()
def forward(self, x):
return self.seq(x)
def weights_init(self):
if self.blur_type == "gaussian":
n = np.zeros((self.kernel_size, self.kernel_size))
n[self.kernel_size // 2,self.kernel_size // 2] = 1
k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
elif self.blur_type == "motion":
k = MotionKernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
def update_weights(self, k):
if not torch.is_tensor(k):
k = torch.from_numpy(k).to(self.device)
for name, f in self.named_parameters():
f.data.copy_(k)
def get_kernel(self):
return self.k