File size: 5,819 Bytes
638b138 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
from torch.nn import functional as F
import torchvision.transforms.functional as TF
from basicsr.utils.registry import MODEL_REGISTRY
from basicsr.models.sr_model import SRModel
from tqdm import tqdm
@MODEL_REGISTRY.register()
class MaIRPlusModel(SRModel):
"""MaIR model for image restoration."""
def one_img_test(self, img):
_, C, h, w = img.size()
split_token_h = h // 200 + 1 # number of horizontal cut sections
split_token_w = w // 200 + 1 # number of vertical cut sections
# padding
mod_pad_h, mod_pad_w = 0, 0
if h % split_token_h != 0:
mod_pad_h = split_token_h - h % split_token_h
if w % split_token_w != 0:
mod_pad_w = split_token_w - w % split_token_w
img = F.pad(img, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
_, _, H, W = img.size()
split_h = H // split_token_h # height of each partition
split_w = W // split_token_w # width of each partition
# overlapping
shave_h = split_h // 10
shave_w = split_w // 10
scale = self.opt.get('scale', 1)
ral = H // split_h
row = W // split_w
slices = [] # list of partition borders
for i in range(ral):
for j in range(row):
if i == 0 and i == ral - 1:
top = slice(i * split_h, (i + 1) * split_h)
elif i == 0:
top = slice(i*split_h, (i+1)*split_h+shave_h)
elif i == ral - 1:
top = slice(i*split_h-shave_h, (i+1)*split_h)
else:
top = slice(i*split_h-shave_h, (i+1)*split_h+shave_h)
if j == 0 and j == row - 1:
left = slice(j*split_w, (j+1)*split_w)
elif j == 0:
left = slice(j*split_w, (j+1)*split_w+shave_w)
elif j == row - 1:
left = slice(j*split_w-shave_w, (j+1)*split_w)
else:
left = slice(j*split_w-shave_w, (j+1)*split_w+shave_w)
temp = (top, left)
slices.append(temp)
img_chops = [] # list of partitions
for temp in slices:
top, left = temp
img_chops.append(img[..., top, left])
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
with torch.no_grad():
outputs = []
for chop in img_chops:
out = self.net_g_ema(chop) # image processing of each partition
outputs.append(out)
_img = torch.zeros(1, C, H * scale, W * scale)
# merge
for i in range(ral):
for j in range(row):
top = slice(i * split_h * scale, (i + 1) * split_h * scale)
left = slice(j * split_w * scale, (j + 1) * split_w * scale)
if i == 0:
_top = slice(0, split_h * scale)
else:
_top = slice(shave_h*scale, (shave_h+split_h)*scale)
if j == 0:
_left = slice(0, split_w*scale)
else:
_left = slice(shave_w*scale, (shave_w+split_w)*scale)
_img[..., top, left] = outputs[i * row + j][..., _top, _left]
return _img
else:
self.net_g.eval()
with torch.no_grad():
outputs = []
for chop in img_chops:
out = self.net_g(chop) # image processing of each partition
outputs.append(out)
_img = torch.zeros(1, C, H * scale, W * scale)
# merge
for i in range(ral):
for j in range(row):
top = slice(i * split_h * scale, (i + 1) * split_h * scale)
left = slice(j * split_w * scale, (j + 1) * split_w * scale)
if i == 0:
_top = slice(0, split_h * scale)
else:
_top = slice(shave_h * scale, (shave_h + split_h) * scale)
if j == 0:
_left = slice(0, split_w * scale)
else:
_left = slice(shave_w * scale, (shave_w + split_w) * scale)
_img[..., top, left] = outputs[i * row + j][..., _top, _left]
self.net_g.train()
_, _, h, w = _img.size()
_img = _img[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
return _img
# test by partitioning
def gather(self, imgs):
for i in range(len(imgs), 0, -1):
if i > 4:
imgs[i-1] = imgs[i-1].clone().transpose(2,3)
if (i-1) %4 > 1:
imgs[i-1] = TF.hflip(imgs[i-1])
if ((i-1) % 4) % 2 == 1:
imgs[i-1] = TF.vflip(imgs[i-1])
imgs = torch.cat(imgs, dim=0)
imgs = torch.mean(imgs, dim=0, keepdim=True)
return imgs
def augment(self, img):
imgs = [0] * 9
for i in range(1,9):
if i == 1:
imgs[i] = img
elif i == 2:
imgs[i] = TF.vflip(img)
elif i >2 and i <=4 :
imgs[i] = TF.hflip(imgs[i-2])
elif i > 4:
imgs[i] = imgs[i-4].transpose(2,3)
return imgs[1:]
def test(self):
lqs = self.augment(self.lq)
output = []
for i in tqdm(range(len(lqs))):
output.append(self.one_img_test(lqs[i]))
self.output = self.gather(output)
|