Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,604 Bytes
90a9dd3 a7169e0 90a9dd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch
from torch import nn
import numpy as np
import scipy
from src.flair.utils.motionblur import Kernel as MotionKernel
class Blurkernel(nn.Module):
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
super().__init__()
self.blur_type = blur_type
self.kernel_size = kernel_size
self.std = std
self.device = device
self.seq = nn.Sequential(
nn.ReflectionPad2d(self.kernel_size//2),
nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
)
self.weights_init()
def forward(self, x):
return self.seq(x)
def weights_init(self):
if self.blur_type == "gaussian":
n = np.zeros((self.kernel_size, self.kernel_size))
n[self.kernel_size // 2,self.kernel_size // 2] = 1
k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
elif self.blur_type == "motion":
k = MotionKernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
def update_weights(self, k):
if not torch.is_tensor(k):
k = torch.from_numpy(k).to(self.device)
for name, f in self.named_parameters():
f.data.copy_(k)
def get_kernel(self):
return self.k
|