VanLinLin commited on
Commit
3f7c489
·
0 Parent(s):
.gitignore ADDED
File without changes
README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [TEAM ACVLAB][NTIRE 2025 Image Shadow Removal Challenge](https://cvlai.net/ntire/2025/) @ [CVPR 2025](https://cvpr.thecvf.com/)
2
+
3
+ ## Link to the codes/executables of the solution(s):
4
+ * [Checkpoints](https://drive.google.com/file/d/1USD5sLvEcgFqIg7BDzc1OuInzSx3GnUN/view?usp=drive_link)
5
+ * Input / Output file
6
+
7
+ ## Environments
8
+ ```bash
9
+ conda create -n ntire_shadow python=3.9 -y
10
+
11
+ conda activate ntire_shadow
12
+
13
+ pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
14
+
15
+ pip install -r requirements.txt
16
+
17
+ ```
18
+
19
+ ## Folder Structure
20
+ ```bash
21
+ test_dir
22
+ ├── Origin <- Put the shadow affected images in this folder
23
+ │ ├── 0000.png
24
+ │ ├── 0001.png
25
+ │ ├── ...
26
+ ├── Depth
27
+ ├── Normal
28
+
29
+
30
+ output_dir
31
+ ├── 0000.png
32
+ ├── 0001.png
33
+ ├──...
34
+ ```
35
+
36
+ ## How to test?
37
+ 1. Clone [Depth anything v2](https://github.com/DepthAnything/Depth-Anything-V2.git)
38
+
39
+ ```bash
40
+ git clone https://github.com/DepthAnything/Depth-Anything-V2.git
41
+ ```
42
+ 2. Download the [pretrain model of depth anything v2](https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth?download=true)
43
+
44
+ 3. Run ```python Depth-Anything-V2/get_depth_normap.py```
45
+
46
+ Now folder structure will be
47
+ ```bash
48
+ test_dir
49
+ ├── Origin
50
+ │ ├── 0000.png
51
+ │ ├── 0001.png
52
+ │ ├── ...
53
+ ├── Depth
54
+ │ ├── 0000.npy
55
+ │ ├── 0001.npy
56
+ │ ├── ...
57
+ ├── Normal
58
+ │ ├── 0000.npy
59
+ │ ├── 0001.npy
60
+ │ ├── ...
61
+
62
+ output_dir
63
+ ├── 0000.png
64
+ ├── 0001.png
65
+ ├──...
66
+ ```
67
+
68
+ 4. Clone [DINOv2](https://github.com/facebookresearch/dinov2.git)
69
+ ```bash
70
+ git clone https://github.com/facebookresearch/dinov2.git
71
+ ```
72
+
73
+ 5. Download [shadow removal weight](https://drive.google.com/file/d/1USD5sLvEcgFqIg7BDzc1OuInzSx3GnUN/view?usp=drive_link)
74
+
75
+ ```bash
76
+ gdown 1USD5sLvEcgFqIg7BDzc1OuInzSx3GnUN
77
+ ```
78
+
79
+ 6. Run ```run_test.sh``` to get inference results.
80
+
81
+ ```bash
82
+ bash run_test.sh
83
+ ```
dataset.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ from torch.utils.data import Dataset
4
+ import torch
5
+ from utils import load_normal, load_ssao, load_img, depthToPoint, process_normal, load_depth, Augment_RGB_torch
6
+ import torch.nn.functional as F
7
+ import random
8
+
9
+ augment = Augment_RGB_torch()
10
+ transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')]
11
+
12
+ ##################################################################################################
13
+ class DataLoaderTrain(Dataset):
14
+ def __init__(self, rgb_dir, img_options=None, target_transform=None, debug=False):
15
+ super(DataLoaderTrain, self).__init__()
16
+
17
+ self.target_transform = target_transform
18
+
19
+ gt_dir = 'shadow_free'
20
+ input_dir = 'origin'
21
+ depth_dir = 'depth'
22
+ normal_dir = 'normal'
23
+
24
+ clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir))) # shadow free
25
+ noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) # origin
26
+ depth_files = sorted(os.listdir(os.path.join(rgb_dir, depth_dir))) # depth
27
+ normal_files = sorted(os.listdir(os.path.join(rgb_dir, normal_dir))) # noraml map
28
+
29
+ self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files] # shadow free
30
+ self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files] # origin
31
+ self.depth_filenames = [os.path.join(rgb_dir, depth_dir, x) for x in depth_files] # depth
32
+ self.normal_filenames = [os.path.join(rgb_dir, normal_dir, x) for x in normal_files] # noraml map
33
+ self.img_options = img_options
34
+
35
+ if debug:
36
+ self.tar_size = 100
37
+ else:
38
+ self.tar_size = len(self.noisy_filenames)
39
+
40
+ def __len__(self):
41
+ return self.tar_size
42
+
43
+ def __getitem__(self, index):
44
+ tar_index = index % self.tar_size
45
+
46
+ clean = np.float32(load_img(self.clean_filenames[tar_index]))
47
+ noisy = np.float32(load_img(self.noisy_filenames[tar_index]))
48
+ depth = np.float32(load_depth(self.depth_filenames[tar_index]))
49
+ normal = np.float32(load_normal(self.normal_filenames[tar_index]))
50
+
51
+ point = depthToPoint(60, depth)
52
+
53
+ normal = process_normal(normal)
54
+
55
+ clean = torch.from_numpy(clean)
56
+ noisy = torch.from_numpy(noisy)
57
+ depth = torch.from_numpy(depth)
58
+ point = torch.from_numpy(point)
59
+ normal = torch.from_numpy(normal)
60
+
61
+ point = point / (2 * point[:,:,2].mean())
62
+
63
+ clean = clean.permute(2,0,1)
64
+ noisy = noisy.permute(2,0,1)
65
+ point = point.permute(2,0,1)
66
+ normal = normal.permute(2,0,1)
67
+
68
+ clean_filename = os.path.split(self.clean_filenames[tar_index])[-1]
69
+ noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]
70
+ depth_filename = os.path.split(self.depth_filenames[tar_index])[-1]
71
+ normal_filename = os.path.split(self.normal_filenames[tar_index])[-1]
72
+
73
+
74
+ augment.rotate = random.randint(-20,20)
75
+ apply_trans = transforms_aug[random.randint(0, 2)]
76
+
77
+ # [0, 1]
78
+ clean = getattr(augment, apply_trans)(clean)
79
+ noisy = getattr(augment, apply_trans)(noisy)
80
+ point = getattr(augment, apply_trans)(point)
81
+ normal = getattr(augment, apply_trans)(normal)
82
+
83
+
84
+ #Crop Input and Target
85
+ ps = self.img_options['patch_size']
86
+ scale = 1#random.uniform(1, 1.5)
87
+
88
+ H = noisy.shape[1]
89
+ W = noisy.shape[2]
90
+ scaled_ps = (int)(scale * ps)
91
+ if H - scaled_ps != 0 or W - scaled_ps != 0:
92
+ r = np.random.randint(0, H - scaled_ps + 1)
93
+ c = np.random.randint(0, W - scaled_ps + 1)
94
+ clean = clean [:, r:r + scaled_ps, c:c + scaled_ps]
95
+ noisy = noisy [:, r:r + scaled_ps, c:c + scaled_ps]
96
+ point = point [:, r:r + scaled_ps, c:c + scaled_ps]
97
+ normal = normal [:, r:r + scaled_ps, c:c + scaled_ps]
98
+
99
+ # scale back to the patch_size
100
+ if scale != 1:
101
+ clean = F.interpolate(clean.unsqueeze(0), size=[ps, ps], mode='bilinear')
102
+ noisy = F.interpolate(noisy.unsqueeze(0), size=[ps, ps], mode='bilinear')
103
+ point = F.interpolate(point.unsqueeze(0), size=[ps, ps], mode='nearest')
104
+ normal = F.interpolate(normal.unsqueeze(0), size=[ps, ps], mode='nearest')
105
+ return clean.squeeze(0), noisy.squeeze(0), point.squeeze(0), normal.squeeze(0), noisy_filename
106
+
107
+ return clean, noisy, point, normal, clean_filename, noisy_filename
108
+
109
+
110
+ ##################################################################################################
111
+ class DataLoaderVal(Dataset):
112
+ def __init__(self, rgb_dir, target_transform=None, debug=False):
113
+ super(DataLoaderVal, self).__init__()
114
+
115
+ self.target_transform = target_transform
116
+
117
+ gt_dir = 'shadow_free'
118
+ input_dir = 'origin'
119
+ depth_dir = 'depth'
120
+ normal_dir = 'normal'
121
+
122
+ clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir)))
123
+ noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir)))
124
+ depth_files = sorted(os.listdir(os.path.join(rgb_dir, depth_dir)))
125
+ normal_files = sorted(os.listdir(os.path.join(rgb_dir, normal_dir)))
126
+
127
+ self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files]
128
+ self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files]
129
+ self.depth_filenames = [os.path.join(rgb_dir, depth_dir, x) for x in depth_files]
130
+ self.normal_filenames = [os.path.join(rgb_dir, normal_dir, x) for x in normal_files]
131
+
132
+ if debug:
133
+ self.tar_size = 10
134
+ else:
135
+ self.tar_size = len(self.noisy_filenames)
136
+
137
+ def __len__(self):
138
+ return self.tar_size
139
+
140
+ def __getitem__(self, index):
141
+ tar_index = index % self.tar_size
142
+ clean = np.float32(load_img(self.clean_filenames[tar_index]))
143
+ noisy = np.float32(load_img(self.noisy_filenames[tar_index]))
144
+ depth = np.float32(load_depth(self.depth_filenames[tar_index]))
145
+ normal = np.float32(load_normal(self.normal_filenames[tar_index]))
146
+
147
+ point = depthToPoint(60, depth)
148
+ normal = process_normal(normal)
149
+ point = point / (2 * point[:,:,2].mean())
150
+
151
+ clean_filename = os.path.split(self.clean_filenames[tar_index])[-1]
152
+ noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]
153
+
154
+ clean = torch.from_numpy(clean)
155
+ noisy = torch.from_numpy(noisy)
156
+ point = torch.from_numpy(point)
157
+ normal = torch.from_numpy(normal)
158
+
159
+
160
+ clean = clean.permute(2,0,1)
161
+ noisy = noisy.permute(2,0,1)
162
+ point = point.permute(2,0,1)
163
+ normal = normal.permute(2,0,1)
164
+
165
+
166
+ return clean, noisy, point, normal, clean_filename, noisy_filename
167
+
168
+
169
+
170
+
171
+ ##################################################################################################
172
+ class DataLoaderTest(Dataset):
173
+ def __init__(self, rgb_dir, target_transform=None, debug=False):
174
+ super(DataLoaderTest, self).__init__()
175
+
176
+ self.target_transform = target_transform
177
+
178
+ # gt_dir = 'shadow_free'
179
+ input_dir = 'origin'
180
+ depth_dir = 'depth'
181
+ normal_dir = 'normal'
182
+
183
+ # clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir)))
184
+ noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir)))
185
+ depth_files = sorted(os.listdir(os.path.join(rgb_dir, depth_dir)))
186
+ normal_files = sorted(os.listdir(os.path.join(rgb_dir, normal_dir)))
187
+
188
+ # self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files]
189
+ self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files]
190
+ self.depth_filenames = [os.path.join(rgb_dir, depth_dir, x) for x in depth_files]
191
+ self.normal_filenames = [os.path.join(rgb_dir, normal_dir, x) for x in normal_files]
192
+
193
+ if debug:
194
+ self.tar_size = 10
195
+ else:
196
+ self.tar_size = len(self.noisy_filenames)
197
+
198
+ def __len__(self):
199
+ return self.tar_size
200
+
201
+ def __getitem__(self, index):
202
+ tar_index = index % self.tar_size
203
+ # clean = np.float32(load_img(self.clean_filenames[tar_index]))
204
+ noisy = np.float32(load_img(self.noisy_filenames[tar_index]))
205
+ depth = np.float32(load_depth(self.depth_filenames[tar_index]))
206
+ normal = np.float32(load_normal(self.normal_filenames[tar_index]))
207
+
208
+ point = depthToPoint(60, depth)
209
+ normal = process_normal(normal)
210
+ point = point / (2 * point[:,:,2].mean())
211
+
212
+ # clean_filename = os.path.split(self.clean_filenames[tar_index])[-1]
213
+ noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]
214
+
215
+ # clean = torch.from_numpy(clean)
216
+ noisy = torch.from_numpy(noisy)
217
+ point = torch.from_numpy(point)
218
+ normal = torch.from_numpy(normal)
219
+
220
+
221
+ # clean = clean.permute(2,0,1)
222
+ noisy = noisy.permute(2,0,1)
223
+ point = point.permute(2,0,1)
224
+ normal = normal.permute(2,0,1)
225
+
226
+
227
+ return noisy, noisy, point, normal, noisy_filename, noisy_filename
228
+
freqfusion.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TPAMI 2024:Frequency-aware Feature Fusion for Dense Image Prediction
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.utils.checkpoint import checkpoint
7
+ import warnings
8
+ import numpy as np
9
+
10
+ try:
11
+ from mmcv.ops.carafe import normal_init, xavier_init, carafe
12
+ except ImportError:
13
+
14
+ def xavier_init(module: nn.Module,
15
+ gain: float = 1,
16
+ bias: float = 0,
17
+ distribution: str = 'normal') -> None:
18
+ assert distribution in ['uniform', 'normal']
19
+ if hasattr(module, 'weight') and module.weight is not None:
20
+ if distribution == 'uniform':
21
+ nn.init.xavier_uniform_(module.weight, gain=gain)
22
+ else:
23
+ nn.init.xavier_normal_(module.weight, gain=gain)
24
+ if hasattr(module, 'bias') and module.bias is not None:
25
+ nn.init.constant_(module.bias, bias)
26
+
27
+ def carafe(x, normed_mask, kernel_size, group=1, up=1):
28
+ b, c, h, w = x.shape
29
+ _, m_c, m_h, m_w = normed_mask.shape
30
+ # print('x', x.shape)
31
+ # print('normed_mask', normed_mask.shape)
32
+ # assert m_c == kernel_size ** 2 * up ** 2
33
+ assert m_h == up * h
34
+ assert m_w == up * w
35
+ pad = kernel_size // 2
36
+ # print(pad)
37
+ pad_x = F.pad(x, pad=[pad] * 4, mode='reflect')
38
+ # print(pad_x.shape)
39
+ unfold_x = F.unfold(pad_x, kernel_size=(kernel_size, kernel_size), stride=1, padding=0)
40
+ # unfold_x = unfold_x.reshape(b, c, 1, kernel_size, kernel_size, h, w).repeat(1, 1, up ** 2, 1, 1, 1, 1)
41
+ unfold_x = unfold_x.reshape(b, c * kernel_size * kernel_size, h, w)
42
+ unfold_x = F.interpolate(unfold_x, scale_factor=up, mode='nearest')
43
+ # normed_mask = normed_mask.reshape(b, 1, up ** 2, kernel_size, kernel_size, h, w)
44
+ unfold_x = unfold_x.reshape(b, c, kernel_size * kernel_size, m_h, m_w)
45
+ normed_mask = normed_mask.reshape(b, 1, kernel_size * kernel_size, m_h, m_w)
46
+ res = unfold_x * normed_mask
47
+ # test
48
+ # res[:, :, 0] = 1
49
+ # res[:, :, 1] = 2
50
+ # res[:, :, 2] = 3
51
+ # res[:, :, 3] = 4
52
+ res = res.sum(dim=2).reshape(b, c, m_h, m_w)
53
+ # res = F.pixel_shuffle(res, up)
54
+ # print(res.shape)
55
+ # print(res)
56
+ return res
57
+
58
+ def normal_init(module, mean=0, std=1, bias=0):
59
+ if hasattr(module, 'weight') and module.weight is not None:
60
+ nn.init.normal_(module.weight, mean, std)
61
+ if hasattr(module, 'bias') and module.bias is not None:
62
+ nn.init.constant_(module.bias, bias)
63
+
64
+
65
+ def constant_init(module, val, bias=0):
66
+ if hasattr(module, 'weight') and module.weight is not None:
67
+ nn.init.constant_(module.weight, val)
68
+ if hasattr(module, 'bias') and module.bias is not None:
69
+ nn.init.constant_(module.bias, bias)
70
+
71
+ def resize(input,
72
+ size=None,
73
+ scale_factor=None,
74
+ mode='nearest',
75
+ align_corners=None,
76
+ warning=True):
77
+ if warning:
78
+ if size is not None and align_corners:
79
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
80
+ output_h, output_w = tuple(int(x) for x in size)
81
+ if output_h > input_h or output_w > input_w:
82
+ if ((output_h > 1 and output_w > 1 and input_h > 1
83
+ and input_w > 1) and (output_h - 1) % (input_h - 1)
84
+ and (output_w - 1) % (input_w - 1)):
85
+ warnings.warn(
86
+ f'When align_corners={align_corners}, '
87
+ 'the output would more aligned if '
88
+ f'input size {(input_h, input_w)} is `x+1` and '
89
+ f'out size {(output_h, output_w)} is `nx+1`')
90
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
91
+
92
+ def hamming2D(M, N):
93
+ """
94
+ 生成二维Hamming窗
95
+
96
+ 参数:
97
+ - M:窗口的行数
98
+ - N:窗口的列数
99
+
100
+ 返回:
101
+ - 二维Hamming窗
102
+ """
103
+ # 生成水平和垂直方向上的Hamming窗
104
+ # hamming_x = np.blackman(M)
105
+ # hamming_x = np.kaiser(M)
106
+ hamming_x = np.hamming(M)
107
+ hamming_y = np.hamming(N)
108
+ # 通过外积生成二维Hamming窗
109
+ hamming_2d = np.outer(hamming_x, hamming_y)
110
+ return hamming_2d
111
+
112
+ class FreqFusion(nn.Module):
113
+ def __init__(self,
114
+ hr_channels,
115
+ lr_channels,
116
+ scale_factor=1,
117
+ lowpass_kernel=5,
118
+ highpass_kernel=3,
119
+ up_group=1,
120
+ encoder_kernel=3,
121
+ encoder_dilation=1,
122
+ compressed_channels=64,
123
+ align_corners=False,
124
+ upsample_mode='nearest',
125
+ feature_resample=False, # use offset generator or not
126
+ feature_resample_group=4,
127
+ comp_feat_upsample=True, # use ALPF & AHPF for init upsampling
128
+ use_high_pass=True,
129
+ use_low_pass=True,
130
+ hr_residual=True,
131
+ semi_conv=True,
132
+ hamming_window=True, # for regularization, do not matter really
133
+ feature_resample_norm=True,
134
+ **kwargs):
135
+ super().__init__()
136
+ self.scale_factor = scale_factor
137
+ self.lowpass_kernel = lowpass_kernel
138
+ self.highpass_kernel = highpass_kernel
139
+ self.up_group = up_group
140
+ self.encoder_kernel = encoder_kernel
141
+ self.encoder_dilation = encoder_dilation
142
+ self.compressed_channels = compressed_channels
143
+ self.hr_channel_compressor = nn.Conv2d(hr_channels, self.compressed_channels,1)
144
+ self.lr_channel_compressor = nn.Conv2d(lr_channels, self.compressed_channels,1)
145
+ self.content_encoder = nn.Conv2d( # ALPF generator
146
+ self.compressed_channels,
147
+ lowpass_kernel ** 2 * self.up_group * self.scale_factor * self.scale_factor,
148
+ self.encoder_kernel,
149
+ padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
150
+ dilation=self.encoder_dilation,
151
+ groups=1)
152
+
153
+ self.align_corners = align_corners
154
+ self.upsample_mode = upsample_mode
155
+ self.hr_residual = hr_residual
156
+ self.use_high_pass = use_high_pass
157
+ self.use_low_pass = use_low_pass
158
+ self.semi_conv = semi_conv
159
+ self.feature_resample = feature_resample
160
+ self.comp_feat_upsample = comp_feat_upsample
161
+ if self.feature_resample:
162
+ self.dysampler = LocalSimGuidedSampler(in_channels=compressed_channels, scale=2, style='lp', groups=feature_resample_group, use_direct_scale=True, kernel_size=encoder_kernel, norm=feature_resample_norm)
163
+ if self.use_high_pass:
164
+ self.content_encoder2 = nn.Conv2d( # AHPF generator
165
+ self.compressed_channels,
166
+ highpass_kernel ** 2 * self.up_group * self.scale_factor * self.scale_factor,
167
+ self.encoder_kernel,
168
+ padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
169
+ dilation=self.encoder_dilation,
170
+ groups=1)
171
+ self.hamming_window = hamming_window
172
+ lowpass_pad=0
173
+ highpass_pad=0
174
+ if self.hamming_window:
175
+ self.register_buffer('hamming_lowpass', torch.FloatTensor(hamming2D(lowpass_kernel + 2 * lowpass_pad, lowpass_kernel + 2 * lowpass_pad))[None, None,])
176
+ self.register_buffer('hamming_highpass', torch.FloatTensor(hamming2D(highpass_kernel + 2 * highpass_pad, highpass_kernel + 2 * highpass_pad))[None, None,])
177
+ else:
178
+ self.register_buffer('hamming_lowpass', torch.FloatTensor([1.0]))
179
+ self.register_buffer('hamming_highpass', torch.FloatTensor([1.0]))
180
+ self.init_weights()
181
+
182
+ def init_weights(self):
183
+ for m in self.modules():
184
+ # print(m)
185
+ if isinstance(m, nn.Conv2d):
186
+ xavier_init(m, distribution='uniform')
187
+ normal_init(self.content_encoder, std=0.001)
188
+ if self.use_high_pass:
189
+ normal_init(self.content_encoder2, std=0.001)
190
+
191
+ def kernel_normalizer(self, mask, kernel, scale_factor=None, hamming=1):
192
+ if scale_factor is not None:
193
+ mask = F.pixel_shuffle(mask, self.scale_factor)
194
+ n, mask_c, h, w = mask.size()
195
+ mask_channel = int(mask_c / float(kernel**2)) # group
196
+ # mask = mask.view(n, mask_channel, -1, h, w)
197
+ # mask = F.softmax(mask, dim=2, dtype=mask.dtype)
198
+ # mask = mask.view(n, mask_c, h, w).contiguous()
199
+
200
+ mask = mask.view(n, mask_channel, -1, h, w)
201
+ mask = F.softmax(mask, dim=2, dtype=mask.dtype)
202
+ mask = mask.view(n, mask_channel, kernel, kernel, h, w)
203
+ mask = mask.permute(0, 1, 4, 5, 2, 3).view(n, -1, kernel, kernel)
204
+ # mask = F.pad(mask, pad=[padding] * 4, mode=self.padding_mode) # kernel + 2 * padding
205
+ mask = mask * hamming
206
+ mask /= mask.sum(dim=(-1, -2), keepdims=True)
207
+ # print(hamming)
208
+ # print(mask.shape)
209
+ mask = mask.view(n, mask_channel, h, w, -1)
210
+ mask = mask.permute(0, 1, 4, 2, 3).view(n, -1, h, w).contiguous()
211
+ return mask
212
+
213
+ def forward(self, hr_feat, lr_feat, use_checkpoint=False): # use check_point to save GPU memory
214
+ if use_checkpoint:
215
+ return checkpoint(self._forward, hr_feat, lr_feat)
216
+ else:
217
+ return self._forward(hr_feat, lr_feat)
218
+
219
+ def _forward(self, hr_feat, lr_feat):
220
+ compressed_hr_feat = self.hr_channel_compressor(hr_feat)
221
+ compressed_lr_feat = self.lr_channel_compressor(lr_feat)
222
+ if self.semi_conv:
223
+ if self.comp_feat_upsample:
224
+ if self.use_high_pass:
225
+ mask_hr_hr_feat = self.content_encoder2(compressed_hr_feat) #从hr_feat得到初始高通滤波特征
226
+ mask_hr_init = self.kernel_normalizer(mask_hr_hr_feat, self.highpass_kernel, hamming=self.hamming_highpass) #kernel归一化得到初始高通滤波
227
+ compressed_hr_feat = compressed_hr_feat + compressed_hr_feat - carafe(compressed_hr_feat, mask_hr_init, self.highpass_kernel, self.up_group, 1) #利用初始高通滤波对压缩hr_feat的高频增强 (x-x的低通结果=x的高通结果)
228
+
229
+ mask_lr_hr_feat = self.content_encoder(compressed_hr_feat) #从hr_feat得到初始低通滤波特征
230
+ mask_lr_init = self.kernel_normalizer(mask_lr_hr_feat, self.lowpass_kernel, hamming=self.hamming_lowpass) #kernel归一化得到初始低通滤波
231
+
232
+ mask_lr_lr_feat_lr = self.content_encoder(compressed_lr_feat) #从hr_feat得到另一部分初始低通滤波特征
233
+ mask_lr_lr_feat = F.interpolate( #利用初始低通滤波对另一部分初始低通滤波特征上采样
234
+ carafe(mask_lr_lr_feat_lr, mask_lr_init, self.lowpass_kernel, self.up_group, 2), size=compressed_hr_feat.shape[-2:], mode='nearest')
235
+ mask_lr = mask_lr_hr_feat + mask_lr_lr_feat #将两部分初始低通滤波特征合在一起
236
+
237
+ mask_lr_init = self.kernel_normalizer(mask_lr, self.lowpass_kernel, hamming=self.hamming_lowpass) #得到初步融合的初始低通滤波
238
+ mask_hr_lr_feat = F.interpolate( #使用初始低通滤波对lr_feat处理,分辨率得到提高
239
+ carafe(self.content_encoder2(compressed_lr_feat), mask_lr_init, self.lowpass_kernel, self.up_group, 2), size=compressed_hr_feat.shape[-2:], mode='nearest')
240
+ mask_hr = mask_hr_hr_feat + mask_hr_lr_feat # 最终高通滤波特征
241
+ else: raise NotImplementedError
242
+ else:
243
+ mask_lr = self.content_encoder(compressed_hr_feat) + F.interpolate(self.content_encoder(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest')
244
+ if self.use_high_pass:
245
+ mask_hr = self.content_encoder2(compressed_hr_feat) + F.interpolate(self.content_encoder2(compressed_lr_feat), size=compressed_hr_feat.shape[-2:], mode='nearest')
246
+ else:
247
+ compressed_x = F.interpolate(compressed_lr_feat, size=compressed_hr_feat.shape[-2:], mode='nearest') + compressed_hr_feat
248
+ mask_lr = self.content_encoder(compressed_x)
249
+ if self.use_high_pass:
250
+ mask_hr = self.content_encoder2(compressed_x)
251
+
252
+ mask_lr = self.kernel_normalizer(mask_lr, self.lowpass_kernel, hamming=self.hamming_lowpass)
253
+ if self.semi_conv:
254
+ lr_feat = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 2)
255
+ else:
256
+ lr_feat = resize(
257
+ input=lr_feat,
258
+ size=hr_feat.shape[2:],
259
+ mode=self.upsample_mode,
260
+ align_corners=None if self.upsample_mode == 'nearest' else self.align_corners)
261
+ lr_feat = carafe(lr_feat, mask_lr, self.lowpass_kernel, self.up_group, 1)
262
+
263
+ if self.use_high_pass:
264
+ mask_hr = self.kernel_normalizer(mask_hr, self.highpass_kernel, hamming=self.hamming_highpass)
265
+ hr_feat_hf = hr_feat - carafe(hr_feat, mask_hr, self.highpass_kernel, self.up_group, 1)
266
+ if self.hr_residual:
267
+ # print('using hr_residual')
268
+ hr_feat = hr_feat_hf + hr_feat
269
+ else:
270
+ hr_feat = hr_feat_hf
271
+
272
+ if self.feature_resample:
273
+ # print(lr_feat.shape)
274
+ lr_feat = self.dysampler(hr_x=compressed_hr_feat,
275
+ lr_x=compressed_lr_feat, feat2sample=lr_feat)
276
+
277
+ return mask_lr, hr_feat, lr_feat
278
+
279
+
280
+
281
+ class LocalSimGuidedSampler(nn.Module):
282
+ """
283
+ offset generator in FreqFusion
284
+ """
285
+ def __init__(self, in_channels, scale=2, style='lp', groups=4, use_direct_scale=True, kernel_size=1, local_window=3, sim_type='cos', norm=True, direction_feat='sim_concat'):
286
+ super().__init__()
287
+ assert scale==2
288
+ assert style=='lp'
289
+
290
+ self.scale = scale
291
+ self.style = style
292
+ self.groups = groups
293
+ self.local_window = local_window
294
+ self.sim_type = sim_type
295
+ self.direction_feat = direction_feat
296
+
297
+ if style == 'pl':
298
+ assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0
299
+ assert in_channels >= groups and in_channels % groups == 0
300
+
301
+ if style == 'pl':
302
+ in_channels = in_channels // scale ** 2
303
+ out_channels = 2 * groups
304
+ else:
305
+ out_channels = 2 * groups * scale ** 2
306
+ if self.direction_feat == 'sim':
307
+ self.offset = nn.Conv2d(local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
308
+ elif self.direction_feat == 'sim_concat':
309
+ self.offset = nn.Conv2d(in_channels + local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
310
+ else: raise NotImplementedError
311
+ normal_init(self.offset, std=0.001)
312
+ if use_direct_scale:
313
+ if self.direction_feat == 'sim':
314
+ self.direct_scale = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
315
+ elif self.direction_feat == 'sim_concat':
316
+ self.direct_scale = nn.Conv2d(in_channels + local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
317
+ else: raise NotImplementedError
318
+ constant_init(self.direct_scale, val=0.)
319
+
320
+ out_channels = 2 * groups
321
+ if self.direction_feat == 'sim':
322
+ self.hr_offset = nn.Conv2d(local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
323
+ elif self.direction_feat == 'sim_concat':
324
+ self.hr_offset = nn.Conv2d(in_channels + local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
325
+ else: raise NotImplementedError
326
+ normal_init(self.hr_offset, std=0.001)
327
+
328
+ if use_direct_scale:
329
+ if self.direction_feat == 'sim':
330
+ self.hr_direct_scale = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
331
+ elif self.direction_feat == 'sim_concat':
332
+ self.hr_direct_scale = nn.Conv2d(in_channels + local_window**2 - 1, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
333
+ else: raise NotImplementedError
334
+ constant_init(self.hr_direct_scale, val=0.)
335
+
336
+ self.norm = norm
337
+ if self.norm:
338
+ self.norm_hr = nn.GroupNorm(in_channels // 8, in_channels)
339
+ self.norm_lr = nn.GroupNorm(in_channels // 8, in_channels)
340
+ else:
341
+ self.norm_hr = nn.Identity()
342
+ self.norm_lr = nn.Identity()
343
+ self.register_buffer('init_pos', self._init_pos())
344
+
345
+ def _init_pos(self):
346
+ h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
347
+ return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)
348
+
349
+ def sample(self, x, offset, scale=None):
350
+ if scale is None: scale = self.scale
351
+ B, _, H, W = offset.shape
352
+ offset = offset.view(B, 2, -1, H, W)
353
+ coords_h = torch.arange(H) + 0.5
354
+ coords_w = torch.arange(W) + 0.5
355
+ coords = torch.stack(torch.meshgrid([coords_w, coords_h])
356
+ ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)
357
+ normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)
358
+ coords = 2 * (coords + offset) / normalizer - 1
359
+ coords = F.pixel_shuffle(coords.view(B, -1, H, W), scale).view(
360
+ B, 2, -1, scale * H, scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)
361
+ return F.grid_sample(x.reshape(B * self.groups, -1, x.size(-2), x.size(-1)), coords, mode='bilinear',
362
+ align_corners=False, padding_mode="border").view(B, -1, scale * H, scale * W)
363
+
364
+ def forward(self, hr_x, lr_x, feat2sample):
365
+ hr_x = self.norm_hr(hr_x)
366
+ lr_x = self.norm_lr(lr_x)
367
+
368
+ if self.direction_feat == 'sim':
369
+ hr_sim = compute_similarity(hr_x, self.local_window, dilation=2, sim='cos')
370
+ lr_sim = compute_similarity(lr_x, self.local_window, dilation=2, sim='cos')
371
+ elif self.direction_feat == 'sim_concat':
372
+ hr_sim = torch.cat([hr_x, compute_similarity(hr_x, self.local_window, dilation=2, sim='cos')], dim=1)
373
+ lr_sim = torch.cat([lr_x, compute_similarity(lr_x, self.local_window, dilation=2, sim='cos')], dim=1)
374
+ hr_x, lr_x = hr_sim, lr_sim
375
+ # offset = self.get_offset(hr_x, lr_x)
376
+ offset = self.get_offset_lp(hr_x, lr_x, hr_sim, lr_sim)
377
+ return self.sample(feat2sample, offset)
378
+
379
+ # def get_offset_lp(self, hr_x, lr_x):
380
+ def get_offset_lp(self, hr_x, lr_x, hr_sim, lr_sim):
381
+ if hasattr(self, 'direct_scale'):
382
+ # offset = (self.offset(lr_x) + F.pixel_unshuffle(self.hr_offset(hr_x), self.scale)) * (self.direct_scale(lr_x) + F.pixel_unshuffle(self.hr_direct_scale(hr_x), self.scale)).sigmoid() + self.init_pos
383
+ offset = (self.offset(lr_sim) + F.pixel_unshuffle(self.hr_offset(hr_sim), self.scale)) * (self.direct_scale(lr_x) + F.pixel_unshuffle(self.hr_direct_scale(hr_x), self.scale)).sigmoid() + self.init_pos
384
+ # offset = (self.offset(lr_sim) + F.pixel_unshuffle(self.hr_offset(hr_sim), self.scale)) * (self.direct_scale(lr_sim) + F.pixel_unshuffle(self.hr_direct_scale(hr_sim), self.scale)).sigmoid() + self.init_pos
385
+ else:
386
+ offset = (self.offset(lr_x) + F.pixel_unshuffle(self.hr_offset(hr_x), self.scale)) * 0.25 + self.init_pos
387
+ return offset
388
+
389
+ def get_offset(self, hr_x, lr_x):
390
+ if self.style == 'pl':
391
+ raise NotImplementedError
392
+ return self.get_offset_lp(hr_x, lr_x)
393
+
394
+
395
+ def compute_similarity(input_tensor, k=3, dilation=1, sim='cos'):
396
+ """
397
+ 计算输入张量中每一点与周围KxK范围内的点的余弦相似度。
398
+
399
+ 参数:
400
+ - input_tensor: 输入张量,形状为[B, C, H, W]
401
+ - k: 范围大小,表示周围KxK范围内的点
402
+
403
+ 返回:
404
+ - 输出张量,形状为[B, KxK-1, H, W]
405
+ """
406
+ B, C, H, W = input_tensor.shape
407
+ # 使用零填充来处理边界情况
408
+ # padded_input = F.pad(input_tensor, (k // 2, k // 2, k // 2, k // 2), mode='constant', value=0)
409
+
410
+ # 展平输入张量中每个点及其周围KxK范围内的点
411
+ unfold_tensor = F.unfold(input_tensor, k, padding=(k // 2) * dilation, dilation=dilation) # B, CxKxK, HW
412
+ # print(unfold_tensor.shape)
413
+ unfold_tensor = unfold_tensor.reshape(B, C, k**2, H, W)
414
+
415
+ # 计算余弦相似度
416
+ if sim == 'cos':
417
+ similarity = F.cosine_similarity(unfold_tensor[:, :, k * k // 2:k * k // 2 + 1], unfold_tensor[:, :, :], dim=1)
418
+ elif sim == 'dot':
419
+ similarity = unfold_tensor[:, :, k * k // 2:k * k // 2 + 1] * unfold_tensor[:, :, :]
420
+ similarity = similarity.sum(dim=1)
421
+ else:
422
+ raise NotImplementedError
423
+
424
+ # 移除中心点的余弦相似度,得到[KxK-1]的结果
425
+ similarity = torch.cat((similarity[:, :k * k // 2], similarity[:, k * k // 2 + 1:]), dim=1)
426
+
427
+ # 将结果重塑回[B, KxK-1, H, W]的形状
428
+ similarity = similarity.view(B, k * k - 1, H, W)
429
+ return similarity
430
+
431
+
432
+ if __name__ == '__main__':
433
+ # x = torch.rand(4, 128, 16, 16)
434
+ # mask = torch.rand(4, 4 * 25, 16, 16)
435
+ # carafe(x, mask, kernel_size=5, group=1, up=2)
436
+
437
+ hr_feat = torch.rand(1, 128, 512, 512)
438
+ lr_feat = torch.rand(1, 128, 256, 256)
439
+ model = FreqFusion(hr_channels=128, lr_channels=128)
440
+ mask_lr, hr_feat, lr_feat = model(hr_feat=hr_feat, lr_feat=lr_feat)
441
+ print(mask_lr.shape)
model.py ADDED
@@ -0,0 +1,1418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.utils.checkpoint as checkpoint
4
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, repeat
7
+ import math
8
+ from utils import grid_sample
9
+
10
+ from freqfusion import FreqFusion
11
+
12
+ #########################################
13
+
14
+ class SepConv2d(torch.nn.Module):
15
+ def __init__(self,
16
+ in_channels,
17
+ out_channels,
18
+ kernel_size,
19
+ stride=1,
20
+ padding=0,
21
+ dilation=1,act_layer=nn.ReLU):
22
+ super(SepConv2d, self).__init__()
23
+ self.depthwise = torch.nn.Conv2d(in_channels,
24
+ in_channels,
25
+ kernel_size=kernel_size,
26
+ stride=stride,
27
+ padding=padding,
28
+ dilation=dilation,
29
+ groups=in_channels)
30
+ self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1)
31
+ self.act_layer = act_layer() if act_layer is not None else nn.Identity()
32
+ self.in_channels = in_channels
33
+ self.out_channels = out_channels
34
+ self.kernel_size = kernel_size
35
+ self.stride = stride
36
+
37
+ def forward(self, x):
38
+ x = self.depthwise(x)
39
+ x = self.act_layer(x)
40
+ x = self.pointwise(x)
41
+ return x
42
+
43
+ def flops(self, H, W):
44
+ flops = 0
45
+ flops += H*W*self.in_channels*self.kernel_size**2/self.stride**2
46
+ flops += H*W*self.in_channels*self.out_channels
47
+ return flops
48
+
49
+ ##########################################################################
50
+ ## Channel Attention Layer
51
+ class CALayer(nn.Module):
52
+ def __init__(self, channel, reduction=16, bias=False):
53
+ super(CALayer, self).__init__()
54
+ # global average pooling: feature --> point
55
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
56
+ # feature channel downscale and upscale --> channel weight
57
+ self.conv_du = nn.Sequential(
58
+ nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
59
+ nn.ReLU(inplace=True),
60
+ nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
61
+ nn.Sigmoid()
62
+ )
63
+
64
+ def forward(self, x):
65
+ y = self.avg_pool(x)
66
+ y = self.conv_du(y)
67
+ return x * y
68
+
69
+ def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
70
+ return nn.Conv2d(
71
+ in_channels, out_channels, kernel_size,
72
+ padding=(kernel_size//2), bias=bias, stride = stride)
73
+
74
+ ##########################################################################
75
+ ## Channel Attention Block (CAB)
76
+ class CAB(nn.Module):
77
+ def __init__(self, n_feat, kernel_size, reduction, bias, act):
78
+ super(CAB, self).__init__()
79
+ modules_body = []
80
+ modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
81
+ modules_body.append(act)
82
+ modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
83
+
84
+ self.CA = CALayer(n_feat, reduction, bias=bias)
85
+ self.body = nn.Sequential(*modules_body)
86
+
87
+ def forward(self, x):
88
+ res = self.body(x)
89
+ res = self.CA(res)
90
+ res += x
91
+ return res
92
+
93
+ #########################################
94
+ ######## Embedding for q,k,v ########
95
+ class ConvProjection(nn.Module):
96
+ def __init__(self, dim, heads = 8, dim_head = 64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1, dropout = 0.,
97
+ last_stage=False,bias=True):
98
+
99
+ super().__init__()
100
+
101
+ inner_dim = dim_head * heads
102
+ self.heads = heads
103
+ pad = (kernel_size - q_stride)//2
104
+ self.to_q = SepConv2d(dim, inner_dim, kernel_size, q_stride, pad, bias)
105
+ self.to_k = SepConv2d(dim, inner_dim, kernel_size, k_stride, pad, bias)
106
+ self.to_v = SepConv2d(dim, inner_dim, kernel_size, v_stride, pad, bias)
107
+
108
+ def forward(self, x, attn_kv=None):
109
+ b, n, c, h = *x.shape, self.heads
110
+ l = int(math.sqrt(n))
111
+ w = int(math.sqrt(n))
112
+
113
+ attn_kv = x if attn_kv is None else attn_kv
114
+ x = rearrange(x, 'b (l w) c -> b c l w', l=l, w=w)
115
+ attn_kv = rearrange(attn_kv, 'b (l w) c -> b c l w', l=l, w=w)
116
+ # print(attn_kv)
117
+ q = self.to_q(x)
118
+ q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h)
119
+
120
+ k = self.to_k(attn_kv)
121
+ v = self.to_v(attn_kv)
122
+ k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h)
123
+ v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h)
124
+ return q,k,v
125
+
126
+ def flops(self, H, W):
127
+ flops = 0
128
+ flops += self.to_q.flops(H, W)
129
+ flops += self.to_k.flops(H, W)
130
+ flops += self.to_v.flops(H, W)
131
+ return flops
132
+
133
+ class LinearProjection(nn.Module):
134
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., bias=True):
135
+ super().__init__()
136
+ inner_dim = dim_head * heads
137
+ self.heads = heads
138
+ self.to_q = nn.Linear(dim, inner_dim, bias = bias)
139
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias)
140
+ self.dim = dim
141
+ self.inner_dim = inner_dim
142
+
143
+ def forward(self, x, attn_kv=None):
144
+ B_, N, C = x.shape
145
+ attn_kv = x if attn_kv is None else attn_kv
146
+ q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4).contiguous()
147
+ kv = self.to_kv(attn_kv).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4).contiguous()
148
+ q = q[0]
149
+ k, v = kv[0], kv[1]
150
+ return q,k,v
151
+
152
+ def flops(self, H, W):
153
+ flops = H*W*self.dim*self.inner_dim*3
154
+ return flops
155
+
156
+ class LinearProjection_Concat_kv(nn.Module):
157
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., bias=True):
158
+ super().__init__()
159
+ inner_dim = dim_head * heads
160
+ self.heads = heads
161
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = bias)
162
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias)
163
+ self.dim = dim
164
+ self.inner_dim = inner_dim
165
+
166
+ def forward(self, x, attn_kv=None):
167
+ B_, N, C = x.shape
168
+ attn_kv = x if attn_kv is None else attn_kv
169
+ qkv_dec = self.to_qkv(x).reshape(B_, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4).contiguous()
170
+ kv_enc = self.to_kv(attn_kv).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4).contiguous()
171
+ q, k_d, v_d = qkv_dec[0], qkv_dec[1], qkv_dec[2] # make torchscript happy (cannot use tensor as tuple)
172
+ k_e, v_e = kv_enc[0], kv_enc[1]
173
+ k = torch.cat((k_d,k_e),dim=2)
174
+ v = torch.cat((v_d,v_e),dim=2)
175
+ return q,k,v
176
+
177
+ def flops(self, H, W):
178
+ flops = H*W*self.dim*self.inner_dim*5
179
+ return flops
180
+
181
+ #########################################
182
+
183
+ ########### SIA #############
184
+ class WindowAttention(nn.Module):
185
+ def __init__(self, dim, win_size, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0.,
186
+ proj_drop=0., se_layer=False):
187
+
188
+ super().__init__()
189
+ self.dim = dim
190
+ self.win_size = win_size # Wh, Ww
191
+ self.num_heads = num_heads
192
+ head_dim = dim // num_heads
193
+ self.scale = qk_scale or head_dim ** -0.5
194
+
195
+ # define a parameter table of relative position bias
196
+ self.relative_position_bias_table = nn.Parameter(
197
+ torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
198
+
199
+ # get pair-wise relative position index for each token inside the window
200
+ coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1]
201
+ coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1]
202
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
203
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
204
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
205
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
206
+ relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0
207
+ relative_coords[:, :, 1] += self.win_size[1] - 1
208
+ relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1
209
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
210
+ self.register_buffer("relative_position_index", relative_position_index)
211
+
212
+ # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
213
+ if token_projection == 'conv':
214
+ self.qkv = ConvProjection(dim, num_heads, dim // num_heads, bias=qkv_bias)
215
+ elif token_projection == 'linear_concat':
216
+ self.qkv = LinearProjection_Concat_kv(dim, num_heads, dim // num_heads, bias=qkv_bias)
217
+ else:
218
+ self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias)
219
+
220
+ self.token_projection = token_projection
221
+ self.attn_drop = nn.Dropout(attn_drop)
222
+ self.proj = nn.Linear(dim, dim)
223
+ self.ll = nn.Identity()
224
+ self.proj_drop = nn.Dropout(proj_drop)
225
+ self.sigmoid = nn.Sigmoid()
226
+
227
+ trunc_normal_(self.relative_position_bias_table, std=.02)
228
+ self.softmax = nn.Softmax(dim=-1)
229
+
230
+ def forward(self, x, dino_mat, point_feature, normal, attn_kv=None, mask=None):
231
+ B_, N, C = x.shape
232
+
233
+ dino_mat = dino_mat.unsqueeze(2)
234
+ normalizer = torch.sqrt((dino_mat @ dino_mat.transpose(-2, -1)).squeeze(-2)).detach()
235
+ normalizer = torch.clamp(normalizer, 1.0e-20, 1.0e10)
236
+ dino_mat = dino_mat.squeeze(2) / normalizer
237
+ dino_mat_correlation_map = dino_mat @ dino_mat.transpose(-2, -1).contiguous()
238
+ dino_mat_correlation_map = torch.clamp(dino_mat_correlation_map, 0.0, 1.0e10)
239
+ dino_mat_correlation_map = torch.unsqueeze(dino_mat_correlation_map, dim=1)
240
+
241
+ point_feature = point_feature.unsqueeze(2)
242
+ Point = point_feature.repeat(1, 1, self.win_size[0] * self.win_size[1],1)
243
+ Point = Point - Point.transpose(-2, -3)
244
+ normal = normal.unsqueeze(2).repeat(1,1,self.win_size[0] * self.win_size[1],1)
245
+ # print(f'{Point.shape=}')
246
+ # print(f'{normal.shape=}')
247
+ Point = Point * normal
248
+ Point = torch.abs(torch.sum(Point, dim=3))
249
+
250
+ plane_correlation_map = 0.5 * (Point + Point.transpose(-1, -2))
251
+ plane_correlation_map = plane_correlation_map.unsqueeze(1)
252
+ plane_correlation_map = torch.exp(-plane_correlation_map)
253
+
254
+
255
+ q, k, v = self.qkv(x, attn_kv)
256
+ q = q * self.scale
257
+ attn = (q @ k.transpose(-2, -1))
258
+
259
+ attn = dino_mat_correlation_map * attn
260
+ attn = plane_correlation_map * attn
261
+
262
+
263
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
264
+ self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH
265
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
266
+ ratio = attn.size(-1) // relative_position_bias.size(-1)
267
+ relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio)
268
+
269
+ attn = attn + relative_position_bias.unsqueeze(0)
270
+
271
+ if mask is not None:
272
+ nW = mask.shape[0]
273
+ mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio)
274
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0)
275
+ attn = attn.view(-1, self.num_heads, N, N * ratio)
276
+ attn = self.softmax(attn)
277
+ else:
278
+ attn = self.softmax(attn)
279
+
280
+ attn = self.attn_drop(attn)
281
+ x = (attn @ v).transpose(1, 2).contiguous().reshape(B_, N, C)
282
+ x = self.proj(x)
283
+ x = self.ll(x)
284
+ x = self.proj_drop(x)
285
+ return x
286
+
287
+ def extra_repr(self) -> str:
288
+ return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}'
289
+
290
+ def flops(self, H, W):
291
+ # calculate flops for 1 window with token length of N
292
+ # print(N, self.dim)
293
+ flops = 0
294
+ N = self.win_size[0] * self.win_size[1]
295
+ nW = H * W / N
296
+ # qkv = self.qkv(x)
297
+ # flops += N * self.dim * 3 * self.dim
298
+ flops += self.qkv.flops(H, W)
299
+ # attn = (q @ k.transpose(-2, -1))
300
+ if self.token_projection != 'linear_concat':
301
+ flops += nW * self.num_heads * N * (self.dim // self.num_heads) * N
302
+ # x = (attn @ v)
303
+ flops += nW * self.num_heads * N * N * (self.dim // self.num_heads)
304
+ else:
305
+ flops += nW * self.num_heads * N * (self.dim // self.num_heads) * N * 2
306
+ # x = (attn @ v)
307
+ flops += nW * self.num_heads * N * N * 2 * (self.dim // self.num_heads)
308
+ # x = self.proj(x)
309
+ flops += nW * N * self.dim * self.dim
310
+ print("W-MSA:{%.2f}" % (flops / 1e9))
311
+ return flops
312
+
313
+
314
+ #########################################
315
+ ########### feed-forward network #############
316
+ class Mlp(nn.Module):
317
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
318
+ super().__init__()
319
+ out_features = out_features or in_features
320
+ hidden_features = hidden_features or in_features
321
+ self.fc1 = nn.Linear(in_features, hidden_features)
322
+ self.act = act_layer()
323
+ self.fc2 = nn.Linear(hidden_features, out_features)
324
+ self.drop = nn.Dropout(drop)
325
+ self.in_features = in_features
326
+ self.hidden_features = hidden_features
327
+ self.out_features = out_features
328
+
329
+ def forward(self, x):
330
+ x = self.fc1(x)
331
+ x = self.act(x)
332
+ x = self.drop(x)
333
+ x = self.fc2(x)
334
+ x = self.drop(x)
335
+ return x
336
+
337
+ def flops(self, H, W):
338
+ flops = 0
339
+ # fc1
340
+ flops += H*W*self.in_features*self.hidden_features
341
+ # fc2
342
+ flops += H*W*self.hidden_features*self.out_features
343
+ print("MLP:{%.2f}"%(flops/1e9))
344
+ return flops
345
+
346
+ class LeFF(nn.Module):
347
+ def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU,drop = 0.):
348
+ super().__init__()
349
+ self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim),
350
+ act_layer())
351
+ self.dwconv = nn.Sequential(nn.Conv2d(hidden_dim,hidden_dim,groups=hidden_dim,kernel_size=3,stride=1,padding=1),
352
+ act_layer())
353
+ self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim))
354
+ self.dim = dim
355
+ self.hidden_dim = hidden_dim
356
+
357
+ def forward(self, x, img_size=(128,128)):
358
+ # bs x hw x c
359
+ bs, hw, c = x.size()
360
+ # hh = int(math.sqrt(hw))
361
+ hh = img_size[0]
362
+ ww = img_size[1]
363
+
364
+ x = self.linear1(x)
365
+
366
+ # spatial restore
367
+ x = rearrange(x, ' b (h w) (c) -> b c h w ', h = hh, w = ww)
368
+ # bs,hidden_dim,32x32
369
+
370
+ x = self.dwconv(x)
371
+
372
+ # flaten
373
+ x = rearrange(x, ' b c h w -> b (h w) c', h = hh, w = ww)
374
+
375
+ x = self.linear2(x)
376
+
377
+ return x
378
+
379
+ def flops(self, H, W):
380
+ flops = 0
381
+ # fc1
382
+ flops += H*W*self.dim*self.hidden_dim
383
+ # dwconv
384
+ flops += H*W*self.hidden_dim*3*3
385
+ # fc2
386
+ flops += H*W*self.hidden_dim*self.dim
387
+ print("LeFF:{%.2f}"%(flops/1e9))
388
+ return flops
389
+
390
+ #########################################
391
+ ########### window operation#############
392
+ def window_partition(x, win_size, dilation_rate=1):
393
+ B, H, W, C = x.shape
394
+ if dilation_rate !=1:
395
+ x = x.permute(0,3,1,2).contiguous() # B, C, H, W
396
+ assert type(dilation_rate) is int, 'dilation_rate should be a int'
397
+ x = F.unfold(x, kernel_size=win_size,dilation=dilation_rate,padding=4*(dilation_rate-1),stride=win_size) # B, C*Wh*Ww, H/Wh*W/Ww
398
+ windows = x.permute(0,2,1).contiguous().view(-1, C, win_size, win_size) # B' ,C ,Wh ,Ww
399
+ windows = windows.permute(0,2,3,1).contiguous() # B' ,Wh ,Ww ,C
400
+ else:
401
+ x = x.view(B, H // win_size, win_size, W // win_size, win_size, C)
402
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) # B' ,Wh ,Ww ,C
403
+ return windows
404
+
405
+ def window_reverse(windows, win_size, H, W, dilation_rate=1):
406
+ # B' ,Wh ,Ww ,C
407
+ B = int(windows.shape[0] / (H * W / win_size / win_size))
408
+ x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1)
409
+ if dilation_rate !=1:
410
+ x = windows.permute(0,5,3,4,1,2).contiguous() # B, C*Wh*Ww, H/Wh*W/Ww
411
+ x = F.fold(x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4*(dilation_rate-1),stride=win_size)
412
+ else:
413
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
414
+ return x
415
+
416
+ #########################################
417
+ # Downsample Block
418
+ class Downsample(nn.Module):
419
+ def __init__(self, in_channel, out_channel):
420
+ super(Downsample, self).__init__()
421
+ self.conv = nn.Sequential(
422
+ nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1)
423
+ # nn.Conv2d(in_channel * 4, out_channel, kernel_size=3, padding=1)
424
+ )
425
+ self.in_channel = in_channel
426
+ self.out_channel = out_channel
427
+
428
+ def forward(self, x, img_size=(128,128)):
429
+ B, L, C = x.shape
430
+ H = img_size[0]
431
+ W = img_size[1]
432
+ x = x.transpose(1, 2).contiguous().view(B, C, H, W)
433
+
434
+ out = self.conv(x).flatten(2).transpose(1,2).contiguous() # B H*W C
435
+ return out
436
+
437
+ # def forward(self, x, img_size=(128,128)):
438
+ # B, L, C = x.shape
439
+ # H = img_size[0]
440
+ # W = img_size[1]
441
+ # x = x.transpose(1, 2).contiguous().view(B, C, H, W)
442
+
443
+ # # new add
444
+ # x = x.permute(0,2,3,1)
445
+
446
+ # x0 = x[:, 0::2, 0::2, :]
447
+ # x1 = x[:, 0::2, 1::2, :]
448
+ # x2 = x[:, 1::2, 0::2, :]
449
+ # x3 = x[:, 1::2, 1::2, :]
450
+ # x = torch.cat([x0, x1, x2, x3], axis=-1)
451
+ # x = x.permute(0,3,1,2)
452
+
453
+ # out = self.conv(x).flatten(2).transpose(1,2).contiguous() # B H*W C
454
+ # return out
455
+
456
+ def flops(self, H, W):
457
+ flops = 0
458
+ # conv
459
+ flops += H/2*W/2*self.in_channel*self.out_channel*4*4
460
+ print("Downsample:{%.2f}"%(flops/1e9))
461
+ return flops
462
+
463
+ # Upsample Block
464
+ class Upsample(nn.Module):
465
+ def __init__(self, in_channel, out_channel):
466
+ super(Upsample, self).__init__()
467
+ self.deconv = nn.Sequential(
468
+ nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2)
469
+ )
470
+
471
+ # self.conv = nn.Sequential(
472
+ # nn.Conv2d(in_channel, out_channel * 4, kernel_size=3, padding=1)
473
+ # )
474
+
475
+ self.in_channel = in_channel
476
+ self.out_channel = out_channel
477
+
478
+ def forward(self, x, img_size=(128,128)):
479
+ B, L, C = x.shape
480
+ H = img_size[0]
481
+ W = img_size[1]
482
+ x = x.transpose(1, 2).contiguous().view(B, C, H, W)
483
+ out = self.deconv(x)
484
+
485
+ out = out.flatten(2).transpose(1,2).contiguous() # B H*W C
486
+ return out
487
+
488
+ # def forward(self, x, img_size=(128,128)):
489
+ # B, L, C = x.shape
490
+ # H = img_size[0]
491
+ # W = img_size[1]
492
+ # x = x.transpose(1, 2).contiguous().view(B, C, H, W)
493
+ # out = self.conv(x)
494
+ # # new add
495
+ # pixel_shuffle = nn.PixelShuffle(2)
496
+ # out = pixel_shuffle(out)
497
+
498
+ # out = out.flatten(2).transpose(1,2).contiguous() # B H*W C
499
+ # return out
500
+
501
+ def flops(self, H, W):
502
+ flops = 0
503
+ # conv
504
+ flops += H*2*W*2*self.in_channel*self.out_channel*2*2
505
+ print("Upsample:{%.2f}"%(flops/1e9))
506
+ return flops
507
+
508
+ # Input Projection
509
+ class InputProj(nn.Module):
510
+ def __init__(self, in_channel=3, out_channel=64, kernel_size=3, stride=1, norm_layer=None,act_layer=nn.LeakyReLU):
511
+ super().__init__()
512
+ self.proj = nn.Sequential(
513
+ nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size//2),
514
+ act_layer(inplace=True)
515
+ )
516
+ if norm_layer is not None:
517
+ self.norm = norm_layer(out_channel)
518
+ else:
519
+ self.norm = None
520
+ self.in_channel = in_channel
521
+ self.out_channel = out_channel
522
+
523
+ def forward(self, x):
524
+ B, C, H, W = x.shape
525
+ x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B H*W C
526
+ if self.norm is not None:
527
+ x = self.norm(x)
528
+ return x
529
+
530
+ def flops(self, H, W):
531
+ flops = 0
532
+ # conv
533
+ flops += H*W*self.in_channel*self.out_channel*3*3
534
+
535
+ if self.norm is not None:
536
+ flops += H*W*self.out_channel
537
+ print("Input_proj:{%.2f}"%(flops/1e9))
538
+ return flops
539
+
540
+ # Output Projection
541
+ class OutputProj(nn.Module):
542
+ def __init__(self, in_channel=64, out_channel=3, kernel_size=3, stride=1, norm_layer=None,act_layer=None):
543
+ super().__init__()
544
+ self.proj = nn.Sequential(
545
+ nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size//2),
546
+ )
547
+ if act_layer is not None:
548
+ self.proj.add_module(act_layer(inplace=True))
549
+ if norm_layer is not None:
550
+ self.norm = norm_layer(out_channel)
551
+ else:
552
+ self.norm = None
553
+ self.in_channel = in_channel
554
+ self.out_channel = out_channel
555
+
556
+ def forward(self, x, img_size=(128,128)):
557
+ B, L, C = x.shape
558
+ H = img_size[0]
559
+ W = img_size[1]
560
+ # H = int(math.sqrt(L))
561
+ # W = int(math.sqrt(L))
562
+ x = x.transpose(1, 2).contiguous().view(B, C, H, W)
563
+ x = self.proj(x)
564
+ if self.norm is not None:
565
+ x = self.norm(x)
566
+ return x
567
+
568
+ def flops(self, H, W):
569
+ flops = 0
570
+ # conv
571
+ flops += H*W*self.in_channel*self.out_channel*3*3
572
+
573
+ if self.norm is not None:
574
+ flops += H*W*self.out_channel
575
+ print("Output_proj:{%.2f}"%(flops/1e9))
576
+ return flops
577
+
578
+
579
+ #########################################
580
+ ########### CA Transformer #############
581
+ class CATransformerBlock(nn.Module):
582
+ def __init__(self, dim, input_resolution, num_heads, win_size=10, shift_size=0,
583
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
584
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, token_projection='linear', token_mlp='leff',
585
+ se_layer=False):
586
+ super().__init__()
587
+ self.dim = dim
588
+ self.input_resolution = input_resolution
589
+ self.num_heads = num_heads
590
+ self.win_size = win_size
591
+ self.shift_size = shift_size
592
+ self.mlp_ratio = mlp_ratio
593
+ self.token_mlp = token_mlp
594
+ if min(self.input_resolution) <= self.win_size:
595
+ self.shift_size = 0
596
+ self.win_size = min(self.input_resolution)
597
+ assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size"
598
+
599
+ self.norm1 = norm_layer(dim)
600
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
601
+ self.norm2 = norm_layer(dim)
602
+ mlp_hidden_dim = int(dim * mlp_ratio)
603
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
604
+ drop=drop) if token_mlp == 'ffn' else LeFF(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop)
605
+ self.CAB = CAB(dim, kernel_size=3, reduction=4, bias=False, act=nn.PReLU())
606
+
607
+
608
+ def extra_repr(self) -> str:
609
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
610
+ f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
611
+
612
+ def forward(self, x, dino_mat, point, normal, mask=None, img_size=(128, 128)):
613
+ B, L, C = x.shape
614
+ H = img_size[0]
615
+ W = img_size[1]
616
+ assert L == W * H, \
617
+ f"Input image size ({H}*{W} doesn't match model ({L})."
618
+
619
+ shortcut = x
620
+ x = self.norm1(x)
621
+
622
+ # spatial restore
623
+ x = rearrange(x, ' b (h w) (c) -> b c h w ', h=H, w=W)
624
+ # bs,hidden_dim,32x32
625
+
626
+ x = self.CAB(x)
627
+
628
+ # flaten
629
+ x = rearrange(x, ' b c h w -> b (h w) c', h=H, w=W)
630
+ x = x.view(B, H * W, C)
631
+
632
+ # FFN
633
+ x = shortcut + self.drop_path(x)
634
+ x = x + self.drop_path(self.mlp(self.norm2(x), img_size=img_size))
635
+
636
+ return x
637
+
638
+ def flops(self):
639
+ flops = 0
640
+ H, W = self.input_resolution
641
+ # norm1
642
+ flops += self.dim * H * W
643
+ # W-MSA/SW-MSA
644
+ flops += self.attn.flops(H, W)
645
+ # norm2
646
+ flops += self.dim * H * W
647
+ # mlp
648
+ flops += self.mlp.flops(H, W)
649
+ print("LeWin:{%.2f}" % (flops / 1e9))
650
+ return flops
651
+
652
+ #########################################
653
+ ########### SIM Transformer #############
654
+ class SIMTransformerBlock(nn.Module):
655
+ def __init__(self, dim, input_resolution, num_heads, win_size=10, shift_size=0,
656
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
657
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm,token_projection='linear',token_mlp='leff',se_layer=False):
658
+ super().__init__()
659
+ self.dim = dim
660
+ self.input_resolution = input_resolution
661
+ self.num_heads = num_heads
662
+ self.win_size = win_size
663
+ self.shift_size = shift_size
664
+ self.mlp_ratio = mlp_ratio
665
+ self.token_mlp = token_mlp
666
+ if min(self.input_resolution) <= self.win_size:
667
+ self.shift_size = 0
668
+ self.win_size = min(self.input_resolution)
669
+ assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size"
670
+
671
+ self.norm1 = norm_layer(dim)
672
+ self.attn = WindowAttention(
673
+ dim, win_size=to_2tuple(self.win_size), num_heads=num_heads,
674
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
675
+ token_projection=token_projection,se_layer=se_layer)
676
+
677
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
678
+ self.norm2 = norm_layer(dim)
679
+ mlp_hidden_dim = int(dim * mlp_ratio)
680
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,act_layer=act_layer, drop=drop) if token_mlp=='ffn' else LeFF(dim,mlp_hidden_dim,act_layer=act_layer, drop=drop)
681
+ self.CAB = CAB(dim, kernel_size=3, reduction=4, bias=False, act=nn.PReLU())
682
+
683
+
684
+ def extra_repr(self) -> str:
685
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
686
+ f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
687
+
688
+ def forward(self, x, dino_mat, point, normal, mask=None, img_size = (128, 128)):
689
+ B, L, C = x.shape
690
+ H = img_size[0]
691
+ W = img_size[1]
692
+ assert L == W * H, \
693
+ f"Input image size ({H}*{W} doesn't match model ({L})."
694
+
695
+ C_dino_mat = dino_mat.shape[1]
696
+ C_point = point.shape[1]
697
+ C_normal = normal.shape[1]
698
+
699
+ if mask != None:
700
+ input_mask = F.interpolate(mask, size=(H,W)).permute(0,2,3,1).contiguous()
701
+ input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1
702
+ attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size
703
+ attn_mask = attn_mask.unsqueeze(2)*attn_mask.unsqueeze(1) # nW, win_size*win_size, win_size*win_size
704
+ attn_mask = attn_mask.masked_fill(attn_mask!=0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
705
+ else:
706
+ attn_mask = None
707
+
708
+ ## shift mask
709
+ if self.shift_size > 0:
710
+ # calculate attention mask for SW-MSA
711
+ shift_mask = torch.zeros((1, H, W, 1)).type_as(x)
712
+ h_slices = (slice(0, -self.win_size),
713
+ slice(-self.win_size, -self.shift_size),
714
+ slice(-self.shift_size, None))
715
+ w_slices = (slice(0, -self.win_size),
716
+ slice(-self.win_size, -self.shift_size),
717
+ slice(-self.shift_size, None))
718
+ cnt = 0
719
+ for h in h_slices:
720
+ for w in w_slices:
721
+ shift_mask[:, h, w, :] = cnt
722
+ cnt += 1
723
+ shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1
724
+ shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size
725
+ shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze(2) # nW, win_size*win_size, win_size*win_size
726
+ shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill(shift_attn_mask == 0, float(0.0))
727
+ attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask
728
+
729
+ shortcut = x
730
+ x = self.norm1(x)
731
+
732
+
733
+ x = x.view(B, H, W, C)
734
+ dino_mat = dino_mat.permute(0, 2, 3, 1).contiguous()
735
+ point = point.permute(0, 2, 3, 1).contiguous()
736
+ normal = normal.permute(0, 2, 3, 1).contiguous()
737
+
738
+ # cyclic shift
739
+ if self.shift_size > 0:
740
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
741
+ shifted_dino_mat = torch.roll(dino_mat, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
742
+ shifted_point = torch.roll(point, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
743
+ shifted_normal = torch.roll(normal, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
744
+ else:
745
+ shifted_x = x
746
+ shifted_dino_mat = dino_mat
747
+ shifted_point = point
748
+ shifted_normal = normal
749
+
750
+ x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C N*C->C
751
+ x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C
752
+
753
+
754
+ dino_mat_windows = window_partition(shifted_dino_mat, self.win_size) # nW*B, win_size, win_size, C N*C->C
755
+ dino_mat_windows = dino_mat_windows.view(-1, self.win_size * self.win_size, C_dino_mat) # nW*B, win_size*win_size, C
756
+
757
+ point_windows = window_partition(shifted_point, self.win_size) # nW*B, win_size, win_size, C N*C->C
758
+ point_windows = point_windows.view(-1, self.win_size * self.win_size, C_point) # nW*B, win_size*win_size, C
759
+
760
+ normal_windows = window_partition(shifted_normal, self.win_size) # nW*B, win_size, win_size, C N*C->C
761
+ normal_windows = normal_windows.view(-1, self.win_size * self.win_size, C_normal) # nW*B, win_size*win_size, C
762
+
763
+ # W-MSA/SW-MSA
764
+ attn_windows = self.attn(x_windows, dino_mat_windows, point_windows, normal_windows, mask=attn_mask) # nW*B, win_size*win_size, C
765
+
766
+
767
+ # merge windows
768
+ attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C)
769
+
770
+
771
+ shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C
772
+
773
+ # reverse cyclic shift
774
+ if self.shift_size > 0:
775
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
776
+ else:
777
+ x = shifted_x
778
+ x = x.view(B, H * W, C)
779
+
780
+ x = rearrange(x, ' b (h w) (c) -> b c h w ', h=H, w=W)
781
+ # bs,hidden_dim,32x32
782
+
783
+ x = self.CAB(x)
784
+
785
+ x = rearrange(x, ' b c h w -> b (h w) c', h=H, w=W)
786
+
787
+ # FFN
788
+ x = shortcut + self.drop_path(x)
789
+ x = x + self.drop_path(self.mlp(self.norm2(x), img_size=img_size))
790
+ del attn_mask
791
+ return x
792
+
793
+ def flops(self):
794
+ flops = 0
795
+ H, W = self.input_resolution
796
+ # norm1
797
+ flops += self.dim * H * W
798
+ # W-MSA/SW-MSA
799
+ flops += self.attn.flops(H, W)
800
+ # norm2
801
+ flops += self.dim * H * W
802
+ # mlp
803
+ flops += self.mlp.flops(H,W)
804
+ print("LeWin:{%.2f}"%(flops/1e9))
805
+ return flops
806
+
807
+
808
+ #########################################
809
+ ########### Basic layer of ShadowFormer ################
810
+ class BasicShadowFormer(nn.Module):
811
+ def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size,
812
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
813
+ drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False,
814
+ token_projection='linear',token_mlp='ffn',se_layer=False,cab=False):
815
+
816
+ super().__init__()
817
+ self.dim = dim
818
+ self.input_resolution = input_resolution
819
+ self.depth = depth
820
+ self.use_checkpoint = use_checkpoint
821
+ # build blocks
822
+ if cab:
823
+ self.blocks = nn.ModuleList([
824
+ CATransformerBlock(dim=dim, input_resolution=input_resolution,
825
+ num_heads=num_heads, win_size=win_size,
826
+ shift_size=0 if (i % 2 == 0) else win_size // 2,
827
+ mlp_ratio=mlp_ratio,
828
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
829
+ drop=drop, attn_drop=attn_drop,
830
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
831
+ norm_layer=norm_layer, token_projection=token_projection, token_mlp=token_mlp,
832
+ se_layer=se_layer)
833
+ for i in range(depth)])
834
+ else:
835
+ self.blocks = nn.ModuleList([
836
+ SIMTransformerBlock(dim=dim, input_resolution=input_resolution,
837
+ num_heads=num_heads, win_size=win_size,
838
+ shift_size=0 if (i % 2 == 0) else win_size // 2,
839
+ mlp_ratio=mlp_ratio,
840
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
841
+ drop=drop, attn_drop=attn_drop,
842
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
843
+ norm_layer=norm_layer,token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
844
+ for i in range(depth)])
845
+
846
+ def extra_repr(self) -> str:
847
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
848
+
849
+ def forward(self, x, dino_mat=None, point=None, normal=None, mask=None, img_size=(128,128)):
850
+ for blk in self.blocks:
851
+ if self.use_checkpoint:
852
+ x = checkpoint.checkpoint(blk, x)
853
+ else:
854
+ x = blk(x, dino_mat, point, normal, mask, img_size)
855
+ return x
856
+
857
+ def flops(self):
858
+ flops = 0
859
+ for blk in self.blocks:
860
+ flops += blk.flops()
861
+ return flops
862
+
863
+ class ShadowFormer(nn.Module):
864
+ def __init__(self, img_size=256, in_chans=3,
865
+ embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2],
866
+ win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
867
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
868
+ norm_layer=nn.LayerNorm, patch_norm=True,
869
+ use_checkpoint=False, token_projection='linear', token_mlp='leff', se_layer=True,
870
+ dowsample=Downsample, upsample=Upsample, **kwargs):
871
+ super().__init__()
872
+
873
+ self.num_enc_layers = len(depths)//2
874
+ self.num_dec_layers = len(depths)//2
875
+ self.embed_dim = embed_dim
876
+ self.patch_norm = patch_norm
877
+ self.mlp_ratio = mlp_ratio
878
+ self.token_projection = token_projection
879
+ self.mlp = token_mlp
880
+ self.win_size =win_size
881
+ self.reso = img_size
882
+ self.pos_drop = nn.Dropout(p=drop_rate)
883
+ self.DINO_channel = 1024
884
+
885
+ # stochastic depth
886
+ enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))]
887
+ conv_dpr = [drop_path_rate]*depths[4]
888
+ dec_dpr = enc_dpr[::-1]
889
+
890
+ # build layers
891
+
892
+ # Input/Output
893
+ self.input_proj = InputProj(in_channel=4, out_channel=embed_dim, kernel_size=3, stride=1, act_layer=nn.LeakyReLU)
894
+ self.output_proj = OutputProj(in_channel=2*embed_dim, out_channel=in_chans, kernel_size=3, stride=1)
895
+
896
+ # Encoder
897
+ self.encoderlayer_0 = BasicShadowFormer(dim=embed_dim,
898
+ output_dim=embed_dim,
899
+ input_resolution=(img_size,
900
+ img_size),
901
+ depth=depths[0],
902
+ num_heads=num_heads[0],
903
+ win_size=win_size,
904
+ mlp_ratio=self.mlp_ratio,
905
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
906
+ drop=drop_rate, attn_drop=attn_drop_rate,
907
+ drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])],
908
+ norm_layer=norm_layer,
909
+ use_checkpoint=use_checkpoint,
910
+ token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer,cab=True)
911
+ self.dowsample_0 = dowsample(embed_dim, embed_dim*2)
912
+ self.encoderlayer_1 = BasicShadowFormer(dim=embed_dim*2,
913
+ output_dim=embed_dim*2,
914
+ input_resolution=(img_size // 2,
915
+ img_size // 2),
916
+ depth=depths[1],
917
+ num_heads=num_heads[1],
918
+ win_size=win_size,
919
+ mlp_ratio=self.mlp_ratio,
920
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
921
+ drop=drop_rate, attn_drop=attn_drop_rate,
922
+ drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])],
923
+ norm_layer=norm_layer,
924
+ use_checkpoint=use_checkpoint,
925
+ token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer, cab=True)
926
+ self.dowsample_1 = dowsample(embed_dim*2, embed_dim*4)
927
+ self.encoderlayer_2 = BasicShadowFormer(dim=embed_dim*4,
928
+ output_dim=embed_dim*4,
929
+ input_resolution=(img_size // (2 ** 2),
930
+ img_size // (2 ** 2)),
931
+ depth=depths[2],
932
+ num_heads=num_heads[2],
933
+ win_size=win_size,
934
+ mlp_ratio=self.mlp_ratio,
935
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
936
+ drop=drop_rate, attn_drop=attn_drop_rate,
937
+ drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])],
938
+ norm_layer=norm_layer,
939
+ use_checkpoint=use_checkpoint,
940
+ token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
941
+ self.dowsample_2 = dowsample(embed_dim*4, embed_dim*8)
942
+
943
+ # Bottleneck
944
+ channel_conv = embed_dim*16
945
+ # channel_conv = embed_dim*8 if self.add_shadow_detect_dino_conact else embed_dim*4
946
+ self.conv = BasicShadowFormer(dim=channel_conv,
947
+ output_dim=channel_conv,
948
+ input_resolution=(img_size // (2 ** 3),
949
+ img_size // (2 ** 3)),
950
+ depth=depths[4],
951
+ num_heads=num_heads[4],
952
+ win_size=win_size,
953
+ mlp_ratio=self.mlp_ratio,
954
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
955
+ drop=drop_rate, attn_drop=attn_drop_rate,
956
+ drop_path=conv_dpr,
957
+ norm_layer=norm_layer,
958
+ use_checkpoint=use_checkpoint,
959
+ token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
960
+
961
+ # # Decoder
962
+ self.upsample_0 = upsample(channel_conv, embed_dim*4)
963
+ channel_0 = embed_dim*8
964
+ self.decoderlayer_0 = BasicShadowFormer(dim=channel_0,
965
+ output_dim=channel_0,
966
+ input_resolution=(img_size // (2 ** 2),
967
+ img_size // (2 ** 2)),
968
+ depth=depths[6],
969
+ num_heads=num_heads[6],
970
+ win_size=win_size,
971
+ mlp_ratio=self.mlp_ratio,
972
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
973
+ drop=drop_rate, attn_drop=attn_drop_rate,
974
+ drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])],
975
+ norm_layer=norm_layer,
976
+ use_checkpoint=use_checkpoint,
977
+ token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
978
+ self.upsample_1 = upsample(channel_0, embed_dim*2)
979
+ channel_1 = embed_dim*4
980
+ self.decoderlayer_1 = BasicShadowFormer(dim=channel_1,
981
+ output_dim=channel_1,
982
+ input_resolution=(img_size // 2,
983
+ img_size // 2),
984
+ depth=depths[7],
985
+ num_heads=num_heads[7],
986
+ win_size=win_size,
987
+ mlp_ratio=self.mlp_ratio,
988
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
989
+ drop=drop_rate, attn_drop=attn_drop_rate,
990
+ drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])],
991
+ norm_layer=norm_layer,
992
+ use_checkpoint=use_checkpoint,
993
+ token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer, cab=True)
994
+ self.upsample_2 = upsample(channel_1, embed_dim)
995
+ self.decoderlayer_2 = BasicShadowFormer(dim=embed_dim*2,
996
+ output_dim=embed_dim*2,
997
+ input_resolution=(img_size,
998
+ img_size),
999
+ depth=depths[8],
1000
+ num_heads=num_heads[8],
1001
+ win_size=win_size,
1002
+ mlp_ratio=self.mlp_ratio,
1003
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
1004
+ drop=drop_rate, attn_drop=attn_drop_rate,
1005
+ drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])],
1006
+ norm_layer=norm_layer,
1007
+ use_checkpoint=use_checkpoint,
1008
+ token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer,cab=True)
1009
+
1010
+ self.Conv = nn.Conv2d(self.DINO_channel * 4, embed_dim * 8, kernel_size=1)
1011
+ self.relu = nn.LeakyReLU()
1012
+ self.apply(self._init_weights)
1013
+
1014
+
1015
+ def _init_weights(self, m):
1016
+ if isinstance(m, nn.Linear):
1017
+ trunc_normal_(m.weight, std=.02)
1018
+ if isinstance(m, nn.Linear) and m.bias is not None:
1019
+ nn.init.constant_(m.bias, 0)
1020
+ elif isinstance(m, nn.LayerNorm):
1021
+ nn.init.constant_(m.bias, 0)
1022
+ nn.init.constant_(m.weight, 1.0)
1023
+
1024
+ @torch.jit.ignore
1025
+ def no_weight_decay(self):
1026
+ return {'absolute_pos_embed'}
1027
+
1028
+ @torch.jit.ignore
1029
+ def no_weight_decay_keywords(self):
1030
+ return {'relative_position_bias_table'}
1031
+
1032
+ def extra_repr(self) -> str:
1033
+ return f"embed_dim={self.embed_dim}, token_projection={self.token_projection}, token_mlp={self.mlp}, win_size={self.win_size}"
1034
+
1035
+ def forward(self, x, DINO_Mat_features=None, point=None, normal=None, mask=None):
1036
+ point_feature=None
1037
+ dino_mat =None
1038
+ dino_mat1=None
1039
+
1040
+ self.img_size = torch.tensor((x.shape[2], x.shape[3]))
1041
+ point_feature1 = grid_sample(point, self.img_size // 2)
1042
+ point_feature2 = grid_sample(point, self.img_size // 4)
1043
+ point_feature3 = grid_sample(point, self.img_size // 8)
1044
+ normal1= grid_sample(normal, self.img_size // 2)
1045
+ normal2= grid_sample(normal, self.img_size // 4)
1046
+ normal3= grid_sample(normal, self.img_size // 8)
1047
+
1048
+ patch_features_0 = DINO_Mat_features[0]
1049
+ patch_features_1 = DINO_Mat_features[1]
1050
+ patch_features_2 = DINO_Mat_features[2]
1051
+ patch_features_3 = DINO_Mat_features[3]
1052
+ patch_feature_all = torch.cat((patch_features_0, patch_features_1,
1053
+ patch_features_2, patch_features_3), dim=1)
1054
+
1055
+ # Get concatenate DINO Feature
1056
+ dino_mat_cat = self.Conv(patch_feature_all)
1057
+ dino_mat_cat = self.relu(dino_mat_cat)
1058
+ B, C, W, H = dino_mat_cat.shape
1059
+ dino_mat_cat_flat = dino_mat_cat.view(B, C, W * H).permute(0,2,1)
1060
+
1061
+
1062
+ dino_mat2 = F.upsample_bilinear(DINO_Mat_features[-1], scale_factor=2)
1063
+ dino_mat3 = DINO_Mat_features[-1]
1064
+
1065
+ # RGBD
1066
+ xi = torch.cat((x, point[:,2,:].unsqueeze(1)), dim=1)
1067
+
1068
+ y = self.input_proj(xi)
1069
+ y = self.pos_drop(y)
1070
+
1071
+ # Encoder
1072
+ self.img_size = (int(self.img_size[0]), int(self.img_size[1]))
1073
+ conv0 = self.encoderlayer_0(y, dino_mat, point_feature, normal, mask, img_size = self.img_size)
1074
+ pool0 = self.dowsample_0(conv0, img_size = self.img_size)
1075
+
1076
+ self.img_size = (int(self.img_size[0]/2), int(self.img_size[1]/2))
1077
+ conv1 = self.encoderlayer_1(pool0, dino_mat1, point_feature1, normal1, img_size = self.img_size)
1078
+ pool1 = self.dowsample_1(conv1, img_size = self.img_size)
1079
+
1080
+ self.img_size = (int(self.img_size[0] / 2), int(self.img_size[1] / 2))
1081
+ conv2 = self.encoderlayer_2(pool1, dino_mat2, point_feature2, normal2, img_size = self.img_size)
1082
+ pool2 = self.dowsample_2(conv2, img_size = self.img_size)
1083
+
1084
+ # Bottleneck
1085
+ self.img_size = (int(self.img_size[0] / 2), int(self.img_size[1] / 2))
1086
+ pool2 = torch.cat([pool2, dino_mat_cat_flat],-1)
1087
+ conv3 = self.conv(pool2, dino_mat3, point_feature3, normal3, img_size = self.img_size)
1088
+ print(f'{conv3.shape=}')
1089
+
1090
+ #Decoder
1091
+ up0 = self.upsample_0(conv3, img_size = self.img_size)
1092
+ self.img_size = (int(self.img_size[0] * 2), int(self.img_size[1] * 2))
1093
+ print(f'{conv2.shape=}, {up0.shape=}') # conv2.shape=torch.Size([1, 4096, 128]), up0.shape=torch.Size([1, 4096, 128])
1094
+ deconv0 = torch.cat([up0,conv2],-1)
1095
+ deconv0 = self.decoderlayer_0(deconv0, dino_mat2, point_feature2, normal2, img_size = self.img_size)
1096
+ print(f'{deconv0.shape=}')
1097
+
1098
+ up1 = self.upsample_1(deconv0, img_size = self.img_size)
1099
+ self.img_size = (int(self.img_size[0] * 2), int(self.img_size[1] * 2))
1100
+ print(f'{conv1.shape=}, {up1.shape=}') # conv1.shape=torch.Size([1, 16384, 64]), up1.shape=torch.Size([1, 16384, 64])
1101
+ deconv1 = torch.cat([up1,conv1],-1)
1102
+ deconv1 = self.decoderlayer_1(deconv1, dino_mat1, point_feature1, normal1, img_size = self.img_size)
1103
+ print(f'{deconv1.shape=}')
1104
+
1105
+ up2 = self.upsample_2(deconv1, img_size = self.img_size)
1106
+ self.img_size = (int(self.img_size[0] * 2), int(self.img_size[1] * 2))
1107
+ print(f'{conv0.shape=}, {up2.shape=}') # conv0.shape=torch.Size([1, 65536, 32]), up2.shape=torch.Size([1, 65536, 32])
1108
+ deconv2 = torch.cat([up2,conv0],-1)
1109
+ deconv2 = self.decoderlayer_2(deconv2, dino_mat, point_feature, normal, mask, img_size = self.img_size)
1110
+
1111
+ # Output Projection
1112
+ y = self.output_proj(deconv2, img_size = self.img_size) + x
1113
+ return y
1114
+
1115
+
1116
+
1117
+ class ShadowFormerFreq(nn.Module):
1118
+ def __init__(self, img_size=256, in_chans=3,
1119
+ embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2],
1120
+ win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
1121
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
1122
+ norm_layer=nn.LayerNorm, patch_norm=True,
1123
+ use_checkpoint=False, token_projection='linear', token_mlp='leff', se_layer=True,
1124
+ dowsample=Downsample, upsample=Upsample, **kwargs):
1125
+ super().__init__()
1126
+
1127
+ self.num_enc_layers = len(depths)//2
1128
+ self.num_dec_layers = len(depths)//2
1129
+ self.embed_dim = embed_dim
1130
+ self.patch_norm = patch_norm
1131
+ self.mlp_ratio = mlp_ratio
1132
+ self.token_projection = token_projection
1133
+ self.mlp = token_mlp
1134
+ self.win_size =win_size
1135
+ self.reso = img_size
1136
+ self.pos_drop = nn.Dropout(p=drop_rate)
1137
+ self.DINO_channel = 1024
1138
+
1139
+ # stochastic depth
1140
+ enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))]
1141
+ conv_dpr = [drop_path_rate]*depths[4]
1142
+ dec_dpr = enc_dpr[::-1]
1143
+
1144
+ # build layers
1145
+
1146
+ # Input/Output
1147
+ self.input_proj = InputProj(in_channel=4, out_channel=embed_dim, kernel_size=3, stride=1, act_layer=nn.LeakyReLU)
1148
+ self.output_proj = OutputProj(in_channel=2*embed_dim, out_channel=in_chans, kernel_size=3, stride=1)
1149
+
1150
+ # Encoder
1151
+ self.encoderlayer_0 = BasicShadowFormer(dim=embed_dim,
1152
+ output_dim=embed_dim,
1153
+ input_resolution=(img_size,
1154
+ img_size),
1155
+ depth=depths[0],
1156
+ num_heads=num_heads[0],
1157
+ win_size=win_size,
1158
+ mlp_ratio=self.mlp_ratio,
1159
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
1160
+ drop=drop_rate, attn_drop=attn_drop_rate,
1161
+ drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])],
1162
+ norm_layer=norm_layer,
1163
+ use_checkpoint=use_checkpoint,
1164
+ token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer,cab=True)
1165
+ self.dowsample_0 = dowsample(embed_dim, embed_dim*2)
1166
+ self.encoderlayer_1 = BasicShadowFormer(dim=embed_dim*2,
1167
+ output_dim=embed_dim*2,
1168
+ input_resolution=(img_size // 2,
1169
+ img_size // 2),
1170
+ depth=depths[1],
1171
+ num_heads=num_heads[1],
1172
+ win_size=win_size,
1173
+ mlp_ratio=self.mlp_ratio,
1174
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
1175
+ drop=drop_rate, attn_drop=attn_drop_rate,
1176
+ drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])],
1177
+ norm_layer=norm_layer,
1178
+ use_checkpoint=use_checkpoint,
1179
+ token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer, cab=True)
1180
+ self.dowsample_1 = dowsample(embed_dim*2, embed_dim*4)
1181
+ self.encoderlayer_2 = BasicShadowFormer(dim=embed_dim*4,
1182
+ output_dim=embed_dim*4,
1183
+ input_resolution=(img_size // (2 ** 2),
1184
+ img_size // (2 ** 2)),
1185
+ depth=depths[2],
1186
+ num_heads=num_heads[2],
1187
+ win_size=win_size,
1188
+ mlp_ratio=self.mlp_ratio,
1189
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
1190
+ drop=drop_rate, attn_drop=attn_drop_rate,
1191
+ drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])],
1192
+ norm_layer=norm_layer,
1193
+ use_checkpoint=use_checkpoint,
1194
+ token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
1195
+ self.dowsample_2 = dowsample(embed_dim*4, embed_dim*8)
1196
+
1197
+ # Bottleneck
1198
+ channel_conv = embed_dim*16
1199
+ # channel_conv = embed_dim*8 if self.add_shadow_detect_dino_conact else embed_dim*4
1200
+ self.conv = BasicShadowFormer(dim=channel_conv,
1201
+ output_dim=channel_conv,
1202
+ input_resolution=(img_size // (2 ** 3),
1203
+ img_size // (2 ** 3)),
1204
+ depth=depths[4],
1205
+ num_heads=num_heads[4],
1206
+ win_size=win_size,
1207
+ mlp_ratio=self.mlp_ratio,
1208
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
1209
+ drop=drop_rate, attn_drop=attn_drop_rate,
1210
+ drop_path=conv_dpr,
1211
+ norm_layer=norm_layer,
1212
+ use_checkpoint=use_checkpoint,
1213
+ token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
1214
+
1215
+ # # Decoder
1216
+ self.upsample_0 = upsample(channel_conv, embed_dim*4)
1217
+ channel_0 = embed_dim*8
1218
+ self.decoderlayer_0 = BasicShadowFormer(dim=channel_0,
1219
+ output_dim=channel_0,
1220
+ input_resolution=(img_size // (2 ** 2),
1221
+ img_size // (2 ** 2)),
1222
+ depth=depths[6],
1223
+ num_heads=num_heads[6],
1224
+ win_size=win_size,
1225
+ mlp_ratio=self.mlp_ratio,
1226
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
1227
+ drop=drop_rate, attn_drop=attn_drop_rate,
1228
+ drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])],
1229
+ norm_layer=norm_layer,
1230
+ use_checkpoint=use_checkpoint,
1231
+ token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
1232
+ self.upsample_1 = upsample(channel_0, embed_dim*2)
1233
+ channel_1 = embed_dim*4
1234
+ self.decoderlayer_1 = BasicShadowFormer(dim=channel_1,
1235
+ output_dim=channel_1,
1236
+ input_resolution=(img_size // 2,
1237
+ img_size // 2),
1238
+ depth=depths[7],
1239
+ num_heads=num_heads[7],
1240
+ win_size=win_size,
1241
+ mlp_ratio=self.mlp_ratio,
1242
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
1243
+ drop=drop_rate, attn_drop=attn_drop_rate,
1244
+ drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])],
1245
+ norm_layer=norm_layer,
1246
+ use_checkpoint=use_checkpoint,
1247
+ token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer, cab=True)
1248
+ self.upsample_2 = upsample(channel_1, embed_dim)
1249
+ self.decoderlayer_2 = BasicShadowFormer(dim=embed_dim*2,
1250
+ output_dim=embed_dim*2,
1251
+ input_resolution=(img_size,
1252
+ img_size),
1253
+ depth=depths[8],
1254
+ num_heads=num_heads[8],
1255
+ win_size=win_size,
1256
+ mlp_ratio=self.mlp_ratio,
1257
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
1258
+ drop=drop_rate, attn_drop=attn_drop_rate,
1259
+ drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])],
1260
+ norm_layer=norm_layer,
1261
+ use_checkpoint=use_checkpoint,
1262
+ token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer,cab=True)
1263
+
1264
+ self.Conv = nn.Conv2d(self.DINO_channel * 4, embed_dim * 8, kernel_size=1)
1265
+ self.relu = nn.LeakyReLU()
1266
+ self.apply(self._init_weights)
1267
+
1268
+ self.freqfusion1 = FreqFusion(hr_channels=256,
1269
+ lr_channels=512)
1270
+
1271
+ self.freqfusion2 = FreqFusion(hr_channels=128,
1272
+ lr_channels=256)
1273
+
1274
+ self.freqfusion3 = FreqFusion(hr_channels=64,
1275
+ lr_channels=128)
1276
+
1277
+ def _init_weights(self, m):
1278
+ if isinstance(m, nn.Linear):
1279
+ trunc_normal_(m.weight, std=.02)
1280
+ if isinstance(m, nn.Linear) and m.bias is not None:
1281
+ nn.init.constant_(m.bias, 0)
1282
+ elif isinstance(m, nn.LayerNorm):
1283
+ nn.init.constant_(m.bias, 0)
1284
+ nn.init.constant_(m.weight, 1.0)
1285
+
1286
+ @torch.jit.ignore
1287
+ def no_weight_decay(self):
1288
+ return {'absolute_pos_embed'}
1289
+
1290
+ @torch.jit.ignore
1291
+ def no_weight_decay_keywords(self):
1292
+ return {'relative_position_bias_table'}
1293
+
1294
+ def extra_repr(self) -> str:
1295
+ return f"embed_dim={self.embed_dim}, token_projection={self.token_projection}, token_mlp={self.mlp}, win_size={self.win_size}"
1296
+
1297
+ def forward(self, x, DINO_Mat_features=None, point=None, normal=None, mask=None):
1298
+ point_feature=None
1299
+ dino_mat =None
1300
+ dino_mat1=None
1301
+
1302
+ self.img_size = torch.tensor((x.shape[2], x.shape[3]))
1303
+ point_feature1 = grid_sample(point, self.img_size // 2)
1304
+ point_feature2 = grid_sample(point, self.img_size // 4)
1305
+ point_feature3 = grid_sample(point, self.img_size // 8)
1306
+ normal1= grid_sample(normal, self.img_size // 2)
1307
+ normal2= grid_sample(normal, self.img_size // 4)
1308
+ normal3= grid_sample(normal, self.img_size // 8)
1309
+
1310
+ patch_features_0 = DINO_Mat_features[0]
1311
+ patch_features_1 = DINO_Mat_features[1]
1312
+ patch_features_2 = DINO_Mat_features[2]
1313
+ patch_features_3 = DINO_Mat_features[3]
1314
+ patch_feature_all = torch.cat((patch_features_0, patch_features_1,
1315
+ patch_features_2, patch_features_3), dim=1)
1316
+
1317
+ # Get concatenate DINO Feature
1318
+ dino_mat_cat = self.Conv(patch_feature_all)
1319
+ dino_mat_cat = self.relu(dino_mat_cat)
1320
+ B, C, W, H = dino_mat_cat.shape
1321
+ dino_mat_cat_flat = dino_mat_cat.view(B, C, W * H).permute(0,2,1)
1322
+
1323
+
1324
+ dino_mat2 = F.upsample_bilinear(DINO_Mat_features[-1], scale_factor=2)
1325
+ dino_mat3 = DINO_Mat_features[-1]
1326
+
1327
+ # RGBD
1328
+ xi = torch.cat((x, point[:,2,:].unsqueeze(1)), dim=1)
1329
+
1330
+ y = self.input_proj(xi)
1331
+ y = self.pos_drop(y)
1332
+
1333
+ # Encoder
1334
+ self.img_size = (int(self.img_size[0]), int(self.img_size[1]))
1335
+ conv0 = self.encoderlayer_0(y, dino_mat, point_feature, normal, mask, img_size = self.img_size)
1336
+ pool0 = self.dowsample_0(conv0, img_size = self.img_size)
1337
+
1338
+ self.img_size = (int(self.img_size[0]/2), int(self.img_size[1]/2))
1339
+ conv1 = self.encoderlayer_1(pool0, dino_mat1, point_feature1, normal1, img_size = self.img_size)
1340
+ pool1 = self.dowsample_1(conv1, img_size = self.img_size)
1341
+
1342
+ self.img_size = (int(self.img_size[0] / 2), int(self.img_size[1] / 2))
1343
+ conv2 = self.encoderlayer_2(pool1, dino_mat2, point_feature2, normal2, img_size = self.img_size)
1344
+ pool2 = self.dowsample_2(conv2, img_size = self.img_size)
1345
+
1346
+ # Bottleneck
1347
+ self.img_size = (int(self.img_size[0] / 2), int(self.img_size[1] / 2))
1348
+ pool2 = torch.cat([pool2, dino_mat_cat_flat],-1)
1349
+ conv3 = self.conv(pool2, dino_mat3, point_feature3, normal3, img_size = self.img_size)
1350
+ # print(f'{conv3.shape=}') # conv3.shape=torch.Size([1, 1024, 512]) 1, 32, 32, 512
1351
+ # conv3_B_C_H_W = conv3.view(conv3.shape[0], 32, 32, 512).permute(0, 3, 1, 2)
1352
+ conv3_B_C_H_W = conv3.view(conv3.shape[0], int(conv3.shape[1]**0.5), int(conv3.shape[1]**0.5), 512).permute(0, 3, 1, 2)
1353
+ # print(f'{conv3_B_C_H_W.shape=}') # 1, 512, 32, 32
1354
+
1355
+ #Decoder
1356
+ up0 = self.upsample_0(conv3, img_size = self.img_size)
1357
+ self.img_size = (int(self.img_size[0] * 2), int(self.img_size[1] * 2))
1358
+ # print(f'1.{conv2.shape=}, {up0.shape=}') # conv2.shape=torch.Size([1, 4096, 128]), up0.shape=torch.Size([1, 4096, 128])
1359
+ deconv0 = torch.cat([up0,conv2],-1)
1360
+ deconv0 = self.decoderlayer_0(deconv0, dino_mat2, point_feature2, normal2, img_size = self.img_size)
1361
+ # print(f'1.{deconv0.shape=}') # deconv0.shape=torch.Size([1, 4096, 256]) 1, 64, 64, 256
1362
+ deconv0_B_C_H_W = deconv0.view(deconv0.shape[0], int(deconv0.shape[1]**0.5), int(deconv0.shape[1]**0.5), 256).permute(0, 3, 1, 2)
1363
+ # print(f'1.{deconv0_B_C_H_W.shape=}') # 1, 256, 64, 64
1364
+
1365
+ _, deconv0_B_C_H_W, lr_feat = self.freqfusion1(hr_feat=deconv0_B_C_H_W, lr_feat=conv3_B_C_H_W) # 1, 256, 64, 64 & 1, 512, 32, 32
1366
+ # print(f'1.{deconv0.shape=}, {lr_feat.shape=}') # deconv0.shape=torch.Size([1, 256, 64, 64]), lr_feat.shape=torch.Size([1, 512, 64, 64])
1367
+
1368
+ deconv0 = deconv0_B_C_H_W.view(deconv0_B_C_H_W.shape[0], 256, -1).permute(0, 2, 1)
1369
+
1370
+
1371
+ # print(f'1.{deconv0.shape=}') # conv2.shape=torch.Size([1, 4096, 256]) 1, 64, 64, 256
1372
+
1373
+ deconv0_B_C_H_W = deconv0.view(deconv0.shape[0], int(deconv0.shape[1]**0.5), int(deconv0.shape[1]**0.5), 256).permute(0, 3, 1, 2) # 1, 256, 64, 64
1374
+ # print(f'2.{deconv0_B_C_H_W.shape=}') # 1, 256, 64, 64
1375
+
1376
+ up1 = self.upsample_1(deconv0, img_size = self.img_size)
1377
+ self.img_size = (int(self.img_size[0] * 2), int(self.img_size[1] * 2))
1378
+ # print(f'2.{conv1.shape=}, {up1.shape=}') # conv1.shape=torch.Size([1, 16384, 64]), up1.shape=torch.Size([1, 16384, 64]) 1, 128, 128, 64
1379
+ deconv1 = torch.cat([up1,conv1],-1) # 1, 16384, 128
1380
+ deconv1 = self.decoderlayer_1(deconv1, dino_mat1, point_feature1, normal1, img_size = self.img_size)
1381
+ # print(f'2.{deconv1.shape=}') # 1, 16384, 128 1, 128, 128, 128
1382
+ deconv1_B_C_H_W = deconv1.view(deconv1.shape[0], int(deconv1.shape[1]**0.5), int(deconv1.shape[1]**0.5), 128).permute(0, 3, 1, 2)
1383
+ # print(f'2.{deconv1_B_C_H_W.shape=}') # 1, 128, 128, 128
1384
+
1385
+ _, deconv1_B_C_H_W, lr_feat = self.freqfusion2(hr_feat=deconv1_B_C_H_W, lr_feat=deconv0_B_C_H_W) # 1, 128, 128, 128 & 1, 256, 64, 64
1386
+
1387
+ # print(f'2.{deconv1_B_C_H_W.shape=}, {lr_feat.shape=}') # hr_feat.shape=torch.Size([1, 128, 128, 128]), lr_feat.shape=torch.Size([1, 256, 128, 128])
1388
+
1389
+ deconv1 = deconv1_B_C_H_W.view(deconv1_B_C_H_W.shape[0], 128, -1).permute(0, 2, 1)
1390
+ # print()
1391
+ # print(f'3.{deconv1.shape=}') # deconv1.shape=torch.Size([1, 16384, 128]) 1, 128, 128, 128
1392
+
1393
+ deconv1_B_C_H_W = deconv1.view(deconv1.shape[0], int(deconv1.shape[1]**0.5), int(deconv1.shape[1]**0.5), 128).permute(0, 3, 1, 2) # 1, 128, 128, 128
1394
+
1395
+ up2 = self.upsample_2(deconv1, img_size = self.img_size)
1396
+ self.img_size = (int(self.img_size[0] * 2), int(self.img_size[1] * 2))
1397
+ # print(f'3.{conv0.shape=}, {up2.shape=}') # conv0.shape=torch.Size([1, 65536, 32]), up2.shape=torch.Size([1, 65536, 32]) 1, 256, 256, 32
1398
+ deconv2 = torch.cat([up2,conv0],-1) # 1, 256, 256, 64
1399
+ # print(f'3.{deconv2.shape=}') # 1, 65536, 64 1, 256, 256, 64
1400
+ deconv2 = self.decoderlayer_2(deconv2, dino_mat, point_feature, normal, mask, img_size = self.img_size)
1401
+ # print(f'3.{deconv2.shape=}') # 1, 65536, 64 1, 256, 256, 64
1402
+
1403
+ deconv2_B_C_H_W = deconv2.view(deconv2.shape[0], int(deconv2.shape[1]**0.5), int(deconv2.shape[1]**0.5), 64).permute(0, 3, 1, 2)
1404
+ # print(f'3.{deconv2_B_C_H_W.shape=}')
1405
+
1406
+ _, deconv2_B_C_H_W, lr_feat = self.freqfusion3(hr_feat=deconv2_B_C_H_W, lr_feat=deconv1_B_C_H_W) # 1, 64, 256, 256 & 1, 128, 128, 128
1407
+
1408
+ # print('*'*5, f'3.{deconv2_B_C_H_W.shape=}, {lr_feat.shape=}')
1409
+
1410
+ deconv2 = deconv2_B_C_H_W.view(deconv2_B_C_H_W.shape[0], 64, -1).permute(0, 2, 1)
1411
+ # print(f'4.{deconv2.shape=}')
1412
+
1413
+ # Output Projection
1414
+ # print(f'4.{deconv2.shape=}')
1415
+ # print(f'4.{x.shape=}')
1416
+ y = self.output_proj(deconv2, img_size = self.img_size) + x
1417
+ return y
1418
+
requirements.txt ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ addict==2.4.0
3
+ aiofiles==23.2.1
4
+ aiohttp==3.8.5
5
+ aiosignal==1.3.1
6
+ altair==5.1.0
7
+ annotated-types==0.5.0
8
+ ansi2html==1.8.0
9
+ antlr4-python3-runtime==4.9.3
10
+ anyio==3.7.1
11
+ asttokens==2.4.1
12
+ async-timeout==4.0.3
13
+ attrs==23.1.0
14
+ blinker==1.6.2
15
+ boto3==1.28.62
16
+ botocore==1.31.62
17
+ bottle==0.12.25
18
+ Brotli
19
+ cachetools==5.3.1
20
+ cattrs==23.1.2
21
+ certifi
22
+ cffi
23
+ chainmap==1.0.3
24
+ chardet==5.2.0
25
+ charset-normalizer
26
+ click==8.1.7
27
+ cloudpickle==2.2.1
28
+ cmake==3.27.0
29
+ colorama
30
+ combomethod==1.0.12
31
+ comm==0.1.4
32
+ common==0.1.2
33
+ ConfigArgParse==1.7
34
+ contourpy==1.1.0
35
+ cryptography
36
+ cycler==0.11.0
37
+ Cython==3.0.5
38
+ dash==2.14.1
39
+ dash-core-components==2.0.0
40
+ dash-html-components==2.0.0
41
+ dash-table==5.0.0
42
+ data==0.4
43
+ debugpy==1.8.1
44
+ decorator==5.1.1
45
+ dual==0.0.10
46
+ dynamo3==0.4.10
47
+ easydict==1.10
48
+ efficientnet-pytorch==0.7.1
49
+ einops==0.3.2
50
+ exceptiongroup==1.1.3
51
+ executing==2.0.1
52
+ fastapi==0.103.0
53
+ fastjsonschema==2.18.1
54
+ ffmpy==0.3.1
55
+ filelock==3.12.2
56
+ Flask==2.3.3
57
+ Flask-Cors==4.0.0
58
+ flywheel==0.5.4
59
+ fonttools==4.42.0
60
+ frozenlist==1.4.0
61
+ fsspec==2023.6.0
62
+ funcsigs==1.0.2
63
+ future
64
+ fvcore
65
+ google-auth==2.22.0
66
+ google-auth-oauthlib==1.0.0
67
+ gradio
68
+ gradio_client
69
+ grpcio==1.57.0
70
+ h11==0.14.0
71
+ h5py==3.9.0
72
+ httpcore==0.17.3
73
+ httpx==0.24.1
74
+ huggingface-hub==0.16.4
75
+ idna
76
+ imageio==2.31.1
77
+ importlib-metadata==6.8.0
78
+ importlib-resources==6.0.1
79
+ iopath==0.1.9
80
+ ipython==8.17.2
81
+ ipywidgets==8.1.1
82
+ itsdangerous==2.1.2
83
+ jedi==0.19.1
84
+ Jinja2==3.1.2
85
+ jmespath==1.0.1
86
+ joblib==1.3.2
87
+ jsonschema==4.19.0
88
+ jsonschema-specifications==2023.7.1
89
+ jupyter_core==5.5.0
90
+ jupyterlab-widgets==3.0.9
91
+ kiwisolver==1.4.4
92
+ kornia==0.7.0
93
+ lazy_loader==0.3
94
+ lightning-utilities==0.9.0
95
+ lit==16.0.6
96
+ Markdown==3.4.4
97
+ markdown-it-py==3.0.0
98
+ MarkupSafe==2.1.3
99
+ matplotlib==3.3.4
100
+ matplotlib-inline==0.1.6
101
+ mdurl==0.1.2
102
+ mpmath==1.3.0
103
+ multidict==6.0.4
104
+ mypy-extensions==1.0.0
105
+ natsort==8.4.0
106
+ nbformat==5.7.0
107
+ ndim
108
+ nest-asyncio==1.5.8
109
+ networkx==3.1
110
+ ntplib==0.4.0
111
+ nulltype==2.3.1
112
+ numpy
113
+ nvidia-cublas-cu11==11.10.3.66
114
+ nvidia-cuda-cupti-cu11==11.7.101
115
+ nvidia-cuda-nvrtc-cu11==11.7.99
116
+ nvidia-cuda-runtime-cu11==11.7.99
117
+ nvidia-cudnn-cu11==8.5.0.96
118
+ nvidia-cufft-cu11==10.9.0.58
119
+ nvidia-curand-cu11==10.2.10.91
120
+ nvidia-cusolver-cu11==11.4.0.1
121
+ nvidia-cusparse-cu11==11.7.4.91
122
+ nvidia-nccl-cu11==2.14.3
123
+ nvidia-nvtx-cu11==11.7.91
124
+ oauthlib==3.2.2
125
+ omegaconf==2.3.0
126
+ open3d==0.17.0
127
+ opencv-python==4.8.0.74
128
+ options==1.4.10
129
+ orjson==3.9.5
130
+ packaging==23.1
131
+
132
+ pandas==2.0.3
133
+ parso==0.8.3
134
+ peewee==3.16.3
135
+ pexpect==4.8.0
136
+ Pillow==10.0.0
137
+ platformdirs==3.10.0
138
+ plotly==5.18.0
139
+ portalocker
140
+ progressbar2==4.2.0
141
+ prompt-toolkit==3.0.39
142
+ protobuf==4.24.2
143
+ prox==0.0.17
144
+ ptflops==0.7.2.2
145
+ ptyprocess==0.7.0
146
+ pure-eval==0.2.2
147
+ py-machineid==0.4.3
148
+ pyasn1==0.5.0
149
+ pyasn1-modules==0.3.0
150
+ pybind11==2.11.1
151
+ pycparser
152
+ pydantic==2.3.0
153
+ pydantic_core==2.6.3
154
+
155
+ pyDeprecate==0.3.2
156
+ pydub==0.25.1
157
+ Pygments==2.16.1
158
+ PyNaCl==1.5.0
159
+ pyOpenSSL
160
+ pyparsing==3.0.9
161
+ pyquaternion==0.9.9
162
+ pyre-extensions==0.0.29
163
+ PySocks
164
+ python-dateutil==2.8.2
165
+ python-geoip-python3==1.3
166
+
167
+ python-multipart==0.0.6
168
+ python-utils==3.7.0
169
+
170
+ pytz==2023.3
171
+ PyWavelets==1.4.1
172
+ PyYAML==6.0.1
173
+ referencing==0.30.2
174
+ requests
175
+ requests-cache
176
+ requests-oauthlib==1.3.1
177
+ retrying==1.3.4
178
+ rich==13.5.2
179
+ rich-argparse==1.3.0
180
+ rpds-py==0.10.0
181
+ rsa==4.9
182
+ s3transfer==0.7.0
183
+ safetensors==0.3.1
184
+ scikit-image==0.21.0
185
+ scikit-learn==1.3.0
186
+ scipy==1.11.1
187
+ seaborn==0.12.2
188
+ semantic-version==2.10.0
189
+ six==1.12.0
190
+
191
+ sniffio==1.3.0
192
+ stack-data==0.6.3
193
+ starlette==0.27.0
194
+ sympy==1.12
195
+ tabulate
196
+ tenacity==8.2.3
197
+ tensorboard==2.14.0
198
+ tensorboard-data-server==0.7.1
199
+
200
+ termcolor
201
+ thop==0.1.1.post2209072238
202
+ threadpoolctl==3.2.0
203
+ tifffile==2023.7.18
204
+ tight==0.1.0
205
+ timm==0.9.5
206
+ tomli==2.0.1
207
+ tomli_w==1.0.0
208
+ toolz==0.12.0
209
+ gdown
210
+
211
+ torchmetrics==1.1.1
212
+ torchstat==0.0.7
213
+ torchsummary==1.5.1
214
+
215
+ tqdm==4.65.0
216
+ traitlets==5.13.0
217
+ triton==2.0.0
218
+ typing-inspect==0.9.0
219
+ typing_extensions
220
+ tzdata==2023.3
221
+ url-normalize==1.4.3
222
+ urllib3==1.26.16
223
+ uvicorn==0.23.2
224
+ wcwidth==0.2.9
225
+ websockets==11.0.3
226
+ Werkzeug==2.3.7
227
+ widgetsnbextension==4.0.9
228
+ wrapt==1.15.0
229
+ x21
230
+ xformers==0.0.20
231
+ yacs
232
+ yarl==1.9.2
233
+ zipp==3.16.2
run_test.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node 1 --master_port 29508 ./test_shadow.py --save_images
test_shadow.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import argparse
4
+ from tqdm import tqdm
5
+ from torch.utils.data.distributed import DistributedSampler
6
+ import torch.nn as nn
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+ import torch.nn.functional as F
11
+ import random
12
+ # from utils.loader import get_validation_data
13
+ from utils.loader import get_test_data
14
+ import utils
15
+ import cv2
16
+ import torch.distributed as dist
17
+ from skimage.metrics import peak_signal_noise_ratio as psnr_loss
18
+ from skimage.metrics import structural_similarity as ssim_loss
19
+ parser = argparse.ArgumentParser(description='RGB denoising evaluation on the validation set of SIDD')
20
+ parser.add_argument('--input_dir', default='test_dir',
21
+ type=str, help='Directory of validation images')
22
+ parser.add_argument('--result_dir', default='./output_dir',
23
+ type=str, help='Directory for results')
24
+ parser.add_argument('--weights', default='ACVLab_shadow.pth'
25
+ ,type=str, help='Path to weights')
26
+ # parser.add_argument('--arch', default='ShadowFormer', type=str, help='arch')
27
+ parser.add_argument('--arch', type=str, default='ShadowFormerFreq', help='archtechture')
28
+ parser.add_argument('--batch_size', default=1, type=int, help='Batch size for dataloader')
29
+ parser.add_argument('--save_images', action='store_true', default=False, help='Save denoised images in result directory')
30
+ parser.add_argument('--cal_metrics', action='store_true', default=False, help='Measure denoised images with GT')
31
+ parser.add_argument('--embed_dim', type=int, default=32, help='number of data loading workers')
32
+ parser.add_argument('--win_size', type=int, default=16, help='number of data loading workers')
33
+ parser.add_argument('--token_projection', type=str, default='linear', help='linear/conv token projection')
34
+ parser.add_argument('--token_mlp', type=str,default='leff', help='ffn/leff token mlp')
35
+
36
+ parser.add_argument('--train_ps', type=int, default=256, help='patch size of training sample')
37
+ parser.add_argument("--local-rank", type=int)
38
+
39
+ args = parser.parse_args()
40
+
41
+ local_rank = args.local_rank
42
+ torch.cuda.set_device(local_rank)
43
+ dist.init_process_group(backend='nccl')
44
+ device = torch.device("cuda", local_rank)
45
+
46
+
47
+ class SlidingWindowInference:
48
+ def __init__(self, window_size=512, overlap=64, img_multiple_of=64):
49
+ self.window_size = window_size
50
+ self.overlap = overlap
51
+ self.img_multiple_of = img_multiple_of
52
+
53
+ def _pad_input(self, x, h_pad, w_pad):
54
+ """Handle padding using reflection padding"""
55
+ return F.pad(x, (0, w_pad, 0, h_pad), 'reflect')
56
+
57
+ def __call__(self, model, input_, point, normal, dino_net, device):
58
+ # Save original dimensions
59
+ original_height, original_width = input_.shape[2], input_.shape[3]
60
+ # print(f"Original size: {original_height}x{original_width}")
61
+
62
+ # Calculate minimum dimensions needed (at least window_size and multiple of img_multiple_of)
63
+ H = max(self.window_size,
64
+ ((original_height + self.img_multiple_of - 1) // self.img_multiple_of) * self.img_multiple_of)
65
+ W = max(self.window_size,
66
+ ((original_width + self.img_multiple_of - 1) // self.img_multiple_of) * self.img_multiple_of)
67
+ # print(f"Target padded size: {H}x{W}")
68
+
69
+ # Calculate required padding
70
+ padh = H - original_height
71
+ padw = W - original_width
72
+ # print(f"Padding: h={padh}, w={padw}")
73
+
74
+ # Pad all inputs
75
+ input_pad = self._pad_input(input_, padh, padw)
76
+ point_pad = self._pad_input(point, padh, padw)
77
+ normal_pad = self._pad_input(normal, padh, padw)
78
+
79
+ # If image was smaller than window_size, process it as a single window
80
+ if original_height <= self.window_size and original_width <= self.window_size:
81
+ # print("Image smaller than window size, processing as single padded window")
82
+
83
+ # For DINO features
84
+ DINO_patch_size = 14
85
+ h_size = H * DINO_patch_size // 8
86
+ w_size = W * DINO_patch_size // 8
87
+
88
+ UpSample_window = torch.nn.UpsamplingBilinear2d(size=(h_size, w_size))
89
+
90
+ # Get DINO features
91
+ with torch.no_grad():
92
+ input_DINO = UpSample_window(input_pad)
93
+ dino_features = dino_net.module.get_intermediate_layers(input_DINO, 4, True)
94
+
95
+ # Model inference
96
+ with torch.cuda.amp.autocast():
97
+ restored = model(input_pad, dino_features, point_pad, normal_pad)
98
+
99
+ # Crop back to original size
100
+ output = restored[:, :, :original_height, :original_width]
101
+ return output
102
+
103
+ # For larger images, proceed with sliding window approach
104
+ stride = self.window_size - self.overlap
105
+ h_steps = (H - self.window_size + stride - 1) // stride + 1
106
+ w_steps = (W - self.window_size + stride - 1) // stride + 1
107
+ # print(f"Steps: h={h_steps}, w={w_steps}")
108
+
109
+ # Create output tensor and counter
110
+ output = torch.zeros_like(input_pad)
111
+ count = torch.zeros_like(input_pad)
112
+
113
+ for h_idx in range(h_steps):
114
+ for w_idx in range(w_steps):
115
+ # Calculate current window position
116
+ h_start = min(h_idx * stride, H - self.window_size)
117
+ w_start = min(w_idx * stride, W - self.window_size)
118
+ h_end = h_start + self.window_size
119
+ w_end = w_start + self.window_size
120
+
121
+ # Get current window
122
+ input_window = input_pad[:, :, h_start:h_end, w_start:w_end]
123
+ point_window = point_pad[:, :, h_start:h_end, w_start:w_end]
124
+ normal_window = normal_pad[:, :, h_start:h_end, w_start:w_end]
125
+
126
+ # print(f"Processing window at ({h_idx}, {w_idx}): {input_window.shape}")
127
+
128
+ # For DINO features
129
+ DINO_patch_size = 14
130
+ h_size = self.window_size * DINO_patch_size // 8
131
+ w_size = self.window_size * DINO_patch_size // 8
132
+
133
+ UpSample_window = torch.nn.UpsamplingBilinear2d(size=(h_size, w_size))
134
+
135
+ # Get DINO features
136
+ with torch.no_grad():
137
+ input_DINO = UpSample_window(input_window)
138
+ dino_features = dino_net.module.get_intermediate_layers(input_DINO, 4, True)
139
+
140
+ # Model inference
141
+ with torch.cuda.amp.autocast():
142
+ restored = model(input_window, dino_features, point_window, normal_window)
143
+
144
+ # Create weight mask for smooth transition
145
+ weight = torch.ones_like(restored)
146
+ if self.overlap > 0:
147
+ # Create gradual weights for overlap regions
148
+ for i in range(self.overlap):
149
+ ratio = i / self.overlap
150
+ weight[:, :, i, :] *= ratio
151
+ weight[:, :, -(i+1), :] *= ratio
152
+ weight[:, :, :, i] *= ratio
153
+ weight[:, :, :, -(i+1)] *= ratio
154
+
155
+ # Accumulate results and weights
156
+ output[:, :, h_start:h_end, w_start:w_end] += restored * weight
157
+ count[:, :, h_start:h_end, w_start:w_end] += weight
158
+
159
+ # Normalize output
160
+ output = output / (count + 1e-6)
161
+
162
+ # Crop back to original size
163
+ output = output[:, :, :original_height, :original_width]
164
+ return output
165
+
166
+
167
+ utils.mkdir(args.result_dir)
168
+
169
+ # ######### Set Seeds ###########
170
+ random.seed(1234)
171
+ np.random.seed(1234)
172
+ torch.manual_seed(1234)
173
+ torch.cuda.manual_seed(1234)
174
+ torch.cuda.manual_seed_all(1234)
175
+
176
+ def worker_init_fn(worker_id):
177
+ random.seed(1234 + worker_id)
178
+
179
+ g = torch.Generator()
180
+ g.manual_seed(1234)
181
+
182
+ torch.backends.cudnn.benchmark = True
183
+ # torch.backends.cudnn.deterministic = True
184
+ ######### Model ###########
185
+ model_restoration = utils.get_arch(args)
186
+ model_restoration.to(device)
187
+ model_restoration.eval()
188
+ DINO_Net = torch.hub.load('./dinov2', 'dinov2_vitl14', source='local')
189
+ DINO_Net.to(device)
190
+ DINO_Net.eval()
191
+ ######### Load ###########
192
+ utils.load_checkpoint(model_restoration, args.weights)
193
+ print("===>Testing using weights: ", args.weights)
194
+
195
+ ######### DDP ###########
196
+
197
+ model_restoration = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_restoration).to(device)
198
+ model_restoration = DDP(model_restoration, device_ids=[local_rank], output_device=local_rank)
199
+ DINO_Net = DDP(DINO_Net, device_ids=[local_rank], output_device=local_rank)
200
+
201
+ ######### Test ###########
202
+ img_multiple_of = 8 * args.win_size
203
+ DINO_patch_size = 14
204
+
205
+ def UpSample(img):
206
+ upsample = nn.UpsamplingBilinear2d(
207
+ size=((int)(img.shape[2] * (DINO_patch_size / 8)),
208
+ (int)(img.shape[3] * (DINO_patch_size / 8))))
209
+ return upsample(img)
210
+
211
+ img_options_train = {'patch_size':args.train_ps}
212
+ test_dataset = get_test_data(args.input_dir, False)
213
+ test_sampler = DistributedSampler(test_dataset, shuffle=False)
214
+ test_loader = DataLoader(dataset=test_dataset, batch_size=1, num_workers=0, sampler=test_sampler, drop_last=False, worker_init_fn=worker_init_fn, generator=g)
215
+ with torch.no_grad():
216
+ psnr_val_rgb_list = []
217
+ psnr_val_mask_list = []
218
+ ssim_val_rgb_list = []
219
+ rmse_val_rgb_list = []
220
+ for ii, data_test in enumerate(tqdm(test_loader), 0):
221
+ # rgb_gt = data_test[0].numpy().squeeze().transpose((1, 2, 0))
222
+ rgb_noisy = data_test[1].to(device)
223
+ point = data_test[2].to(device)
224
+ normal = data_test[3].to(device)
225
+ filenames = data_test[4]
226
+
227
+ # Pad the input if not_multiple_of win_size * 8
228
+ # height, width = rgb_noisy.shape[2], rgb_noisy.shape[3]
229
+ # H, W = ((height + img_multiple_of) // img_multiple_of) * img_multiple_of, (
230
+ # (width + img_multiple_of) // img_multiple_of) * img_multiple_of
231
+
232
+ # padh = H - height if height % img_multiple_of != 0 else 0
233
+ # padw = W - width if width % img_multiple_of != 0 else 0
234
+ # rgb_noisy = F.pad(rgb_noisy, (0, padw, 0, padh), 'reflect')
235
+ # point = F.pad(point, (0, padw, 0, padh), 'reflect')
236
+ # normal = F.pad(normal, (0, padw, 0, padh), 'reflect')
237
+ # print(f'{rgb_noisy.shape=} {point.shape=} {normal.shape=}')
238
+ # UpSample_val = nn.UpsamplingBilinear2d(
239
+ # size=((int)(rgb_noisy.shape[2] * (DINO_patch_size / 8)),
240
+ # (int)(rgb_noisy.shape[3] * (DINO_patch_size / 8))))
241
+ # with torch.cuda.amp.autocast():
242
+ # # DINO_V2
243
+ # input_DINO = UpSample_val(rgb_noisy)
244
+ # dino_mat_features = DINO_Net.module.get_intermediate_layers(input_DINO, 4, True)
245
+ # rgb_restored = model_restoration(rgb_noisy, dino_mat_features, point, normal)
246
+ sliding_window = SlidingWindowInference(
247
+ window_size=512, # 與訓練相同的 patch size
248
+ overlap=64, # 相應調整 overlap
249
+ img_multiple_of=8 * args.win_size
250
+ )
251
+
252
+ with torch.cuda.amp.autocast():
253
+ rgb_restored = sliding_window(
254
+ model=model_restoration,
255
+ input_=rgb_noisy,
256
+ point=point,
257
+ normal=normal,
258
+ dino_net=DINO_Net,
259
+ device=device
260
+ )
261
+
262
+
263
+ rgb_restored = torch.clamp(rgb_restored, 0.0, 1.0)
264
+ # rgb_restored = rgb_restored[:, : ,:height, :width]
265
+ rgb_restored = torch.clamp(rgb_restored, 0, 1).cpu().numpy().squeeze().transpose((1, 2, 0))
266
+
267
+
268
+ if args.save_images:
269
+ utils.save_img(rgb_restored * 255.0, os.path.join(args.result_dir, filenames[0]))
270
+
271
+
utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .dir_utils import *
2
+ from .dataset_utils import *
3
+ from .image_utils import *
4
+ from .model_utils import *
5
+ from .shadow_mask_evaluate import *
6
+ from .tta import *
utils/antialias.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, Adobe Inc. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
4
+ # 4.0 International Public License. To view a copy of this license, visit
5
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
6
+
7
+
8
+
9
+ ######## https://github.com/adobe/antialiased-cnns/blob/master/models_lpf/__init__.py
10
+
11
+
12
+
13
+ import torch
14
+ import torch.nn.parallel
15
+ import numpy as np
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ class Downsample(nn.Module):
20
+ def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
21
+ super(Downsample, self).__init__()
22
+ self.filt_size = filt_size
23
+ self.pad_off = pad_off
24
+ self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
25
+ self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
26
+ self.stride = stride
27
+ self.off = int((self.stride-1)/2.)
28
+ self.channels = channels
29
+
30
+ # print('Filter size [%i]'%filt_size)
31
+ if(self.filt_size==1):
32
+ a = np.array([1.,])
33
+ elif(self.filt_size==2):
34
+ a = np.array([1., 1.])
35
+ elif(self.filt_size==3):
36
+ a = np.array([1., 2., 1.])
37
+ elif(self.filt_size==4):
38
+ a = np.array([1., 3., 3., 1.])
39
+ elif(self.filt_size==5):
40
+ a = np.array([1., 4., 6., 4., 1.])
41
+ elif(self.filt_size==6):
42
+ a = np.array([1., 5., 10., 10., 5., 1.])
43
+ elif(self.filt_size==7):
44
+ a = np.array([1., 6., 15., 20., 15., 6., 1.])
45
+
46
+ filt = torch.Tensor(a[:,None]*a[None,:])
47
+ filt = filt/torch.sum(filt)
48
+ self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1)))
49
+
50
+ self.pad = get_pad_layer(pad_type)(self.pad_sizes)
51
+
52
+ def forward(self, inp):
53
+ if(self.filt_size==1):
54
+ if(self.pad_off==0):
55
+ return inp[:,:,::self.stride,::self.stride]
56
+ else:
57
+ return self.pad(inp)[:,:,::self.stride,::self.stride]
58
+ else:
59
+ return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
60
+
61
+ def get_pad_layer(pad_type):
62
+ if(pad_type in ['refl','reflect']):
63
+ PadLayer = nn.ReflectionPad2d
64
+ elif(pad_type in ['repl','replicate']):
65
+ PadLayer = nn.ReplicationPad2d
66
+ elif(pad_type=='zero'):
67
+ PadLayer = nn.ZeroPad2d
68
+ else:
69
+ print('Pad type [%s] not recognized'%pad_type)
70
+ return PadLayer
71
+
72
+
73
+ class Downsample1D(nn.Module):
74
+ def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
75
+ super(Downsample1D, self).__init__()
76
+ self.filt_size = filt_size
77
+ self.pad_off = pad_off
78
+ self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
79
+ self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
80
+ self.stride = stride
81
+ self.off = int((self.stride - 1) / 2.)
82
+ self.channels = channels
83
+
84
+ # print('Filter size [%i]' % filt_size)
85
+ if(self.filt_size == 1):
86
+ a = np.array([1., ])
87
+ elif(self.filt_size == 2):
88
+ a = np.array([1., 1.])
89
+ elif(self.filt_size == 3):
90
+ a = np.array([1., 2., 1.])
91
+ elif(self.filt_size == 4):
92
+ a = np.array([1., 3., 3., 1.])
93
+ elif(self.filt_size == 5):
94
+ a = np.array([1., 4., 6., 4., 1.])
95
+ elif(self.filt_size == 6):
96
+ a = np.array([1., 5., 10., 10., 5., 1.])
97
+ elif(self.filt_size == 7):
98
+ a = np.array([1., 6., 15., 20., 15., 6., 1.])
99
+
100
+ filt = torch.Tensor(a)
101
+ filt = filt / torch.sum(filt)
102
+ self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1)))
103
+
104
+ self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
105
+
106
+ def forward(self, inp):
107
+ if(self.filt_size == 1):
108
+ if(self.pad_off == 0):
109
+ return inp[:, :, ::self.stride]
110
+ else:
111
+ return self.pad(inp)[:, :, ::self.stride]
112
+ else:
113
+ return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
114
+
115
+
116
+ def get_pad_layer_1d(pad_type):
117
+ if(pad_type in ['refl', 'reflect']):
118
+ PadLayer = nn.ReflectionPad1d
119
+ elif(pad_type in ['repl', 'replicate']):
120
+ PadLayer = nn.ReplicationPad1d
121
+ elif(pad_type == 'zero'):
122
+ PadLayer = nn.ZeroPad1d
123
+ else:
124
+ print('Pad type [%s] not recognized' % pad_type)
125
+ return PadLayer
utils/bundle_submissions.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Tobias Plötz, TU Darmstadt ([email protected])
2
+
3
+ # This file is part of the implementation as described in the CVPR 2017 paper:
4
+ # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs.
5
+ # Please see the file LICENSE.txt for the license governing this code.
6
+
7
+
8
+ import numpy as np
9
+ import scipy.io as sio
10
+ import os
11
+ import h5py
12
+
13
+ def bundle_submissions_raw(submission_folder,session):
14
+ '''
15
+ Bundles submission data for raw denoising
16
+ submission_folder Folder where denoised images reside
17
+ Output is written to <submission_folder>/bundled/. Please submit
18
+ the content of this folder.
19
+ '''
20
+
21
+ out_folder = os.path.join(submission_folder, session)
22
+ # out_folder = os.path.join(submission_folder, "bundled/")
23
+ try:
24
+ os.mkdir(out_folder)
25
+ except:pass
26
+
27
+ israw = True
28
+ eval_version="1.0"
29
+
30
+ for i in range(50):
31
+ Idenoised = np.zeros((20,), dtype=np.object)
32
+ for bb in range(20):
33
+ filename = '%04d_%02d.mat'%(i+1,bb+1)
34
+ s = sio.loadmat(os.path.join(submission_folder,filename))
35
+ Idenoised_crop = s["Idenoised_crop"]
36
+ Idenoised[bb] = Idenoised_crop
37
+ filename = '%04d.mat'%(i+1)
38
+ sio.savemat(os.path.join(out_folder, filename),
39
+ {"Idenoised": Idenoised,
40
+ "israw": israw,
41
+ "eval_version": eval_version},
42
+ )
43
+
44
+ def bundle_submissions_srgb(submission_folder,session):
45
+ '''
46
+ Bundles submission data for sRGB denoising
47
+
48
+ submission_folder Folder where denoised images reside
49
+ Output is written to <submission_folder>/bundled/. Please submit
50
+ the content of this folder.
51
+ '''
52
+ out_folder = os.path.join(submission_folder, session)
53
+ # out_folder = os.path.join(submission_folder, "bundled/")
54
+ try:
55
+ os.mkdir(out_folder)
56
+ except:pass
57
+ israw = False
58
+ eval_version="1.0"
59
+
60
+ for i in range(50):
61
+ Idenoised = np.zeros((20,), dtype=np.object)
62
+ for bb in range(20):
63
+ filename = '%04d_%02d.mat'%(i+1,bb+1)
64
+ s = sio.loadmat(os.path.join(submission_folder,filename))
65
+ Idenoised_crop = s["Idenoised_crop"]
66
+ Idenoised[bb] = Idenoised_crop
67
+ filename = '%04d.mat'%(i+1)
68
+ sio.savemat(os.path.join(out_folder, filename),
69
+ {"Idenoised": Idenoised,
70
+ "israw": israw,
71
+ "eval_version": eval_version},
72
+ )
73
+
74
+
75
+
76
+ def bundle_submissions_srgb_v1(submission_folder,session):
77
+ '''
78
+ Bundles submission data for sRGB denoising
79
+
80
+ submission_folder Folder where denoised images reside
81
+ Output is written to <submission_folder>/bundled/. Please submit
82
+ the content of this folder.
83
+ '''
84
+ out_folder = os.path.join(submission_folder, session)
85
+ # out_folder = os.path.join(submission_folder, "bundled/")
86
+ try:
87
+ os.mkdir(out_folder)
88
+ except:pass
89
+ israw = False
90
+ eval_version="1.0"
91
+
92
+ for i in range(50):
93
+ Idenoised = np.zeros((20,), dtype=np.object)
94
+ for bb in range(20):
95
+ filename = '%04d_%d.mat'%(i+1,bb+1)
96
+ s = sio.loadmat(os.path.join(submission_folder,filename))
97
+ Idenoised_crop = s["Idenoised_crop"]
98
+ Idenoised[bb] = Idenoised_crop
99
+ filename = '%04d.mat'%(i+1)
100
+ sio.savemat(os.path.join(out_folder, filename),
101
+ {"Idenoised": Idenoised,
102
+ "israw": israw,
103
+ "eval_version": eval_version},
104
+ )
utils/dataset_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import torchvision.transforms as transforms
4
+
5
+
6
+ class Augment_RGB_torch:
7
+ ### rotate and flip
8
+ def __init__(self, rotate=0):
9
+ self.rotate = rotate
10
+ pass
11
+ def transform0(self, torch_tensor):
12
+ return torch_tensor
13
+
14
+ def transform1(self, torch_tensor):
15
+ H, W = torch_tensor.shape[1], torch_tensor.shape[2]
16
+ train_transform = transforms.Compose([
17
+ transforms.RandomRotation((self.rotate,self.rotate), interpolation=transforms.InterpolationMode.BILINEAR, expand=False),
18
+ transforms.Resize((int(H * 1.3), int(W * 1.3)), antialias=True),
19
+ # CenterCrop,if the size is larger than the original size, the excess will be filled with black
20
+ transforms.CenterCrop([H, W])
21
+ ])
22
+ return train_transform(torch_tensor)
23
+
24
+ def transform2(self, torch_tensor):
25
+ torch_tensor = torch.rot90(torch_tensor, k=1, dims=[-1,-2])
26
+ return torch_tensor
27
+ def transform3(self, torch_tensor):
28
+ torch_tensor = torch.rot90(torch_tensor, k=2, dims=[-1,-2])
29
+ return torch_tensor
30
+ def transform4(self, torch_tensor):
31
+ torch_tensor = torch.rot90(torch_tensor, k=3, dims=[-1,-2])
32
+ return torch_tensor
33
+ def transform5(self, torch_tensor):
34
+ torch_tensor = torch_tensor.flip(-2)
35
+ return torch_tensor
36
+ def transform6(self, torch_tensor):
37
+ torch_tensor = (torch.rot90(torch_tensor, k=1, dims=[-1,-2])).flip(-2)
38
+ return torch_tensor
39
+ def transform7(self, torch_tensor):
40
+ torch_tensor = (torch.rot90(torch_tensor, k=2, dims=[-1,-2])).flip(-2)
41
+ return torch_tensor
42
+ def transform8(self, torch_tensor):
43
+ torch_tensor = (torch.rot90(torch_tensor, k=3, dims=[-1,-2])).flip(-2)
44
+ return torch_tensor
45
+
46
+
utils/dir_utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from natsort import natsorted
3
+ from glob import glob
4
+
5
+ def mkdirs(paths):
6
+ if isinstance(paths, list) and not isinstance(paths, str):
7
+ for path in paths:
8
+ mkdir(path)
9
+ else:
10
+ mkdir(paths)
11
+
12
+ def mkdir(path):
13
+ if not os.path.exists(path):
14
+ os.makedirs(path)
15
+
16
+ def mknod(path):
17
+ if not os.path.exists(path):
18
+ os.mknod(path)
19
+
20
+ def get_last_path(path, session):
21
+ x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1]
22
+ return x
utils/image_utils.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pickle
4
+ import cv2
5
+ from skimage.color import rgb2lab
6
+ import os
7
+ import math
8
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
9
+
10
+ def is_numpy_file(filename):
11
+ return any(filename.endswith(extension) for extension in [".npy"])
12
+
13
+ def is_image_file(filename):
14
+ return any(filename.endswith(extension) for extension in [".jpg"])
15
+
16
+ def is_png_file(filename):
17
+ return any(filename.endswith(extension) for extension in [".png"])
18
+
19
+ def is_pkl_file(filename):
20
+ return any(filename.endswith(extension) for extension in [".pkl"])
21
+
22
+ def load_pkl(filename_):
23
+ with open(filename_, 'rb') as f:
24
+ ret_dict = pickle.load(f)
25
+ return ret_dict
26
+
27
+ def save_dict(dict_, filename_):
28
+ with open(filename_, 'wb') as f:
29
+ pickle.dump(dict_, f)
30
+
31
+ def load_npy(filepath):
32
+ img = np.load(filepath)
33
+ return img
34
+
35
+ def load_SSAO(filepath):
36
+ img = cv2.imread(filepath)
37
+ # img = cv2.resize(img, [256, 256], interpolation=cv2.INTER_AREA)
38
+ img = img.astype(np.float32)
39
+ img = img/255.
40
+ return img
41
+ def load_img(filepath):
42
+ img = cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
43
+ # img = cv2.resize(img, [256, 256], interpolation=cv2.INTER_AREA)
44
+ img = img.astype(np.float32)
45
+ img = img/255.
46
+ return img
47
+
48
+ def load_val_img(filepath):
49
+ img = cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
50
+ # img = cv2.resize(img, [256, 256], interpolation=cv2.INTER_AREA)
51
+ resized_img = img.astype(np.float32)
52
+ resized_img = resized_img/255.
53
+ return resized_img
54
+
55
+ def load_mask(filepath):
56
+ img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
57
+ # img = cv2.resize(img, [256, 256], interpolation=cv2.INTER_AREA)
58
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2YUV)
59
+ # kernel = np.ones((8,8), np.uint8)
60
+ # erosion = cv2.erode(img, kernel, iterations=1)
61
+ # dilation = cv2.dilate(img, kernel, iterations=1)
62
+ # contour = dilation - erosion
63
+ img = img.astype(np.float32)
64
+ # contour = contour.astype(np.float32)
65
+ # contour = contour/255.
66
+ img = img/255.
67
+ return img
68
+
69
+ def load_ssao(filepath):
70
+ img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
71
+ # img = cv2.resize(img, [256, 256], interpolation=cv2.INTER_AREA)
72
+ img = img.astype(np.float32)
73
+ # contour = contour.astype(np.float32)
74
+ # contour = contour/255.
75
+ img = img/255.
76
+ return img
77
+
78
+ def load_depth(filepath):
79
+ img = np.load(filepath)
80
+ # img = cv2.resize(img, [256, 256], interpolation=cv2.INTER_AREA)
81
+ # img = cv2.imread(filepath, cv2.IMREAD_UNCHANGED)
82
+ # img = img / 255
83
+ return img
84
+
85
+ def load_normal(filepath):
86
+ img = np.load(filepath).transpose(1,2,0)
87
+ # img = cv2.resize(img, [256, 256], interpolation=cv2.INTER_AREA)
88
+ return img
89
+
90
+ def load_val_mask(filepath):
91
+ img = cv2.imread(filepath, 0)
92
+ resized_img = img
93
+ # resized_img = cv2.resize(img, [256, 256], interpolation=cv2.INTER_AREA)
94
+ resized_img = resized_img.astype(np.float32)
95
+ resized_img = resized_img/255.
96
+ return resized_img
97
+
98
+ def save_img(img, filepath):
99
+ cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
100
+
101
+ def myPSNR(tar_img, prd_img):
102
+ imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
103
+ rmse = (imdff**2).mean().sqrt()
104
+ ps = 20*torch.log10(1/rmse)
105
+ return ps
106
+
107
+ # imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
108
+ # rmse = (imdff**2).mean()
109
+ # ps = 10*torch.log10(1/rmse)
110
+ # return ps
111
+
112
+ def batch_PSNR(img1, img2, average=True):
113
+ PSNR = []
114
+ for im1, im2 in zip(img1, img2):
115
+ psnr = myPSNR(im1, im2)
116
+ PSNR.append(psnr)
117
+ return sum(PSNR)/len(PSNR) if average else sum(PSNR)
118
+
119
+ def tensor2im(input_image, imtype=np.uint8):
120
+ """"Converts a Tensor array into a numpy image array.
121
+ Parameters:
122
+ input_image (tensor) -- the input image tensor array
123
+ imtype (type) -- the desired type of the converted numpy array
124
+ """
125
+ if not isinstance(input_image, np.ndarray):
126
+
127
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
128
+ image_tensor = input_image.data
129
+ else:
130
+ return input_image
131
+ image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
132
+ if image_numpy.shape[0] == 1: # grayscale to RGB
133
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
134
+ # image_numpy = image_numpy.convert('L')
135
+
136
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
137
+ else: # if it is a numpy array, do nothing
138
+ image_numpy = input_image
139
+ # image_numpy =
140
+ return np.clip(image_numpy, 0, 255).astype(imtype)
141
+
142
+ def calc_RMSE(real_img, fake_img):
143
+ # convert to LAB color space
144
+ real_lab = rgb2lab(real_img)
145
+ fake_lab = rgb2lab(fake_img)
146
+ return real_lab - fake_lab
147
+
148
+ def tensor2uint(img):
149
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
150
+ if img.ndim == 3:
151
+ img = np.transpose(img, (1, 2, 0))
152
+ return np.uint8((img*255.0).round())
153
+
154
+ def imsave(img, img_path):
155
+ img = np.squeeze(img)
156
+ if img.ndim == 3:
157
+ img = img[:, :, [2, 1, 0]]
158
+ cv2.imwrite(img_path, img)
159
+
160
+ def process_normal(normal):
161
+ normal = normal * 2.0 - 1.0
162
+ normal = normal[:,:,np.newaxis,:]
163
+ normalizer = np.sqrt(normal @ normal.transpose(0,1,3,2))
164
+ normalizer = np.squeeze(normalizer, axis=-2)
165
+ normalizer = np.clip(normalizer, 1.0e-20, 1.0e10)
166
+ normal = np.squeeze(normal, axis=-2)
167
+ normal = normal / normalizer
168
+ return normal
169
+
170
+
171
+ def depthToPoint(fov, depth):
172
+ # width = 512
173
+ # height = 512
174
+ height, width = depth.shape
175
+ fov_radians = np.deg2rad(fov)
176
+
177
+ focal_length = width / (2 * np.tan(fov_radians / 2))
178
+ fx = focal_length
179
+ fy = focal_length
180
+ cx = (width - 1) / 2.0
181
+ cy = (height - 1) / 2.0
182
+
183
+ x, y = np.meshgrid(range(width), range(height))
184
+ z = depth
185
+ x_3d = (x - cx) * z / fx
186
+ y_3d = (y - cy) * z / fy
187
+ x_3d = x_3d.astype(np.float32)
188
+ y_3d = y_3d.astype(np.float32)
189
+
190
+
191
+ point_cloud = np.stack((x_3d, y_3d, z), axis=-1)
192
+
193
+ return point_cloud
194
+
195
+ def grid_sample(input, img_size):
196
+ x = torch.linspace(-1, 1, img_size[0])
197
+ y = torch.linspace(-1, 1, img_size[1])
198
+ meshx, meshy = torch.meshgrid((x, y))
199
+ grid = torch.stack((meshy, meshx),2).unsqueeze(0).cuda()
200
+ grid = grid.repeat(input.shape[0],1,1,1)
201
+ input = torch.nn.functional.grid_sample(input, grid, mode="nearest", align_corners=False)
202
+ return input
203
+
204
+
utils/loader.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from dataset import DataLoaderTrain, DataLoaderVal, DataLoaderTest
4
+ def get_training_data(rgb_dir, img_options, debug):
5
+ assert os.path.exists(rgb_dir)
6
+ return DataLoaderTrain(rgb_dir, img_options, None, debug)
7
+
8
+ def get_validation_data(rgb_dir, debug=False):
9
+ assert os.path.exists(rgb_dir)
10
+ return DataLoaderVal(rgb_dir, None, debug)
11
+
12
+ def get_test_data(rgb_dir, debug=False):
13
+ assert os.path.exists(rgb_dir)
14
+ return DataLoaderTest(rgb_dir, None, debug)
15
+
16
+
17
+
utils/misc.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torchvision.utils as vutils
3
+
4
+
5
+ def get_np_imgrid(array, nrow=3, padding=0, pad_value=0):
6
+ '''
7
+ achieves the same function of torchvision.utils.make_grid for
8
+ numpy array
9
+ '''
10
+ # assume every image has smae size
11
+ n, h, w, c = array.shape
12
+ row_num = n // nrow + (n % nrow != 0)
13
+ gh, gw = row_num*h + padding*(row_num-1), nrow*w + padding*(nrow - 1)
14
+ grid = np.ones((gh, gw, c), dtype=array.dtype) * pad_value
15
+ for i in range(n):
16
+ grow, gcol = i // nrow, i % nrow
17
+ off_y, off_x = grow * (h + padding), gcol * (w + padding)
18
+ grid[off_y : off_y + h, off_x : off_x + w] = array[i]
19
+ return grid
20
+
21
+
22
+ def split_np_imgrid(imgrid, nimg, nrow, padding=0):
23
+ '''
24
+ reverse operation of make_grid.
25
+ args:
26
+ imgrid: HWC image grid
27
+ nimg: number of images in the grid
28
+ nrow: number of columns in image grid
29
+ return:
30
+ images: list, contains splitted images
31
+ '''
32
+ row_num = nimg // nrow + (nimg % nrow != 0)
33
+ gh, gw, _ = imgrid.shape
34
+ h, w = (gh - (row_num-1)*padding)//row_num, (gw - (nrow-1)*padding)//nrow
35
+ images = []
36
+ for gid in range(nimg):
37
+ grow, gcol = gid // nrow, gid % nrow
38
+ off_i, off_j = grow * (h + padding), gcol * (w + padding)
39
+ images.append(imgrid[off_i:off_i+h, off_j:off_j+w])
40
+ return images
41
+
42
+
43
+ class MDTableConvertor:
44
+
45
+ def __init__(self, col_num):
46
+ self.col_num = col_num
47
+
48
+ def _get_table_row(self, items):
49
+ row = ''
50
+ for item in items:
51
+ row += '| {:s} '.format(item)
52
+ row += '|\n'
53
+ return row
54
+
55
+ def convert(self, item_list, title=None):
56
+ '''
57
+ args:
58
+ item_list: a list of items (str or can be converted to str)
59
+ that want to be presented in table.
60
+
61
+ title: None, or a list of strings. When set to None, empty title
62
+ row is used and column number is determined by col_num; Otherwise,
63
+ it will be used as title row, its length will override col_num.
64
+
65
+ return:
66
+ table: markdown table string.
67
+ '''
68
+ table = ''
69
+ if title: # not None or not [] both equal to true
70
+ col_num = len(title)
71
+ table += self._get_table_row(title)
72
+ else:
73
+ col_num=self.col_num
74
+ table += self._get_table_row([' ']*col_num) # empty title row
75
+ table += self._get_table_row(['-'] * col_num) # header spliter
76
+ for i in range(0, len(item_list), col_num):
77
+ table += self._get_table_row(item_list[i:i+col_num])
78
+ return table
79
+
80
+
81
+ def visual_dict_to_imgrid(visual_dict, col_num=4, padding=0):
82
+ '''
83
+ args:
84
+ visual_dict: a dictionary of images of the same size
85
+ col_num: number of columns in image grid
86
+ padding: number of padding pixels to seperate images
87
+ '''
88
+ im_names = []
89
+ im_tensors = []
90
+ for name, visual in visual_dict.items():
91
+ im_names.append(name)
92
+ im_tensors.append(visual)
93
+ im_grid = vutils.make_grid(im_tensors,
94
+ nrow=col_num ,
95
+ padding=0,
96
+ pad_value=1.0)
97
+ layout = MDTableConvertor(col_num).convert(im_names)
98
+
99
+ return im_grid, layout
100
+
101
+
102
+ def count_parameters(model, trainable_only=False):
103
+ return sum(p.numel() for p in model.parameters())
104
+
105
+
106
+
107
+ class WarmupExpLRScheduler(object):
108
+ def __init__(self, lr_start=1e-4, lr_max=4e-4, lr_min=5e-6, rampup_epochs=4, sustain_epochs=0, exp_decay=0.75):
109
+ self.lr_start = lr_start
110
+ self.lr_max = lr_max
111
+ self.lr_min = lr_min
112
+ self.rampup_epochs = rampup_epochs
113
+ self.sustain_epochs = sustain_epochs
114
+ self.exp_decay = exp_decay
115
+
116
+ def __call__(self, epoch):
117
+ if epoch < self.rampup_epochs:
118
+ lr = (self.lr_max - self.lr_start) / self.rampup_epochs * epoch + self.lr_start
119
+ elif epoch < self.rampup_epochs + self.sustain_epochs:
120
+ lr = self.lr_max
121
+ else:
122
+ lr = (self.lr_max - self.lr_min) * self.exp_decay**(epoch - self.rampup_epochs - self.sustain_epochs) + self.lr_min
123
+ # print(lr)
124
+ return lr
utils/model_utils.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from collections import OrderedDict
4
+
5
+ def freeze(model):
6
+ for p in model.parameters():
7
+ p.requires_grad=False
8
+
9
+ def unfreeze(model):
10
+ for p in model.parameters():
11
+ p.requires_grad=True
12
+
13
+ def is_frozen(model):
14
+ x = [p.requires_grad for p in model.parameters()]
15
+ return not all(x)
16
+
17
+ def save_checkpoint(model_dir, state, session):
18
+ epoch = state['epoch']
19
+ model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
20
+ torch.save(state, model_out_path)
21
+
22
+ def load_checkpoint(model, weights, strict=True):
23
+ checkpoint = torch.load(weights, map_location=torch.device('cpu'))
24
+ try:
25
+ state_dict = checkpoint["state_dict"]
26
+ new_state_dict = OrderedDict()
27
+ for k, v in state_dict.items():
28
+ new_state_dict[k] = v
29
+ model.load_state_dict(new_state_dict, strict=strict)
30
+ except:
31
+ state_dict = checkpoint["state_dict"]
32
+ new_state_dict = OrderedDict()
33
+ for k, v in state_dict.items():
34
+ name = k[7:] if 'module.' in k else k
35
+ new_state_dict[name] = v
36
+ model.load_state_dict(new_state_dict, strict=strict)
37
+
38
+ def load_checkpoint_multigpu(model, weights):
39
+ checkpoint = torch.load(weights)
40
+ state_dict = checkpoint["state_dict"]
41
+ new_state_dict = OrderedDict()
42
+ for k, v in state_dict.items():
43
+ name = k[7:]
44
+ new_state_dict[name] = v
45
+ model.load_state_dict(new_state_dict)
46
+
47
+ def load_start_epoch(weights):
48
+ checkpoint = torch.load(weights, map_location=torch.device('cpu'))
49
+ epoch = checkpoint["epoch"]
50
+ return epoch
51
+
52
+ def load_optim(optimizer, weights):
53
+ checkpoint = torch.load(weights, map_location=torch.device('cpu'))
54
+ optimizer.load_state_dict(checkpoint['optimizer'])
55
+ for p in optimizer.param_groups: lr = p['lr']
56
+ return lr
57
+
58
+ def get_arch(opt):
59
+ from model import ShadowFormer, ShadowFormerFreq
60
+ arch = opt.arch
61
+
62
+ print('You choose '+arch+'...')
63
+ if arch == 'ShadowFormer':
64
+ model_restoration = ShadowFormer(img_size=opt.train_ps,embed_dim=opt.embed_dim,
65
+ win_size=opt.win_size,token_projection=opt.token_projection,
66
+ token_mlp=opt.token_mlp)
67
+ elif arch == 'ShadowFormerFreq':
68
+ model_restoration = ShadowFormerFreq(img_size=opt.train_ps,embed_dim=opt.embed_dim,
69
+ win_size=opt.win_size,token_projection=opt.token_projection,
70
+ token_mlp=opt.token_mlp)
71
+ else:
72
+ raise Exception("Arch error!")
73
+
74
+ return model_restoration
75
+
76
+
77
+ def window_partition(x, win_size):
78
+ B, C, H, W = x.shape
79
+ x = x.permute(0,2,3,1)
80
+ x = x.reshape(B, H // win_size, win_size, W // win_size, win_size, C)
81
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, win_size, win_size, C)
82
+ return x.permute(0,3,1,2)
83
+
84
+ # def distributed_concat(var, num_total):
85
+ # var_list = [torch.zeros(1, dtype=var.dtype).cuda() for _ in range(torch.distributed.get_world_size())]
86
+ # torch.distributed.all_gather(var_list, var)
87
+ # # truncate the dummy elements added by SequentialDistributedSampler
88
+ # return var_list[:num_total]
89
+
90
+ def distributed_concat(var, num_total):
91
+ # 確保 var 是一個 1D tensor (shape: [1])
92
+ var = var.view(1) if var.dim() == 0 else var
93
+
94
+ var_list = [torch.zeros_like(var).cuda() for _ in range(torch.distributed.get_world_size())]
95
+ torch.distributed.all_gather(var_list, var)
96
+
97
+ # truncate the dummy elements added by SequentialDistributedSampler
98
+ return var_list[:num_total]
utils/shadow_mask_evaluate.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from collections import OrderedDict
4
+ import pandas as pd
5
+ import os
6
+ from tqdm import tqdm
7
+ import cv2
8
+ from utils.misc import split_np_imgrid, get_np_imgrid
9
+
10
+
11
+ def cal_ber(tn, tp, fn, fp):
12
+ return 0.5*(fp/(tn+fp) + fn/(fn+tp))
13
+
14
+ def cal_acc(tn, tp, fn, fp):
15
+ return (tp + tn) / (tp + tn + fp + fn)
16
+
17
+
18
+ def get_binary_classification_metrics(pred, gt, threshold=None):
19
+ if threshold is not None:
20
+ gt = (gt > threshold)
21
+ pred = (pred > threshold)
22
+ TP = np.logical_and(gt, pred).sum()
23
+ TN = np.logical_and(np.logical_not(gt), np.logical_not(pred)).sum()
24
+ FN = np.logical_and(gt, np.logical_not(pred)).sum()
25
+ FP = np.logical_and(np.logical_not(gt), pred).sum()
26
+ BER = cal_ber(TN, TP, FN, FP)
27
+ ACC = cal_acc(TN, TP, FN, FP)
28
+ return OrderedDict( [('TP', TP),
29
+ ('TN', TN),
30
+ ('FP', FP),
31
+ ('FN', FN),
32
+ ('BER', BER),
33
+ ('ACC', ACC)]
34
+ )
35
+
36
+
37
+ def evaluate(res_root, pred_id, gt_id, nimg, nrow, threshold):
38
+ img_names = os.listdir(res_root)
39
+ score_dict = OrderedDict()
40
+
41
+ for img_name in img_names:
42
+ im_grid_path = os.path.join(res_root, img_name)
43
+ im_grid = cv2.imread(im_grid_path)
44
+ ims = split_np_imgrid(im_grid, nimg, nrow)
45
+ pred = ims[pred_id]
46
+ gt = ims[gt_id]
47
+ score_dict[img_name] = get_binary_classification_metrics(pred,
48
+ gt,
49
+ threshold)
50
+
51
+ df = pd.DataFrame(score_dict)
52
+ df['ave'] = df.mean(axis=1)
53
+
54
+ tn = df['ave']['TN']
55
+ tp = df['ave']['TP']
56
+ fn = df['ave']['FN']
57
+ fp = df['ave']['FP']
58
+
59
+ pos_err = (1 - tp / (tp + fn)) * 100
60
+ neg_err = (1 - tn / (tn + fp)) * 100
61
+ ber = (pos_err + neg_err) / 2
62
+ acc = (tn + tp) / (tn + tp + fn + fp)
63
+
64
+ return pos_err, neg_err, ber, acc, df
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+ ###############################################
73
+
74
+ class AverageMeter(object):
75
+ """Computes and stores the average and current value"""
76
+ def __init__(self):
77
+ self.sum = 0
78
+ self.count = 0
79
+
80
+ def update(self, val, weight=1):
81
+ self.sum += val * weight
82
+ self.count += weight
83
+
84
+ def average(self):
85
+ if self.count == 0:
86
+ return 0
87
+ else:
88
+ return self.sum / self.count
89
+
90
+ def clear(self):
91
+ self.sum = 0
92
+ self.count = 0
93
+
94
+ def compute_cm_torch(y_pred, y_label, n_class):
95
+ mask = (y_label >= 0) & (y_label < n_class)
96
+ hist = torch.bincount(n_class * y_label[mask] + y_pred[mask],
97
+ minlength=n_class**2).reshape(n_class, n_class)
98
+ return hist
99
+
100
+ class MyConfuseMatrixMeter(AverageMeter):
101
+ """More Clear Confusion Matrix Meter"""
102
+ def __init__(self, n_class):
103
+ super(MyConfuseMatrixMeter, self).__init__()
104
+ self.n_class = n_class
105
+
106
+ def update_cm(self, y_pred, y_label, weight=1):
107
+ y_label = y_label.type(torch.int64)
108
+ val = compute_cm_torch(y_pred=y_pred.flatten(), y_label=y_label.flatten(),
109
+ n_class=self.n_class)
110
+ self.update(val, weight)
111
+
112
+ # def get_scores_binary(self):
113
+ # assert self.n_class == 2, "this function can only be called for binary calssification problem"
114
+ # tn, fp, fn, tp = self.sum.flatten()
115
+ # eps = torch.finfo(torch.float32).eps
116
+ # precision = tp / (tp + fp + eps)
117
+ # recall = tp / (tp + fn + eps)
118
+ # f1 = 2*recall*precision / (recall + precision + eps)
119
+ # iou = tp / (tp + fn + fp + eps)
120
+ # oa = (tp + tn) / (tp + tn + fn + fp + eps)
121
+ # score_dict = {}
122
+ # score_dict['precision'] = precision.item()
123
+ # score_dict['recall'] = recall.item()
124
+ # score_dict['f1'] = f1.item()
125
+ # score_dict['iou'] = iou.item()
126
+ # score_dict['oa'] = oa.item()
127
+ # return score_dict
128
+ def get_scores_binary(self):
129
+ assert self.n_class == 2, "this function can only be called for binary calssification problem"
130
+ tn, fp, fn, tp = self.sum.flatten()
131
+ eps = torch.finfo(torch.float32).eps
132
+ pos_err = (1 - tp / (tp + fn + eps)) * 100
133
+ neg_err = (1 - tn / (tn + fp + eps)) * 100
134
+ ber = (pos_err + neg_err) / 2
135
+ acc = (tn + tp) / (tn + tp + fn + fp + eps)
136
+ score_dict = {}
137
+ score_dict['pos_err'] = pos_err
138
+ score_dict['neg_err'] = neg_err
139
+ score_dict['ber'] = ber
140
+ score_dict['acc'] = acc
141
+ return score_dict
utils/tta.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class TestTimeAugmentation:
5
+ """Test-Time Augmentation for image restoration models"""
6
+
7
+ def __init__(self, model, dino_net, device, use_flip=True, use_rot=True, use_multi_scale=False, scales=None):
8
+ """
9
+ Args:
10
+ model: The model to apply TTA to
11
+ dino_net: DINO feature extractor
12
+ device: Device to run inference on
13
+ use_flip: Whether to use horizontal and vertical flips
14
+ use_rot: Whether to use 90-degree rotations
15
+ use_multi_scale: Whether to use multi-scale testing
16
+ scales: List of scales to use for multi-scale testing, e.g. [0.8, 1.0, 1.2]
17
+ """
18
+ self.model = model
19
+ self.dino_net = dino_net
20
+ self.device = device
21
+ self.use_flip = use_flip
22
+ self.use_rot = use_rot
23
+ self.use_multi_scale = use_multi_scale
24
+ self.scales = scales or [1.0]
25
+
26
+ def _apply_augmentation(self, image, point, normal, aug_type):
27
+ """Apply single augmentation to input images
28
+
29
+ Args:
30
+ image: Input RGB image
31
+ point: Point map
32
+ normal: Normal map
33
+ aug_type: Augmentation type string (e.g., 'original', 'h_flip', etc.)
34
+
35
+ Returns:
36
+ Augmented versions of image, point map and normal map
37
+ """
38
+ if aug_type == 'original':
39
+ return image, point, normal
40
+
41
+ elif aug_type == 'h_flip':
42
+ # Horizontal flip
43
+ img_aug = torch.flip(image, dims=[3])
44
+ point_aug = torch.flip(point, dims=[3])
45
+ normal_aug = torch.flip(normal, dims=[3])
46
+ # For normal map, x direction needs to be flipped
47
+ normal_aug[:, 0, :, :] = -normal_aug[:, 0, :, :]
48
+ return img_aug, point_aug, normal_aug
49
+
50
+ elif aug_type == 'v_flip':
51
+ # Vertical flip
52
+ img_aug = torch.flip(image, dims=[2])
53
+ point_aug = torch.flip(point, dims=[2])
54
+ normal_aug = torch.flip(normal, dims=[2])
55
+ # For normal map, y direction needs to be flipped
56
+ normal_aug[:, 1, :, :] = -normal_aug[:, 1, :, :]
57
+ return img_aug, point_aug, normal_aug
58
+
59
+ elif aug_type == 'rot90':
60
+ # 90-degree rotation
61
+ img_aug = torch.rot90(image, k=1, dims=[2, 3])
62
+ point_aug = torch.rot90(point, k=1, dims=[2, 3])
63
+ normal_aug = torch.rot90(normal, k=1, dims=[2, 3])
64
+ # Swap x and y channels in normal map and negate x
65
+ normal_x = -normal_aug[:, 1, :, :].clone()
66
+ normal_y = normal_aug[:, 0, :, :].clone()
67
+ normal_aug[:, 0, :, :] = normal_x
68
+ normal_aug[:, 1, :, :] = normal_y
69
+ return img_aug, point_aug, normal_aug
70
+
71
+ elif aug_type == 'rot180':
72
+ # 180-degree rotation
73
+ img_aug = torch.rot90(image, k=2, dims=[2, 3])
74
+ point_aug = torch.rot90(point, k=2, dims=[2, 3])
75
+ normal_aug = torch.rot90(normal, k=2, dims=[2, 3])
76
+ # For normal map, both x and y directions need to be flipped
77
+ normal_aug[:, 0, :, :] = -normal_aug[:, 0, :, :]
78
+ normal_aug[:, 1, :, :] = -normal_aug[:, 1, :, :]
79
+ return img_aug, point_aug, normal_aug
80
+
81
+ elif aug_type == 'rot270':
82
+ # 270-degree rotation
83
+ img_aug = torch.rot90(image, k=3, dims=[2, 3])
84
+ point_aug = torch.rot90(point, k=3, dims=[2, 3])
85
+ normal_aug = torch.rot90(normal, k=3, dims=[2, 3])
86
+ # Swap x and y channels in normal map and negate y
87
+ normal_x = normal_aug[:, 1, :, :].clone()
88
+ normal_y = -normal_aug[:, 0, :, :].clone()
89
+ normal_aug[:, 0, :, :] = normal_x
90
+ normal_aug[:, 1, :, :] = normal_y
91
+ return img_aug, point_aug, normal_aug
92
+
93
+ else:
94
+ raise ValueError(f"Unknown augmentation type: {aug_type}")
95
+
96
+ def _reverse_augmentation(self, result, aug_type):
97
+ """Reverse the augmentation on the result
98
+
99
+ Args:
100
+ result: Model output to reverse augmentation on
101
+ aug_type: Augmentation type string
102
+
103
+ Returns:
104
+ De-augmented result
105
+ """
106
+ if aug_type == 'original':
107
+ return result
108
+
109
+ elif aug_type == 'h_flip':
110
+ return torch.flip(result, dims=[3])
111
+
112
+ elif aug_type == 'v_flip':
113
+ return torch.flip(result, dims=[2])
114
+
115
+ elif aug_type == 'rot90':
116
+ return torch.rot90(result, k=3, dims=[2, 3])
117
+
118
+ elif aug_type == 'rot180':
119
+ return torch.rot90(result, k=2, dims=[2, 3])
120
+
121
+ elif aug_type == 'rot270':
122
+ return torch.rot90(result, k=1, dims=[2, 3])
123
+
124
+ else:
125
+ raise ValueError(f"Unknown augmentation type: {aug_type}")
126
+
127
+ def __call__(self, sliding_window, input_img, point, normal):
128
+ """
129
+ Apply TTA to the model and return ensemble result
130
+
131
+ Args:
132
+ sliding_window: SlidingWindowInference class instance
133
+ input_img: Input RGB image [B, C, H, W]
134
+ point: Point map [B, C, H, W]
135
+ normal: Normal map [B, C, H, W]
136
+
137
+ Returns:
138
+ Ensemble result with TTA [B, C, H, W]
139
+ """
140
+ # Define all augmentations to use
141
+ augmentations = ['original']
142
+ if self.use_flip:
143
+ augmentations.extend(['h_flip', 'v_flip'])
144
+ if self.use_rot:
145
+ augmentations.extend(['rot90', 'rot180', 'rot270'])
146
+
147
+ # Initialize the result tensor
148
+ ensemble_result = torch.zeros_like(input_img)
149
+ ensemble_weight = 0.0
150
+
151
+ # For each scale and augmentation
152
+ for scale in self.scales:
153
+ scale_weight = 1.0
154
+ if scale != 1.0:
155
+ # Resize inputs for multi-scale testing
156
+ h, w = input_img.shape[2], input_img.shape[3]
157
+ new_h, new_w = int(h * scale), int(w * scale)
158
+
159
+ # Resize all inputs
160
+ resize_fn = torch.nn.functional.interpolate
161
+ input_img_scaled = resize_fn(input_img, size=(new_h, new_w), mode='bilinear', align_corners=False)
162
+ point_scaled = resize_fn(point, size=(new_h, new_w), mode='bilinear', align_corners=False)
163
+ normal_scaled = resize_fn(normal, size=(new_h, new_w), mode='bilinear', align_corners=False)
164
+
165
+ # Normalize normal vectors after resizing
166
+ normal_norm = torch.sqrt(torch.sum(normal_scaled**2, dim=1, keepdim=True) + 1e-6)
167
+ normal_scaled = normal_scaled / normal_norm
168
+ else:
169
+ input_img_scaled = input_img
170
+ point_scaled = point
171
+ normal_scaled = normal
172
+
173
+ # Apply each augmentation
174
+ for aug_type in augmentations:
175
+ # Apply augmentation
176
+ img_aug, point_aug, normal_aug = self._apply_augmentation(
177
+ input_img_scaled, point_scaled, normal_scaled, aug_type
178
+ )
179
+
180
+ # Run model inference with sliding window
181
+ with torch.cuda.amp.autocast():
182
+ result_aug = sliding_window(
183
+ model=self.model,
184
+ input_=img_aug,
185
+ point=point_aug,
186
+ normal=normal_aug,
187
+ dino_net=self.dino_net,
188
+ device=self.device
189
+ )
190
+
191
+ # Reverse augmentation on the result
192
+ result_aug = self._reverse_augmentation(result_aug, aug_type)
193
+
194
+ # Resize back to original size if using multi-scale
195
+ if scale != 1.0:
196
+ result_aug = resize_fn(result_aug, size=(h, w), mode='bilinear', align_corners=False)
197
+
198
+ # Add to ensemble
199
+ ensemble_result += result_aug * scale_weight
200
+ ensemble_weight += scale_weight
201
+
202
+ # Average results
203
+ ensemble_result = ensemble_result / ensemble_weight
204
+
205
+ return ensemble_result