|
import time |
|
import torch |
|
from einops import rearrange |
|
import numpy as np |
|
|
|
def test_crop_by_horz(inp, scan_len): |
|
|
|
split_inp = rearrange(inp, "h (d1 w) -> d1 h w ", w=scan_len) |
|
for i in range(1, len(split_inp), 2): |
|
split_inp[i, :] = split_inp[i, :].flip(dims=[-2]) |
|
inp = rearrange(split_inp, " d1 h w -> (d1 h) w ") |
|
|
|
|
|
inp_window = rearrange(inp, "(d1 h) (d2 w) -> (d2 d1) h w ", h=2, w=scan_len) |
|
|
|
inp_window[:,-1] = inp_window[:,-1].flip(dims=[-1]) |
|
inp_flatten = inp_window.reshape(1, -1) |
|
print(inp_flatten) |
|
print(inp_flatten.shape) |
|
|
|
def chw_2d(h, w): |
|
return torch.arange(1, (h*w+1)).reshape(h, w) |
|
|
|
def chw_3d(c, h, w): |
|
return torch.arange(1, (c*h*w+1)).reshape(c, h, w) |
|
|
|
def chw_4d(b, c, h, w, random=False): |
|
if random: |
|
return torch.randn(b*c*h*w).reshape(b, c, h, w) |
|
else: |
|
return torch.arange(1, (b*c*h*w+1)).reshape(b, c, h, w) |
|
|
|
def create_idx(b, c, h, w): |
|
|
|
return torch.arange(b*c*h*w).reshape(b, c, h, w) |
|
|
|
def test_2d_horz(inp_h, inp_w): |
|
scan_len = 2 |
|
|
|
|
|
inp = torch.tensor([[ 1, 2, 3, 4], |
|
[ 5, 6, 7, 8], |
|
[ 9, 10, 11, 12], |
|
[ 13, 14, 15, 16]]) |
|
inp = chw_2d(inp_h, inp_w) |
|
print(inp) |
|
test_crop_by_horz(inp, scan_len) |
|
|
|
def sscan_einops(inp, scan_len): |
|
B, C, H, W = inp.shape |
|
|
|
split_inp = rearrange(inp, "b c h (d1 w) -> d1 b c h w ", w=scan_len) |
|
for i in range(1, len(split_inp), 2): |
|
split_inp[i, :] = split_inp[i, :].flip(dims=[-2]) |
|
reverse_inp = rearrange(split_inp, " d1 b c h w -> b c (d1 h) w ") |
|
|
|
|
|
inp_window = rearrange(reverse_inp, "b c (d1 h) (d2 w) -> (b c d2 d1) h w ", h=2, w=scan_len) |
|
|
|
inp_window[:,-1] = inp_window[:,-1].flip(dims=[-1]) |
|
inp_flatten = inp_window.reshape(B, C, 1, -1) |
|
|
|
|
|
|
|
return inp_flatten |
|
|
|
def sscan(inp, scan_len, shift_len=0): |
|
B, C, H, W = inp.shape |
|
|
|
|
|
if shift_len == 0: |
|
for i in range(1, (W // scan_len)+1, 2): |
|
|
|
inp[:, :, :, i*scan_len:(i+1)*scan_len] = inp[:, :, :, i*scan_len:(i+1)*scan_len].flip(dims=[-2]) |
|
else: |
|
for i in range(0, ((W-shift_len) // scan_len) +1, 2): |
|
inp[:, :, :,(shift_len+i*scan_len):(shift_len+(i+1)*scan_len)] = inp[:, :, :, (shift_len+i*scan_len):(shift_len+(i+1)*scan_len)].flip(dims=[-2]) |
|
|
|
|
|
|
|
|
|
if shift_len == 0: |
|
for hi in range((H // 2)): |
|
for wi in range(W // scan_len): |
|
inp[:, :, 2*hi+1, wi*scan_len:(wi+1)*scan_len] = inp[:, :, 2*hi+1, wi*scan_len:(wi+1)*scan_len].flip(dims=[-1]) |
|
else: |
|
for hi in range((H // 2)): |
|
inp[:, :, 2*hi+1, 0:shift_len] = inp[:, :, 2*hi+1, 0:shift_len].flip(dims=[-1]) |
|
|
|
for wi in range((W-shift_len) // scan_len): |
|
start_ = shift_len + wi*scan_len |
|
end_ = shift_len + (wi+1)*scan_len |
|
inp[:, :, 2*hi+1, start_:end_] = inp[:, :, 2*hi+1, start_:end_].flip(dims=[-1]) |
|
|
|
|
|
if (W-shift_len) % scan_len: |
|
|
|
inp_last = inp[:,:,:,-((W-shift_len) % scan_len):] |
|
inp_last[:,:, 1::2, :] = inp_last[:,:, 1::2, :].flip(dims=[-1]) |
|
inp_last = inp_last.reshape(B, C, -1) |
|
|
|
inp_rest = inp[:,:,:,:-((W-shift_len) % scan_len)] |
|
else: |
|
inp_rest = inp |
|
|
|
if shift_len==0: |
|
inp_window = rearrange(inp_rest, "b c h (d2 w) -> (b c d2) h w ", w=scan_len) |
|
else: |
|
inp_first = inp_rest[:,:,:,:shift_len].reshape(B, C, -1) |
|
|
|
inp_middle = inp_rest[:,:,:, shift_len:] |
|
inp_window = rearrange(inp_middle, "b c h (d2 w) -> (b c d2) h w ", w=scan_len) |
|
|
|
|
|
inp_flatten = inp_window.reshape(B, C, -1) |
|
|
|
|
|
|
|
|
|
if shift_len != 0: |
|
inp_flatten = torch.concat((inp_first, inp_flatten), dim=-1) |
|
|
|
if (W-shift_len) % scan_len: |
|
inp_flatten = torch.concat((inp_flatten, inp_last), dim=-1) |
|
|
|
return inp_flatten |
|
|
|
|
|
|
|
def sscan_4d(inp, scan_len, shift_len=0, fix_ending=True, use_einops=False): |
|
B, C, H, W = inp.shape |
|
L = H * W |
|
if fix_ending: |
|
inp_reverse = torch.flip(inp, dims=[-1,-2]) |
|
inp_cat = torch.concat((inp, inp_reverse), dim=1) |
|
inp_cat_t = inp_cat.transpose(-1, -2).contiguous() |
|
|
|
if use_einops: |
|
line1 = sscan_einops(inp_cat, scan_len) |
|
line2 = sscan_einops(inp_cat_t, scan_len) |
|
else: |
|
line1 = sscan(inp_cat, scan_len, shift_len) |
|
line2 = sscan(inp_cat_t, scan_len, shift_len) |
|
|
|
xs = torch.stack([line1.reshape(B, 2, -1, L), line2.reshape(B, 2, -1, L)], dim=1).reshape(B, 4, -1, L) |
|
else: |
|
inp_t = inp.transpose(-1, -2).contiguous() |
|
if use_einops: |
|
line1 = sscan_einops(inp, scan_len) |
|
line2 = sscan_einops(inp_t, scan_len) |
|
else: |
|
line1 = sscan(inp, scan_len, shift_len) |
|
line2 = sscan(inp_t, scan_len, shift_len) |
|
|
|
x_hwwh = torch.stack([line1.reshape(B, -1, L), line2.reshape(B, -1, L)], dim=1).reshape(B, 2, -1, L) |
|
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) |
|
|
|
return xs |
|
|
|
def inverse_ids_generate(origin_ids, K=4): |
|
''' |
|
Input: origin_ids: (B, K, C, L) |
|
Output: (B, K, C, L) |
|
Note: C is set to 1 for speeding up. |
|
''' |
|
inverse_ids = torch.argsort(origin_ids, dim=-1) |
|
return inverse_ids |
|
|
|
|
|
def mair_ids_generate(inp_shape, scan_len=4, K=4): |
|
inp_b, inp_c, inp_h, inp_w = inp_shape |
|
|
|
|
|
inp_idx = create_idx(1, 1, inp_h, inp_w) |
|
|
|
xs_scan_ids = sscan_4d(inp_idx, scan_len)[0] |
|
|
|
xs_inverse_ids = inverse_ids_generate(xs_scan_ids, K=K) |
|
|
|
return xs_scan_ids, xs_inverse_ids |
|
|
|
|
|
def mair_shift_ids_generate(inp_shape, scan_len=4, shift_len=0, K=4): |
|
inp_b, inp_c, inp_h, inp_w = inp_shape |
|
|
|
|
|
|
|
inp_idx = create_idx(1, 1, inp_h, inp_w) |
|
|
|
|
|
|
|
xs_scan_ids = sscan_4d(inp_idx, scan_len, shift_len=shift_len)[0] |
|
|
|
|
|
|
|
|
|
xs_scan_ids = xs_scan_ids.repeat(inp_b, 1, 1, 1) |
|
|
|
|
|
|
|
xs_inverse_ids = inverse_ids_generate(xs_scan_ids, K=K) |
|
|
|
|
|
return xs_scan_ids, xs_inverse_ids |
|
|
|
|
|
def mair_ids_scan(inp, xs_scan_ids, bkdl=False, K=4): |
|
''' |
|
inp: B, C, H, W |
|
xs_scan_ids: K, 1, L |
|
''' |
|
B, C, H, W = inp.shape |
|
L = H * W |
|
xs_scan_ids = xs_scan_ids.reshape(K, L) |
|
|
|
y1 = torch.index_select(inp.reshape(B, 1, C, -1), -1, xs_scan_ids[0]) |
|
y2 = torch.index_select(inp.reshape(B, 1, C, -1), -1, xs_scan_ids[1]) |
|
y3 = torch.index_select(inp.reshape(B, 1, C, -1), -1, xs_scan_ids[2]) |
|
y4 = torch.index_select(inp.reshape(B, 1, C, -1), -1, xs_scan_ids[3]) |
|
|
|
if bkdl: |
|
inp_flatten = torch.cat((y1, y2, y3, y4), dim=1) |
|
else: |
|
inp_flatten = torch.cat((y1, y2, y3, y4), dim=1).reshape(B, 4, -1) |
|
return inp_flatten |
|
|
|
def mair_ids_inverse(inp, xs_scan_ids, shape=None): |
|
''' |
|
inp: (B, K, -1, L) |
|
xs_scan_ids: (1, K, 1, L) |
|
''' |
|
B, K, _, L = inp.shape |
|
xs_scan_ids = xs_scan_ids.reshape(K, L) |
|
if not shape: |
|
y1 = torch.index_select(inp[:, 0, :], -1, xs_scan_ids[0]).reshape(B, -1, L) |
|
y2 = torch.index_select(inp[:, 1, :], -1, xs_scan_ids[1]).reshape(B, -1, L) |
|
y3 = torch.index_select(inp[:, 2, :], -1, xs_scan_ids[2]).reshape(B, -1, L) |
|
y4 = torch.index_select(inp[:, 3, :], -1, xs_scan_ids[3]).reshape(B, -1, L) |
|
else: |
|
B, C, H, W = shape |
|
y1 = torch.index_select(inp[:, 0, :], -1, xs_scan_ids[0]).reshape(B, -1, H, W) |
|
y2 = torch.index_select(inp[:, 1, :], -1, xs_scan_ids[1]).reshape(B, -1, H, W) |
|
y3 = torch.index_select(inp[:, 2, :], -1, xs_scan_ids[2]).reshape(B, -1, H, W) |
|
y4 = torch.index_select(inp[:, 3, :], -1, xs_scan_ids[3]).reshape(B, -1, H, W) |
|
return torch.cat((y1,y2,y3,y4), dim=1) |
|
|
|
|
|
def test_time(): |
|
scan_len = 4 |
|
shift_len = 2 |
|
inp_b, inp_c, inp_h, inp_w = 2, 3, 3, 4 |
|
inp = chw_4d(1, 1, inp_h, inp_w, random=False) |
|
inp_rgb = chw_4d(inp_b, inp_c, inp_h, inp_w, random=False) |
|
print("inp:", inp_rgb) |
|
|
|
|
|
xs_scan_ids, xs_inverse_ids = mair_ids_generate(inp.shape, scan_len=scan_len, K=4) |
|
xs = mair_ids_scan(inp_rgb, xs_scan_ids, bkdl=True) |
|
inp_flatten = mair_ids_inverse(xs, xs_inverse_ids, shape=(inp_b, inp_c, inp_h, inp_w)) |
|
|
|
inp_flatten = inp_flatten.chunk(4, dim=1) |
|
print("recovered input:") |
|
for i in range(len(inp_flatten)): |
|
print("inp_flatten:", i) |
|
print(inp_flatten[i]) |
|
print("end") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
start_time = time.time() |
|
result = test_time() |
|
end_time = time.time() |
|
|
|
print(f"函数运行时间:{end_time - start_time} 秒") |
|
|