MaIR / basicsr /archs /shift_scanf_util.py
P-rateek's picture
Upload 106 files
638b138 verified
import time
import torch
from einops import rearrange
import numpy as np
def test_crop_by_horz(inp, scan_len):
# Flip the return way
split_inp = rearrange(inp, "h (d1 w) -> d1 h w ", w=scan_len)
for i in range(1, len(split_inp), 2):
split_inp[i, :] = split_inp[i, :].flip(dims=[-2])
inp = rearrange(split_inp, " d1 h w -> (d1 h) w ")
# print(inp)
inp_window = rearrange(inp, "(d1 h) (d2 w) -> (d2 d1) h w ", h=2, w=scan_len)
inp_window[:,-1] = inp_window[:,-1].flip(dims=[-1])
inp_flatten = inp_window.reshape(1, -1)
print(inp_flatten)
print(inp_flatten.shape)
def chw_2d(h, w):
return torch.arange(1, (h*w+1)).reshape(h, w)
def chw_3d(c, h, w):
return torch.arange(1, (c*h*w+1)).reshape(c, h, w)
def chw_4d(b, c, h, w, random=False):
if random:
return torch.randn(b*c*h*w).reshape(b, c, h, w)
else:
return torch.arange(1, (b*c*h*w+1)).reshape(b, c, h, w)
def create_idx(b, c, h, w):
# return torch.arange(1, (b*c*h*w+1)).reshape(b, c, h, w)
return torch.arange(b*c*h*w).reshape(b, c, h, w)
def test_2d_horz(inp_h, inp_w):
scan_len = 2
# inp_h, inp_w = 4, 4
# inp = torch.randn((4,4))
inp = torch.tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[ 13, 14, 15, 16]])
inp = chw_2d(inp_h, inp_w)
print(inp)
test_crop_by_horz(inp, scan_len)
def sscan_einops(inp, scan_len):
B, C, H, W = inp.shape
# Flip the return way
split_inp = rearrange(inp, "b c h (d1 w) -> d1 b c h w ", w=scan_len)
for i in range(1, len(split_inp), 2):
split_inp[i, :] = split_inp[i, :].flip(dims=[-2])
reverse_inp = rearrange(split_inp, " d1 b c h w -> b c (d1 h) w ")
# print(inp)
inp_window = rearrange(reverse_inp, "b c (d1 h) (d2 w) -> (b c d2 d1) h w ", h=2, w=scan_len)
inp_window[:,-1] = inp_window[:,-1].flip(dims=[-1])
inp_flatten = inp_window.reshape(B, C, 1, -1)
# print(inp_flatten)
# print(inp_flatten.shape)
return inp_flatten
def sscan(inp, scan_len, shift_len=0):
B, C, H, W = inp.shape
# Flip the return way
# 将返回的时候的列,上下翻转
if shift_len == 0:
for i in range(1, (W // scan_len)+1, 2):
# for j in range(scan_len):
inp[:, :, :, i*scan_len:(i+1)*scan_len] = inp[:, :, :, i*scan_len:(i+1)*scan_len].flip(dims=[-2])
else:
for i in range(0, ((W-shift_len) // scan_len) +1, 2):
inp[:, :, :,(shift_len+i*scan_len):(shift_len+(i+1)*scan_len)] = inp[:, :, :, (shift_len+i*scan_len):(shift_len+(i+1)*scan_len)].flip(dims=[-2])
# 将当前return way的sub-line翻转
# inp_window = rearrange(inp, "b c (d1 h) (d2 w) -> (b c d2 d1) h w ", h=2, w=scan_len)
if shift_len == 0:
for hi in range((H // 2)):
for wi in range(W // scan_len):
inp[:, :, 2*hi+1, wi*scan_len:(wi+1)*scan_len] = inp[:, :, 2*hi+1, wi*scan_len:(wi+1)*scan_len].flip(dims=[-1])
else:
for hi in range((H // 2)):
inp[:, :, 2*hi+1, 0:shift_len] = inp[:, :, 2*hi+1, 0:shift_len].flip(dims=[-1])
for wi in range((W-shift_len) // scan_len):
start_ = shift_len + wi*scan_len
end_ = shift_len + (wi+1)*scan_len
inp[:, :, 2*hi+1, start_:end_] = inp[:, :, 2*hi+1, start_:end_].flip(dims=[-1])
if (W-shift_len) % scan_len:
# inp_last = inp[:,:,:,-(W % scan_len):].reshape(B, C, -1)
inp_last = inp[:,:,:,-((W-shift_len) % scan_len):]
inp_last[:,:, 1::2, :] = inp_last[:,:, 1::2, :].flip(dims=[-1]) # 取偶数位,奇数位是::2
inp_last = inp_last.reshape(B, C, -1)
inp_rest = inp[:,:,:,:-((W-shift_len) % scan_len)]
else:
inp_rest = inp
if shift_len==0:
inp_window = rearrange(inp_rest, "b c h (d2 w) -> (b c d2) h w ", w=scan_len)
else:
inp_first = inp_rest[:,:,:,:shift_len].reshape(B, C, -1)
inp_middle = inp_rest[:,:,:, shift_len:]
inp_window = rearrange(inp_middle, "b c h (d2 w) -> (b c d2) h w ", w=scan_len)
# inp_window[:,-1] = inp_window[:,-1].flip(dims=[-1])
inp_flatten = inp_window.reshape(B, C, -1)
# inp_window[:,-1] = inp_window[:,-1].flip(dims=[-1])
# inp_flatten = inp.reshape(B, C, 1, -1)
# print(inp_flatten)
# print(inp_flatten.shape)
if shift_len != 0:
inp_flatten = torch.concat((inp_first, inp_flatten), dim=-1)
if (W-shift_len) % scan_len:
inp_flatten = torch.concat((inp_flatten, inp_last), dim=-1)
# print(inp_last.shape)
return inp_flatten
# def sscan_4d(inp, scan_len, ues_einops=True, fix_ending=True):
def sscan_4d(inp, scan_len, shift_len=0, fix_ending=True, use_einops=False):
B, C, H, W = inp.shape
L = H * W
if fix_ending:
inp_reverse = torch.flip(inp, dims=[-1,-2])
inp_cat = torch.concat((inp, inp_reverse), dim=1)
inp_cat_t = inp_cat.transpose(-1, -2).contiguous()
if use_einops:
line1 = sscan_einops(inp_cat, scan_len)
line2 = sscan_einops(inp_cat_t, scan_len)
else:
line1 = sscan(inp_cat, scan_len, shift_len)
line2 = sscan(inp_cat_t, scan_len, shift_len)
xs = torch.stack([line1.reshape(B, 2, -1, L), line2.reshape(B, 2, -1, L)], dim=1).reshape(B, 4, -1, L)
else:
inp_t = inp.transpose(-1, -2).contiguous()
if use_einops:
line1 = sscan_einops(inp, scan_len)
line2 = sscan_einops(inp_t, scan_len)
else:
line1 = sscan(inp, scan_len, shift_len)
line2 = sscan(inp_t, scan_len, shift_len)
x_hwwh = torch.stack([line1.reshape(B, -1, L), line2.reshape(B, -1, L)], dim=1).reshape(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1)
# print(xs)
return xs
def inverse_ids_generate(origin_ids, K=4):
'''
Input: origin_ids: (B, K, C, L)
Output: (B, K, C, L)
Note: C is set to 1 for speeding up.
'''
inverse_ids = torch.argsort(origin_ids, dim=-1)
return inverse_ids
def mair_ids_generate(inp_shape, scan_len=4, K=4):
inp_b, inp_c, inp_h, inp_w = inp_shape
# inp_idx = create_idx(1, inp_c, inp_h, inp_w)
inp_idx = create_idx(1, 1, inp_h, inp_w)
xs_scan_ids = sscan_4d(inp_idx, scan_len)[0]
xs_inverse_ids = inverse_ids_generate(xs_scan_ids, K=K)
return xs_scan_ids, xs_inverse_ids
def mair_shift_ids_generate(inp_shape, scan_len=4, shift_len=0, K=4):
inp_b, inp_c, inp_h, inp_w = inp_shape
# create_idx函数运行时间:0.0050699710845947266 秒
# start_time = time.time()
inp_idx = create_idx(1, 1, inp_h, inp_w)
# print(f"create_idx函数运行时间:{time.time() - start_time} 秒")
# start_time = time.time()
xs_scan_ids = sscan_4d(inp_idx, scan_len, shift_len=shift_len)[0]
# print(f"sscan_4d函数运行时间:{time.time() - start_time} 秒")
# xs_scan_ids函数运行时间:0.05201005935668945 秒
# start_time = time.time()
xs_scan_ids = xs_scan_ids.repeat(inp_b, 1, 1, 1)
# print(f"xs_scan_ids函数运行时间:{time.time() - start_time} 秒")
# start_time = time.time()
xs_inverse_ids = inverse_ids_generate(xs_scan_ids, K=K)
# print(f"inverse_ids_generate函数运行时间:{time.time() - start_time} 秒")
return xs_scan_ids, xs_inverse_ids
def mair_ids_scan(inp, xs_scan_ids, bkdl=False, K=4):
'''
inp: B, C, H, W
xs_scan_ids: K, 1, L
'''
B, C, H, W = inp.shape
L = H * W
xs_scan_ids = xs_scan_ids.reshape(K, L)
y1 = torch.index_select(inp.reshape(B, 1, C, -1), -1, xs_scan_ids[0])
y2 = torch.index_select(inp.reshape(B, 1, C, -1), -1, xs_scan_ids[1])
y3 = torch.index_select(inp.reshape(B, 1, C, -1), -1, xs_scan_ids[2])
y4 = torch.index_select(inp.reshape(B, 1, C, -1), -1, xs_scan_ids[3])
if bkdl:
inp_flatten = torch.cat((y1, y2, y3, y4), dim=1)
else:
inp_flatten = torch.cat((y1, y2, y3, y4), dim=1).reshape(B, 4, -1)
return inp_flatten
def mair_ids_inverse(inp, xs_scan_ids, shape=None):
'''
inp: (B, K, -1, L)
xs_scan_ids: (1, K, 1, L)
'''
B, K, _, L = inp.shape
xs_scan_ids = xs_scan_ids.reshape(K, L)
if not shape:
y1 = torch.index_select(inp[:, 0, :], -1, xs_scan_ids[0]).reshape(B, -1, L)
y2 = torch.index_select(inp[:, 1, :], -1, xs_scan_ids[1]).reshape(B, -1, L)
y3 = torch.index_select(inp[:, 2, :], -1, xs_scan_ids[2]).reshape(B, -1, L)
y4 = torch.index_select(inp[:, 3, :], -1, xs_scan_ids[3]).reshape(B, -1, L)
else:
B, C, H, W = shape
y1 = torch.index_select(inp[:, 0, :], -1, xs_scan_ids[0]).reshape(B, -1, H, W)
y2 = torch.index_select(inp[:, 1, :], -1, xs_scan_ids[1]).reshape(B, -1, H, W)
y3 = torch.index_select(inp[:, 2, :], -1, xs_scan_ids[2]).reshape(B, -1, H, W)
y4 = torch.index_select(inp[:, 3, :], -1, xs_scan_ids[3]).reshape(B, -1, H, W)
return torch.cat((y1,y2,y3,y4), dim=1)
def test_time():
scan_len = 4
shift_len = 2
inp_b, inp_c, inp_h, inp_w = 2, 3, 3, 4
inp = chw_4d(1, 1, inp_h, inp_w, random=False)
inp_rgb = chw_4d(inp_b, inp_c, inp_h, inp_w, random=False)
print("inp:", inp_rgb)
# Original
xs_scan_ids, xs_inverse_ids = mair_ids_generate(inp.shape, scan_len=scan_len, K=4)
xs = mair_ids_scan(inp_rgb, xs_scan_ids, bkdl=True)
inp_flatten = mair_ids_inverse(xs, xs_inverse_ids, shape=(inp_b, inp_c, inp_h, inp_w))
inp_flatten = inp_flatten.chunk(4, dim=1)
print("recovered input:")
for i in range(len(inp_flatten)):
print("inp_flatten:", i)
print(inp_flatten[i])
print("end")
if __name__ == '__main__':
# torch.set_default_device(1)
start_time = time.time()
result = test_time()
end_time = time.time()
print(f"函数运行时间:{end_time - start_time} 秒")