endo-yuki-t commited on
Commit
d7dbcdd
·
0 Parent(s):

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +51 -0
  3. criteria/__init__.py +0 -0
  4. criteria/lpips/__init__.py +0 -0
  5. criteria/lpips/lpips.py +35 -0
  6. criteria/lpips/networks.py +96 -0
  7. criteria/lpips/utils.py +30 -0
  8. docs/teaser.jpg +0 -0
  9. docs/thumb.gif +0 -0
  10. env.yaml +380 -0
  11. expansion/__init__.py +0 -0
  12. expansion/dataloader/__init__.py +0 -0
  13. expansion/dataloader/__pycache__/__init__.cpython-38.pyc +0 -0
  14. expansion/dataloader/__pycache__/seqlist.cpython-38.pyc +0 -0
  15. expansion/dataloader/chairslist.py +33 -0
  16. expansion/dataloader/chairssdlist.py +30 -0
  17. expansion/dataloader/depth_transforms.py +471 -0
  18. expansion/dataloader/depthloader.py +222 -0
  19. expansion/dataloader/flow_transforms.py +440 -0
  20. expansion/dataloader/hd1klist.py +29 -0
  21. expansion/dataloader/kitti12list.py +29 -0
  22. expansion/dataloader/kitti15list.py +29 -0
  23. expansion/dataloader/kitti15list_train.py +31 -0
  24. expansion/dataloader/kitti15list_train_lidar.py +34 -0
  25. expansion/dataloader/kitti15list_val.py +31 -0
  26. expansion/dataloader/kitti15list_val_lidar.py +34 -0
  27. expansion/dataloader/kitti15list_val_mr.py +41 -0
  28. expansion/dataloader/robloader.py +133 -0
  29. expansion/dataloader/sceneflowlist.py +51 -0
  30. expansion/dataloader/seqlist.py +26 -0
  31. expansion/dataloader/sintellist.py +32 -0
  32. expansion/dataloader/sintellist_clean.py +31 -0
  33. expansion/dataloader/sintellist_final.py +32 -0
  34. expansion/dataloader/sintellist_train.py +32 -0
  35. expansion/dataloader/sintellist_val.py +34 -0
  36. expansion/dataloader/thingslist.py +122 -0
  37. expansion/models/VCN_exp.py +561 -0
  38. expansion/models/__init__.py +0 -0
  39. expansion/models/__pycache__/VCN_exp.cpython-38.pyc +0 -0
  40. expansion/models/__pycache__/__init__.cpython-38.pyc +0 -0
  41. expansion/models/__pycache__/conv4d.cpython-38.pyc +0 -0
  42. expansion/models/__pycache__/submodule.cpython-38.pyc +0 -0
  43. expansion/models/conv4d.py +296 -0
  44. expansion/models/submodule.py +450 -0
  45. expansion/submission.py +95 -0
  46. expansion/utils/__init__.py +0 -0
  47. expansion/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  48. expansion/utils/__pycache__/flowlib.cpython-38.pyc +0 -0
  49. expansion/utils/__pycache__/io.cpython-38.pyc +0 -0
  50. expansion/utils/__pycache__/pfm.cpython-38.pyc +0 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Yuki Endo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # User-Controllable Latent Transformer for StyleGAN Image Layout Editing
2
+ <!--a href="https://arxiv.org/abs/2103.14877"><img src="https://img.shields.io/badge/arXiv-2103.14877-b31b1b.svg"></a-->
3
+ <a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
4
+ <p align="center">
5
+ <img src="docs/teaser.jpg" width="800px"/>
6
+ </p>
7
+
8
+ This repository contains our implementation of the following paper:
9
+
10
+ Yuki Endo: "User-Controllable Latent Transformer for StyleGAN Image Layout Editing," Computer Graphpics Forum (Pacific Graphics 2022) [[Project](http://www.cgg.cs.tsukuba.ac.jp/~endo/projects/UserControllableLT)] [[PDF (preprint)]()]
11
+
12
+ ## Prerequisites
13
+ 1. Python 3.8
14
+ 2. PyTorch 1.9.0
15
+ 3. Flask
16
+ 4. Others (see env.yml)
17
+
18
+ ## Preparation
19
+ Download and decompress <a href="https://drive.google.com/file/d/1lBL_J-uROvqZ0BYu9gmEcMCNyaPo9cBY/view?usp=sharing">our pre-trained models</a>.
20
+
21
+ ## Inference with our pre-trained models
22
+ <img src="docs/thumb.gif" width="150px"/><br>
23
+ We provide an interactive interface based on Flask. This interface can be locally launched with
24
+ ```
25
+ python interface/flask_app.py --checkpoint_path=pretrained_models/latent_transformer/cat.pt
26
+ ```
27
+ The interface can be accessed via http://localhost:8000/.
28
+
29
+ ## Training
30
+ The latent transformer can be trained with
31
+ ```
32
+ python scripts/train.py --exp_dir=results --stylegan_weights=pretrained_models/stylegan2-cat-config-f.pt
33
+ ```
34
+
35
+ ## Citation
36
+ Please cite our paper if you find the code useful:
37
+ ```
38
+ @Article{endoPG2022,
39
+ Title = {User-Controllable Latent Transformer for StyleGAN Image Layout Editing},
40
+ Author = {Yuki Endo},
41
+ Journal = {Computer Graphics Forum},
42
+ volume = {},
43
+ number = {},
44
+ pages = {},
45
+ doi = {},
46
+ Year = {2022}
47
+ }
48
+ ```
49
+
50
+ ## Acknowledgements
51
+ This code heavily borrows from the [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel) and [expansion](https://github.com/gengshan-y/expansion) repositories.
criteria/__init__.py ADDED
File without changes
criteria/lpips/__init__.py ADDED
File without changes
criteria/lpips/lpips.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from criteria.lpips.networks import get_network, LinLayers
5
+ from criteria.lpips.utils import get_state_dict
6
+
7
+
8
+ class LPIPS(nn.Module):
9
+ r"""Creates a criterion that measures
10
+ Learned Perceptual Image Patch Similarity (LPIPS).
11
+ Arguments:
12
+ net_type (str): the network type to compare the features:
13
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
14
+ version (str): the version of LPIPS. Default: 0.1.
15
+ """
16
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
17
+
18
+ assert version in ['0.1'], 'v0.1 is only supported now'
19
+
20
+ super(LPIPS, self).__init__()
21
+
22
+ # pretrained network
23
+ self.net = get_network(net_type).to("cuda")
24
+
25
+ # linear layers
26
+ self.lin = LinLayers(self.net.n_channels_list).to("cuda")
27
+ self.lin.load_state_dict(get_state_dict(net_type, version))
28
+
29
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
30
+ feat_x, feat_y = self.net(x), self.net(y)
31
+
32
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
33
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
34
+
35
+ return torch.sum(torch.cat(res, 0)) / x.shape[0]
criteria/lpips/networks.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ from itertools import chain
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from criteria.lpips.utils import normalize_activation
10
+
11
+
12
+ def get_network(net_type: str):
13
+ if net_type == 'alex':
14
+ return AlexNet()
15
+ elif net_type == 'squeeze':
16
+ return SqueezeNet()
17
+ elif net_type == 'vgg':
18
+ return VGG16()
19
+ else:
20
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21
+
22
+
23
+ class LinLayers(nn.ModuleList):
24
+ def __init__(self, n_channels_list: Sequence[int]):
25
+ super(LinLayers, self).__init__([
26
+ nn.Sequential(
27
+ nn.Identity(),
28
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29
+ ) for nc in n_channels_list
30
+ ])
31
+
32
+ for param in self.parameters():
33
+ param.requires_grad = False
34
+
35
+
36
+ class BaseNet(nn.Module):
37
+ def __init__(self):
38
+ super(BaseNet, self).__init__()
39
+
40
+ # register buffer
41
+ self.register_buffer(
42
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43
+ self.register_buffer(
44
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45
+
46
+ def set_requires_grad(self, state: bool):
47
+ for param in chain(self.parameters(), self.buffers()):
48
+ param.requires_grad = state
49
+
50
+ def z_score(self, x: torch.Tensor):
51
+ return (x - self.mean) / self.std
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = self.z_score(x)
55
+
56
+ output = []
57
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58
+ x = layer(x)
59
+ if i in self.target_layers:
60
+ output.append(normalize_activation(x))
61
+ if len(output) == len(self.target_layers):
62
+ break
63
+ return output
64
+
65
+
66
+ class SqueezeNet(BaseNet):
67
+ def __init__(self):
68
+ super(SqueezeNet, self).__init__()
69
+
70
+ self.layers = models.squeezenet1_1(True).features
71
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73
+
74
+ self.set_requires_grad(False)
75
+
76
+
77
+ class AlexNet(BaseNet):
78
+ def __init__(self):
79
+ super(AlexNet, self).__init__()
80
+
81
+ self.layers = models.alexnet(True).features
82
+ self.target_layers = [2, 5, 8, 10, 12]
83
+ self.n_channels_list = [64, 192, 384, 256, 256]
84
+
85
+ self.set_requires_grad(False)
86
+
87
+
88
+ class VGG16(BaseNet):
89
+ def __init__(self):
90
+ super(VGG16, self).__init__()
91
+
92
+ self.layers = models.vgg16(True).features
93
+ self.target_layers = [4, 9, 16, 23, 30]
94
+ self.n_channels_list = [64, 128, 256, 512, 512]
95
+
96
+ self.set_requires_grad(False)
criteria/lpips/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+
5
+
6
+ def normalize_activation(x, eps=1e-10):
7
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8
+ return x / (norm_factor + eps)
9
+
10
+
11
+ def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12
+ # build url
13
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
15
+
16
+ # download
17
+ old_state_dict = torch.hub.load_state_dict_from_url(
18
+ url, progress=True,
19
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
20
+ )
21
+
22
+ # rename keys
23
+ new_state_dict = OrderedDict()
24
+ for key, val in old_state_dict.items():
25
+ new_key = key
26
+ new_key = new_key.replace('lin', '')
27
+ new_key = new_key.replace('model.', '')
28
+ new_state_dict[new_key] = val
29
+
30
+ return new_state_dict
docs/teaser.jpg ADDED
docs/thumb.gif ADDED
env.yaml ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: uclt
2
+ channels:
3
+ - pytorch
4
+ - anaconda
5
+ - nvidia
6
+ - conda-forge
7
+ - defaults
8
+ dependencies:
9
+ - _ipyw_jlab_nb_ext_conf=0.1.0=py38_0
10
+ - _libgcc_mutex=0.1=conda_forge
11
+ - _openmp_mutex=4.5=1_llvm
12
+ - absl-py=0.13.0=pyhd8ed1ab_0
13
+ - aiohttp=3.7.4.post0=py38h497a2fe_0
14
+ - albumentations=1.0.3=pyhd8ed1ab_0
15
+ - alsa-lib=1.2.3=h516909a_0
16
+ - anaconda-client=1.8.0=py38h06a4308_0
17
+ - anaconda-navigator=2.0.4=py38_0
18
+ - anyio=2.2.0=py38h06a4308_1
19
+ - appdirs=1.4.4=pyh9f0ad1d_0
20
+ - argon2-cffi=20.1.0=py38h27cfd23_1
21
+ - async-timeout=3.0.1=py_1000
22
+ - async_generator=1.10=pyhd3eb1b0_0
23
+ - attrs=21.2.0=pyhd3eb1b0_0
24
+ - babel=2.9.1=pyhd3eb1b0_0
25
+ - backcall=0.2.0=pyhd3eb1b0_0
26
+ - backports=1.0=pyhd3eb1b0_2
27
+ - backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0
28
+ - backports.tempfile=1.0=pyhd3eb1b0_1
29
+ - backports.weakref=1.0.post1=py_1
30
+ - beautifulsoup4=4.9.3=pyha847dfd_0
31
+ - blas=1.0=mkl
32
+ - bleach=4.0.0=pyhd3eb1b0_0
33
+ - blinker=1.4=py_1
34
+ - brotli=1.0.9=h7f98852_5
35
+ - brotli-bin=1.0.9=h7f98852_5
36
+ - brotlipy=0.7.0=py38h27cfd23_1003
37
+ - bzip2=1.0.8=h7b6447c_0
38
+ - c-ares=1.17.1=h27cfd23_0
39
+ - ca-certificates=2021.10.8=ha878542_0
40
+ - cachetools=4.2.2=pyhd8ed1ab_0
41
+ - cairo=1.16.0=hf32fb01_1
42
+ - certifi=2021.10.8=py38h578d9bd_1
43
+ - cffi=1.14.6=py38h400218f_0
44
+ - chardet=4.0.0=py38h06a4308_1003
45
+ - click=8.0.1=pyhd3eb1b0_0
46
+ - cloudpickle=1.6.0=py_0
47
+ - clyent=1.2.2=py38_1
48
+ - conda=4.11.0=py38h578d9bd_0
49
+ - conda-build=3.21.4=py38h06a4308_0
50
+ - conda-content-trust=0.1.1=pyhd3eb1b0_0
51
+ - conda-env=2.6.0=1
52
+ - conda-package-handling=1.7.3=py38h27cfd23_1
53
+ - conda-repo-cli=1.0.4=pyhd3eb1b0_0
54
+ - conda-token=0.3.0=pyhd3eb1b0_0
55
+ - conda-verify=3.4.2=py_1
56
+ - cryptography=3.4.7=py38hd23ed53_0
57
+ - cudatoolkit=11.1.74=h6bb024c_0
58
+ - cycler=0.10.0=py_2
59
+ - cytoolz=0.11.0=py38h497a2fe_3
60
+ - dask-core=2021.8.1=pyhd8ed1ab_0
61
+ - dbus=1.13.18=hb2f20db_0
62
+ - decorator=5.0.9=pyhd3eb1b0_0
63
+ - defusedxml=0.7.1=pyhd3eb1b0_0
64
+ - dill=0.3.4=pyhd8ed1ab_0
65
+ - dominate=2.6.0=pyhd8ed1ab_0
66
+ - entrypoints=0.3=py38_0
67
+ - enum34=1.1.10=py38h32f6830_2
68
+ - expat=2.4.1=h2531618_2
69
+ - ffmpeg=4.3.2=hca11adc_0
70
+ - filelock=3.0.12=pyhd3eb1b0_1
71
+ - flask=1.1.2=pyh9f0ad1d_0
72
+ - flask-httpauth=4.4.0=pyhd8ed1ab_0
73
+ - fontconfig=2.13.1=h6c09931_0
74
+ - fonttools=4.25.0=pyhd3eb1b0_0
75
+ - freetype=2.10.4=h5ab3b9f_0
76
+ - fsspec=2021.7.0=pyhd8ed1ab_0
77
+ - ftfy=6.0.3=pyhd8ed1ab_0
78
+ - func_timeout=4.3.5=py_0
79
+ - future=0.18.2=py38_1
80
+ - gdown=4.2.0=pyhd8ed1ab_0
81
+ - geos=3.10.0=h9c3ff4c_0
82
+ - gettext=0.19.8.1=h0b5b191_1005
83
+ - git=2.23.0=pl526hacde149_0
84
+ - glib=2.68.4=h9c3ff4c_0
85
+ - glib-tools=2.68.4=h9c3ff4c_0
86
+ - glob2=0.7=pyhd3eb1b0_0
87
+ - gmp=6.2.1=h58526e2_0
88
+ - gnutls=3.6.13=h85f3911_1
89
+ - google-auth=1.35.0=pyh6c4a22f_0
90
+ - google-auth-oauthlib=0.4.5=pyhd8ed1ab_0
91
+ - gputil=1.4.0=pyh9f0ad1d_0
92
+ - graphite2=1.3.13=h58526e2_1001
93
+ - gst-plugins-base=1.18.4=hf529b03_2
94
+ - gstreamer=1.18.4=h76c114f_2
95
+ - harfbuzz=2.9.0=h83ec7ef_0
96
+ - hdf5=1.10.6=nompi_h6a2412b_1114
97
+ - icu=68.1=h58526e2_0
98
+ - idna=2.10=pyhd3eb1b0_0
99
+ - imagecodecs-lite=2019.12.3=py38h5c078b8_3
100
+ - imageio=2.9.0=py_0
101
+ - imageio-ffmpeg=0.4.5=pyhd8ed1ab_0
102
+ - imgaug=0.4.0=py_1
103
+ - importlib-metadata=3.10.0=py38h06a4308_0
104
+ - importlib_metadata=3.10.0=hd3eb1b0_0
105
+ - intel-openmp=2021.3.0=h06a4308_3350
106
+ - ipykernel=5.3.4=py38h5ca1d4c_0
107
+ - ipympl=0.8.2=pyhd8ed1ab_0
108
+ - ipython=7.26.0=py38hb070fc8_0
109
+ - ipython_genutils=0.2.0=pyhd3eb1b0_1
110
+ - ipywidgets=7.6.3=pyhd3eb1b0_1
111
+ - itsdangerous=2.0.1=pyhd3eb1b0_0
112
+ - jasper=1.900.1=h07fcdf6_1006
113
+ - jedi=0.18.0=py38h06a4308_1
114
+ - jinja2=2.11.3=pyhd3eb1b0_0
115
+ - joblib=1.1.0=pyhd8ed1ab_0
116
+ - jpeg=9d=h36c2ea0_0
117
+ - json5=0.9.6=pyhd3eb1b0_0
118
+ - jsonnet=0.17.0=py38hadf7658_0
119
+ - jsonschema=3.2.0=py_2
120
+ - jupyter_client=6.1.12=pyhd3eb1b0_0
121
+ - jupyter_core=4.7.1=py38h06a4308_0
122
+ - jupyter_server=1.4.1=py38h06a4308_0
123
+ - jupyterlab=3.1.7=pyhd3eb1b0_0
124
+ - jupyterlab_pygments=0.1.2=py_0
125
+ - jupyterlab_server=2.7.1=pyhd3eb1b0_0
126
+ - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
127
+ - kiwisolver=1.3.1=py38h1fd1430_1
128
+ - krb5=1.19.2=hcc1bbae_0
129
+ - lame=3.100=h7f98852_1001
130
+ - lcms2=2.12=h3be6417_0
131
+ - ld_impl_linux-64=2.35.1=h7274673_9
132
+ - libarchive=3.4.2=h62408e4_0
133
+ - libblas=3.9.0=11_linux64_mkl
134
+ - libbrotlicommon=1.0.9=h7f98852_5
135
+ - libbrotlidec=1.0.9=h7f98852_5
136
+ - libbrotlienc=1.0.9=h7f98852_5
137
+ - libcblas=3.9.0=11_linux64_mkl
138
+ - libcurl=7.78.0=h2574ce0_0
139
+ - libedit=3.1.20191231=he28a2e2_2
140
+ - libev=4.33=h516909a_1
141
+ - libevent=2.1.10=hcdb4288_3
142
+ - libffi=3.3=he6710b0_2
143
+ - libgcc-ng=11.1.0=hc902ee8_8
144
+ - libgfortran-ng=11.1.0=h69a702a_8
145
+ - libgfortran5=11.1.0=h6c583b3_8
146
+ - libglib=2.68.4=h3e27bee_0
147
+ - libiconv=1.16=h516909a_0
148
+ - liblapack=3.9.0=11_linux64_mkl
149
+ - liblapacke=3.9.0=11_linux64_mkl
150
+ - liblief=0.10.1=he6710b0_0
151
+ - libllvm11=11.1.0=hf817b99_2
152
+ - libnghttp2=1.43.0=h812cca2_0
153
+ - libogg=1.3.4=h7f98852_1
154
+ - libopencv=4.5.2=py38hcdf9bf1_0
155
+ - libopus=1.3.1=h7f98852_1
156
+ - libpng=1.6.37=hbc83047_0
157
+ - libpq=13.3=hd57d9b9_0
158
+ - libprotobuf=3.15.8=h780b84a_0
159
+ - libsodium=1.0.18=h7b6447c_0
160
+ - libssh2=1.9.0=ha56f1ee_6
161
+ - libstdcxx-ng=11.1.0=h56837e0_8
162
+ - libtiff=4.2.0=h85742a9_0
163
+ - libuuid=1.0.3=h1bed415_2
164
+ - libuv=1.40.0=h7b6447c_0
165
+ - libvorbis=1.3.7=h9c3ff4c_0
166
+ - libwebp-base=1.2.0=h27cfd23_0
167
+ - libxcb=1.14=h7b6447c_0
168
+ - libxkbcommon=1.0.3=he3ba5ed_0
169
+ - libxml2=2.9.12=h72842e0_0
170
+ - llvm-openmp=12.0.1=h4bd325d_1
171
+ - locket=0.2.0=py_2
172
+ - lz4-c=1.9.3=h295c915_1
173
+ - markdown=3.3.4=pyhd8ed1ab_0
174
+ - markupsafe=2.0.1=py38h27cfd23_0
175
+ - matplotlib=3.4.2=py38h578d9bd_0
176
+ - matplotlib-base=3.4.2=py38hab158f2_0
177
+ - matplotlib-inline=0.1.2=pyhd3eb1b0_2
178
+ - mistune=0.8.4=py38h7b6447c_1000
179
+ - mkl=2021.3.0=h06a4308_520
180
+ - mkl-service=2.4.0=py38h7f8727e_0
181
+ - mkl_fft=1.3.0=py38h42c9631_2
182
+ - mkl_random=1.2.2=py38h51133e4_0
183
+ - multidict=5.1.0=py38h497a2fe_1
184
+ - munkres=1.1.4=pyh9f0ad1d_0
185
+ - mysql-common=8.0.25=ha770c72_0
186
+ - mysql-libs=8.0.25=h935591d_0
187
+ - navigator-updater=0.2.1=py38_0
188
+ - nbclassic=0.2.6=pyhd3eb1b0_0
189
+ - nbclient=0.5.3=pyhd3eb1b0_0
190
+ - nbconvert=6.1.0=py38h06a4308_0
191
+ - nbformat=5.1.3=pyhd3eb1b0_0
192
+ - ncurses=6.2=he6710b0_1
193
+ - nest-asyncio=1.5.1=pyhd3eb1b0_0
194
+ - nettle=3.6=he412f7d_0
195
+ - networkx=2.3=py_0
196
+ - ninja=1.10.2=hff7bd54_1
197
+ - notebook=6.4.3=py38h06a4308_0
198
+ - nspr=4.30=h9c3ff4c_0
199
+ - nss=3.69=hb5efdd6_0
200
+ - numpy=1.20.3=py38hf144106_0
201
+ - numpy-base=1.20.3=py38h74d4b33_0
202
+ - oauthlib=3.1.1=pyhd8ed1ab_0
203
+ - olefile=0.46=py_0
204
+ - opencv=4.5.2=py38h578d9bd_0
205
+ - openh264=2.1.1=h780b84a_0
206
+ - openjpeg=2.3.0=h05c96fa_1
207
+ - openssl=1.1.1l=h7f98852_0
208
+ - packaging=21.0=pyhd3eb1b0_0
209
+ - pandas=1.3.2=py38h43a58ef_0
210
+ - pandocfilters=1.4.3=py38h06a4308_1
211
+ - parso=0.8.2=pyhd3eb1b0_0
212
+ - partd=1.2.0=pyhd8ed1ab_0
213
+ - patchelf=0.12=h2531618_1
214
+ - pathlib=1.0.1=py38h578d9bd_4
215
+ - patsy=0.5.1=py_0
216
+ - pcre=8.45=h295c915_0
217
+ - perl=5.26.2=h14c3975_0
218
+ - pexpect=4.8.0=pyhd3eb1b0_3
219
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
220
+ - pillow=8.3.1=py38h2c7a002_0
221
+ - pip=21.2.2=py38h06a4308_0
222
+ - pixman=0.40.0=h36c2ea0_0
223
+ - pkginfo=1.7.1=py38h06a4308_0
224
+ - pooch=1.5.1=pyhd8ed1ab_0
225
+ - portalocker=1.7.0=py38h578d9bd_1
226
+ - prometheus_client=0.11.0=pyhd3eb1b0_0
227
+ - prompt-toolkit=3.0.17=pyh06a4308_0
228
+ - protobuf=3.15.8=py38h709712a_0
229
+ - psutil=5.8.0=py38h27cfd23_1
230
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
231
+ - py-lief=0.10.1=py38h403a769_0
232
+ - py-opencv=4.5.2=py38hd0cf306_0
233
+ - pyasn1=0.4.8=py_0
234
+ - pyasn1-modules=0.2.7=py_0
235
+ - pycosat=0.6.3=py38h7b6447c_1
236
+ - pycparser=2.20=py_2
237
+ - pygments=2.10.0=pyhd3eb1b0_0
238
+ - pyjwt=2.1.0=pyhd8ed1ab_0
239
+ - pyopenssl=20.0.1=pyhd3eb1b0_1
240
+ - pyparsing=2.4.7=pyhd3eb1b0_0
241
+ - pypng=0.0.20=py_0
242
+ - pyqt=5.12.3=py38h578d9bd_7
243
+ - pyqt-impl=5.12.3=py38h7400c14_7
244
+ - pyqt5-sip=4.19.18=py38h709712a_7
245
+ - pyqtchart=5.12=py38h7400c14_7
246
+ - pyqtwebengine=5.12.1=py38h7400c14_7
247
+ - pyrsistent=0.17.3=py38h7b6447c_0
248
+ - pysocks=1.7.1=py38h06a4308_0
249
+ - python=3.8.10=h12debd9_8
250
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
251
+ - python-libarchive-c=2.9=pyhd3eb1b0_1
252
+ - python-lmdb=0.99=py38h709712a_0
253
+ - python_abi=3.8=2_cp38
254
+ - pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0
255
+ - pytz=2021.1=pyhd3eb1b0_0
256
+ - pyu2f=0.1.5=pyhd8ed1ab_0
257
+ - pywavelets=1.1.1=py38h5c078b8_3
258
+ - pyyaml=5.4.1=py38h27cfd23_1
259
+ - pyzmq=22.2.1=py38h295c915_1
260
+ - qt=5.12.9=hda022c4_4
261
+ - qtpy=1.9.0=py_0
262
+ - readline=8.1=h27cfd23_0
263
+ - regex=2021.8.28=py38h497a2fe_0
264
+ - requests=2.25.1=pyhd3eb1b0_0
265
+ - requests-oauthlib=1.3.0=pyh9f0ad1d_0
266
+ - ripgrep=12.1.1=0
267
+ - rsa=4.7.2=pyh44b312d_0
268
+ - ruamel_yaml=0.15.100=py38h27cfd23_0
269
+ - scikit-image=0.18.3=py38h43a58ef_0
270
+ - scikit-learn=1.0=py38hacb3eff_1
271
+ - scipy=1.7.1=py38h56a6a73_0
272
+ - seaborn=0.11.2=hd8ed1ab_0
273
+ - seaborn-base=0.11.2=pyhd8ed1ab_0
274
+ - send2trash=1.5.0=pyhd3eb1b0_1
275
+ - setuptools=52.0.0=py38h06a4308_0
276
+ - shapely=1.8.0=py38hf7953bd_1
277
+ - sip=4.19.13=py38he6710b0_0
278
+ - sniffio=1.2.0=py38h06a4308_1
279
+ - soupsieve=2.2.1=pyhd3eb1b0_0
280
+ - sqlite=3.36.0=hc218d9a_0
281
+ - statsmodels=0.12.2=py38h5c078b8_0
282
+ - tensorboard=2.6.0=pyhd8ed1ab_1
283
+ - tensorboard-data-server=0.6.0=py38h2b97feb_0
284
+ - tensorboard-plugin-wit=1.8.0=pyh44b312d_0
285
+ - tensorboardx=2.4=pyhd8ed1ab_0
286
+ - terminado=0.9.4=py38h06a4308_0
287
+ - testpath=0.5.0=pyhd3eb1b0_0
288
+ - threadpoolctl=3.0.0=pyh8a188c0_0
289
+ - tifffile=2019.7.26.2=py38_0
290
+ - tk=8.6.10=hbc83047_0
291
+ - toolz=0.11.1=py_0
292
+ - torchfile=0.1.0=py_0
293
+ - tornado=6.1=py38h27cfd23_0
294
+ - tqdm=4.62.1=pyhd3eb1b0_1
295
+ - traitlets=5.0.5=pyhd3eb1b0_0
296
+ - typing_extensions=3.10.0.0=pyh06a4308_0
297
+ - urllib3=1.26.6=pyhd3eb1b0_1
298
+ - wcwidth=0.2.5=py_0
299
+ - webencodings=0.5.1=py38_1
300
+ - werkzeug=1.0.1=pyhd3eb1b0_0
301
+ - wheel=0.37.0=pyhd3eb1b0_0
302
+ - widgetsnbextension=3.5.1=py38_0
303
+ - x264=1!161.3030=h7f98852_1
304
+ - xmltodict=0.12.0=py_0
305
+ - xz=5.2.5=h7b6447c_0
306
+ - yacs=0.1.6=py_0
307
+ - yaml=0.2.5=h7b6447c_0
308
+ - yarl=1.6.3=py38h497a2fe_2
309
+ - zeromq=4.3.4=h2531618_0
310
+ - zipp=3.5.0=pyhd3eb1b0_0
311
+ - zlib=1.2.11=h7b6447c_3
312
+ - zstd=1.4.9=haebb681_0
313
+ - pip:
314
+ - addict==2.4.0
315
+ - altair==4.2.0
316
+ - astor==0.8.1
317
+ - astunparse==1.6.3
318
+ - backports-zoneinfo==0.2.1
319
+ - base58==2.1.1
320
+ - basicsr==1.3.4.1
321
+ - boto3==1.18.33
322
+ - botocore==1.21.33
323
+ - clang==5.0
324
+ - clean-fid==0.1.22
325
+ - clip==1.0
326
+ - colorama==0.4.4
327
+ - commonmark==0.9.1
328
+ - cython==0.29.30
329
+ - einops==0.3.2
330
+ - enum-compat==0.0.3
331
+ - facexlib==0.2.0.3
332
+ - filterpy==1.4.5
333
+ - flatbuffers==1.12
334
+ - gast==0.4.0
335
+ - google-pasta==0.2.0
336
+ - grpcio==1.39.0
337
+ - h5py==3.1.0
338
+ - ipdb==0.13.9
339
+ - jacinle==1.0.0
340
+ - jmespath==0.10.0
341
+ - jsonpickle==2.2.0
342
+ - keras==2.7.0
343
+ - keras-preprocessing==1.1.2
344
+ - libclang==12.0.0
345
+ - llvmlite==0.37.0
346
+ - lpips==0.1.4
347
+ - numba==0.54.0
348
+ - opencv-python==4.5.3.56
349
+ - opt-einsum==3.3.0
350
+ - pkgconfig==1.5.5
351
+ - pyarrow==8.0.0
352
+ - pydantic==1.8.2
353
+ - pydeck==0.7.1
354
+ - pyhocon==0.3.58
355
+ - pytz-deprecation-shim==0.1.0.post0
356
+ - pyvis==0.2.1
357
+ - realesrgan==0.2.2.3
358
+ - rich==10.9.0
359
+ - s3transfer==0.5.0
360
+ - six==1.15.0
361
+ - sklearn==0.0
362
+ - streamlit==0.64.0
363
+ - tabulate==0.8.9
364
+ - tb-nightly==2.7.0a20210827
365
+ - tensorflow-estimator==2.7.0
366
+ - tensorflow-gpu==2.7.0
367
+ - tensorflow-io-gcs-filesystem==0.21.0
368
+ - tensorfn==0.1.19
369
+ - termcolor==1.1.0
370
+ - toml==0.10.2
371
+ - torchsample==0.1.3
372
+ - torchvision==0.10.0+cu111
373
+ - typing-extensions==3.7.4.3
374
+ - tzdata==2022.1
375
+ - tzlocal==4.2
376
+ - validators==0.19.0
377
+ - vit-pytorch==0.24.3
378
+ - watchdog==2.1.8
379
+ - wrapt==1.12.1
380
+ - yapf==0.31.0
expansion/__init__.py ADDED
File without changes
expansion/dataloader/__init__.py ADDED
File without changes
expansion/dataloader/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (162 Bytes). View file
 
expansion/dataloader/__pycache__/seqlist.cpython-38.pyc ADDED
Binary file (1.12 kB). View file
 
expansion/dataloader/chairslist.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+ import glob
8
+
9
+ IMG_EXTENSIONS = [
10
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
11
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12
+ ]
13
+
14
+
15
+ def is_image_file(filename):
16
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17
+
18
+ def dataloader(filepath):
19
+ l0_train = []
20
+ l1_train = []
21
+ flow_train = []
22
+ for flow_map in sorted(glob.glob(os.path.join(filepath,'*_flow.flo'))):
23
+ root_filename = flow_map[:-9]
24
+ img1 = root_filename+'_img1.ppm'
25
+ img2 = root_filename+'_img2.ppm'
26
+ if not (os.path.isfile(os.path.join(filepath,img1)) and os.path.isfile(os.path.join(filepath,img2))):
27
+ continue
28
+
29
+ l0_train.append(img1)
30
+ l1_train.append(img2)
31
+ flow_train.append(flow_map)
32
+
33
+ return l0_train, l1_train, flow_train
expansion/dataloader/chairssdlist.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+ import glob
8
+
9
+ IMG_EXTENSIONS = [
10
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
11
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12
+ ]
13
+
14
+
15
+ def is_image_file(filename):
16
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17
+
18
+ def dataloader(filepath):
19
+ l0_train = []
20
+ l1_train = []
21
+ flow_train = []
22
+ for flow_map in sorted(glob.glob('%s/flow/*.pfm'%filepath)):
23
+ img1 = flow_map.replace('flow','t0').replace('.pfm','.png')
24
+ img2 = flow_map.replace('flow','t1').replace('.pfm','.png')
25
+
26
+ l0_train.append(img1)
27
+ l1_train.append(img2)
28
+ flow_train.append(flow_map)
29
+
30
+ return l0_train, l1_train, flow_train
expansion/dataloader/depth_transforms.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import numbers
6
+ import types
7
+ import scipy.ndimage as ndimage
8
+ import pdb
9
+ import torchvision
10
+ import PIL.Image as Image
11
+ import cv2
12
+ from torch.nn import functional as F
13
+
14
+
15
+ class Compose(object):
16
+ """ Composes several co_transforms together.
17
+ For example:
18
+ >>> co_transforms.Compose([
19
+ >>> co_transforms.CenterCrop(10),
20
+ >>> co_transforms.ToTensor(),
21
+ >>> ])
22
+ """
23
+
24
+ def __init__(self, co_transforms):
25
+ self.co_transforms = co_transforms
26
+
27
+ def __call__(self, input, target,intr):
28
+ for t in self.co_transforms:
29
+ input,target,intr = t(input,target,intr)
30
+ return input,target,intr
31
+
32
+
33
+ class Scale(object):
34
+ """ Rescales the inputs and target arrays to the given 'size'.
35
+ 'size' will be the size of the smaller edge.
36
+ For example, if height > width, then image will be
37
+ rescaled to (size * height / width, size)
38
+ size: size of the smaller edge
39
+ interpolation order: Default: 2 (bilinear)
40
+ """
41
+
42
+ def __init__(self, size, order=1):
43
+ self.ratio = size
44
+ self.order = order
45
+ if order==0:
46
+ self.code=cv2.INTER_NEAREST
47
+ elif order==1:
48
+ self.code=cv2.INTER_LINEAR
49
+ elif order==2:
50
+ self.code=cv2.INTER_CUBIC
51
+
52
+ def __call__(self, inputs, target):
53
+ if self.ratio==1:
54
+ return inputs, target
55
+ h, w, _ = inputs[0].shape
56
+ ratio = self.ratio
57
+
58
+ inputs[0] = cv2.resize(inputs[0], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR)
59
+ inputs[1] = cv2.resize(inputs[1], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR)
60
+ # keep the mask same
61
+ tmp = cv2.resize(target[:,:,2], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_NEAREST)
62
+ target = cv2.resize(target, None, fx=ratio,fy=ratio,interpolation=self.code) * ratio
63
+ target[:,:,2] = tmp
64
+
65
+
66
+ return inputs, target
67
+
68
+
69
+ class RandomCrop(object):
70
+ """Crops the given PIL.Image at a random location to have a region of
71
+ the given size. size can be a tuple (target_height, target_width)
72
+ or an integer, in which case the target will be of a square shape (size, size)
73
+ """
74
+
75
+ def __init__(self, size):
76
+ if isinstance(size, numbers.Number):
77
+ self.size = (int(size), int(size))
78
+ else:
79
+ self.size = size
80
+
81
+ def __call__(self, inputs,target,intr):
82
+ h, w, _ = inputs[0].shape
83
+ th, tw = self.size
84
+ if w < tw: tw=w
85
+ if h < th: th=h
86
+
87
+ x1 = random.randint(0, w - tw)
88
+ y1 = random.randint(0, h - th)
89
+ intr[1] -= x1
90
+ intr[2] -= y1
91
+
92
+ inputs[0] = inputs[0][y1: y1 + th,x1: x1 + tw].astype(float)
93
+ inputs[1] = inputs[1][y1: y1 + th,x1: x1 + tw].astype(float)
94
+ return inputs, target[y1: y1 + th,x1: x1 + tw].astype(float), list(np.asarray(intr).astype(float)) + list(np.asarray([1.,0.,0.,1.,0.,0.]).astype(float))
95
+
96
+
97
+
98
+ class SpatialAug(object):
99
+ def __init__(self, crop, scale=None, rot=None, trans=None, squeeze=None, schedule_coeff=1, order=1, black=False):
100
+ self.crop = crop
101
+ self.scale = scale
102
+ self.rot = rot
103
+ self.trans = trans
104
+ self.squeeze = squeeze
105
+ self.t = np.zeros(6)
106
+ self.schedule_coeff = schedule_coeff
107
+ self.order = order
108
+ self.black = black
109
+
110
+ def to_identity(self):
111
+ self.t[0] = 1; self.t[2] = 0; self.t[4] = 0; self.t[1] = 0; self.t[3] = 1; self.t[5] = 0;
112
+
113
+ def left_multiply(self, u0, u1, u2, u3, u4, u5):
114
+ result = np.zeros(6)
115
+ result[0] = self.t[0]*u0 + self.t[1]*u2;
116
+ result[1] = self.t[0]*u1 + self.t[1]*u3;
117
+
118
+ result[2] = self.t[2]*u0 + self.t[3]*u2;
119
+ result[3] = self.t[2]*u1 + self.t[3]*u3;
120
+
121
+ result[4] = self.t[4]*u0 + self.t[5]*u2 + u4;
122
+ result[5] = self.t[4]*u1 + self.t[5]*u3 + u5;
123
+ self.t = result
124
+
125
+ def inverse(self):
126
+ result = np.zeros(6)
127
+ a = self.t[0]; c = self.t[2]; e = self.t[4];
128
+ b = self.t[1]; d = self.t[3]; f = self.t[5];
129
+
130
+ denom = a*d - b*c;
131
+
132
+ result[0] = d / denom;
133
+ result[1] = -b / denom;
134
+ result[2] = -c / denom;
135
+ result[3] = a / denom;
136
+ result[4] = (c*f-d*e) / denom;
137
+ result[5] = (b*e-a*f) / denom;
138
+
139
+ return result
140
+
141
+ def grid_transform(self, meshgrid, t, normalize=True, gridsize=None):
142
+ if gridsize is None:
143
+ h, w = meshgrid[0].shape
144
+ else:
145
+ h, w = gridsize
146
+ vgrid = torch.cat([(meshgrid[0] * t[0] + meshgrid[1] * t[2] + t[4])[:,:,np.newaxis],
147
+ (meshgrid[0] * t[1] + meshgrid[1] * t[3] + t[5])[:,:,np.newaxis]],-1)
148
+ if normalize:
149
+ vgrid[:,:,0] = 2.0*vgrid[:,:,0]/max(w-1,1)-1.0
150
+ vgrid[:,:,1] = 2.0*vgrid[:,:,1]/max(h-1,1)-1.0
151
+ return vgrid
152
+
153
+
154
+ def __call__(self, inputs, target, intr):
155
+ h, w, _ = inputs[0].shape
156
+ th, tw = self.crop
157
+ meshgrid = torch.meshgrid([torch.Tensor(range(th)), torch.Tensor(range(tw))])[::-1]
158
+ cornergrid = torch.meshgrid([torch.Tensor([0,th-1]), torch.Tensor([0,tw-1])])[::-1]
159
+
160
+ for i in range(50):
161
+ # im0
162
+ self.to_identity()
163
+ #TODO add mirror
164
+ if np.random.binomial(1,0.5):
165
+ mirror = True
166
+ else:
167
+ mirror = False
168
+ ##TODO
169
+ #mirror = False
170
+ if mirror:
171
+ self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th);
172
+ else:
173
+ self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
174
+ scale0 = 1; scale1 = 1; squeeze0 = 1; squeeze1 = 1;
175
+ if not self.rot is None:
176
+ rot0 = np.random.uniform(-self.rot[0],+self.rot[0])
177
+ rot1 = np.random.uniform(-self.rot[1]*self.schedule_coeff, self.rot[1]*self.schedule_coeff) + rot0
178
+ self.left_multiply(np.cos(rot0), np.sin(rot0), -np.sin(rot0), np.cos(rot0), 0, 0)
179
+ if not self.trans is None:
180
+ trans0 = np.random.uniform(-self.trans[0],+self.trans[0], 2)
181
+ trans1 = np.random.uniform(-self.trans[1]*self.schedule_coeff,+self.trans[1]*self.schedule_coeff, 2) + trans0
182
+ self.left_multiply(1, 0, 0, 1, trans0[0] * tw, trans0[1] * th)
183
+ if not self.squeeze is None:
184
+ squeeze0 = np.exp(np.random.uniform(-self.squeeze[0], self.squeeze[0]))
185
+ squeeze1 = np.exp(np.random.uniform(-self.squeeze[1]*self.schedule_coeff, self.squeeze[1]*self.schedule_coeff)) * squeeze0
186
+ if not self.scale is None:
187
+ scale0 = np.exp(np.random.uniform(self.scale[2]-self.scale[0], self.scale[2]+self.scale[0]))
188
+ scale1 = np.exp(np.random.uniform(-self.scale[1]*self.schedule_coeff, self.scale[1]*self.schedule_coeff)) * scale0
189
+ self.left_multiply(1.0/(scale0*squeeze0), 0, 0, 1.0/(scale0/squeeze0), 0, 0)
190
+
191
+ self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
192
+ transmat0 = self.t.copy()
193
+
194
+ # im1
195
+ self.to_identity()
196
+ if mirror:
197
+ self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th);
198
+ else:
199
+ self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
200
+ if not self.rot is None:
201
+ self.left_multiply(np.cos(rot1), np.sin(rot1), -np.sin(rot1), np.cos(rot1), 0, 0)
202
+ if not self.trans is None:
203
+ self.left_multiply(1, 0, 0, 1, trans1[0] * tw, trans1[1] * th)
204
+ self.left_multiply(1.0/(scale1*squeeze1), 0, 0, 1.0/(scale1/squeeze1), 0, 0)
205
+ self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
206
+ transmat1 = self.t.copy()
207
+ transmat1_inv = self.inverse()
208
+
209
+ if self.black:
210
+ # black augmentation, allowing 0 values in the input images
211
+ # https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/black_augmentation_layer.cu
212
+ break
213
+ else:
214
+ if ((self.grid_transform(cornergrid, transmat0, gridsize=[float(h),float(w)]).abs()>1).sum() +\
215
+ (self.grid_transform(cornergrid, transmat1, gridsize=[float(h),float(w)]).abs()>1).sum()) == 0:
216
+ break
217
+ if i==49:
218
+ print('max_iter in augmentation')
219
+ self.to_identity()
220
+ self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
221
+ self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
222
+ transmat0 = self.t.copy()
223
+ transmat1 = self.t.copy()
224
+
225
+ # do the real work
226
+ vgrid = self.grid_transform(meshgrid, transmat0,gridsize=[float(h),float(w)])
227
+ inputs_0 = F.grid_sample(torch.Tensor(inputs[0]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
228
+ if self.order == 0:
229
+ target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0)
230
+ else:
231
+ target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
232
+
233
+ mask_0 = target[:,:,2:3].copy(); mask_0[mask_0==0]=np.nan
234
+ if self.order == 0:
235
+ mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0)
236
+ else:
237
+ mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
238
+ mask_0[torch.isnan(mask_0)] = 0
239
+
240
+
241
+ vgrid = self.grid_transform(meshgrid, transmat1,gridsize=[float(h),float(w)])
242
+ inputs_1 = F.grid_sample(torch.Tensor(inputs[1]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
243
+
244
+ # flow
245
+ pos = target_0[:,:,:2] + self.grid_transform(meshgrid, transmat0,normalize=False)
246
+ pos = self.grid_transform(pos.permute(2,0,1),transmat1_inv,normalize=False)
247
+ if target_0.shape[2]>=4:
248
+ # scale
249
+ exp = target_0[:,:,3:] * scale1 / scale0
250
+ target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1),
251
+ (pos[:,:,1] - meshgrid[1]).unsqueeze(-1),
252
+ mask_0,
253
+ exp], -1)
254
+ else:
255
+ target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1),
256
+ (pos[:,:,1] - meshgrid[1]).unsqueeze(-1),
257
+ mask_0], -1)
258
+ inputs = [np.asarray(inputs_0).astype(float), np.asarray(inputs_1).astype(float)]
259
+ target = np.asarray(target).astype(float)
260
+ return inputs,target, list(np.asarray(intr+list(transmat0)).astype(float))
261
+
262
+
263
+
264
+ class pseudoPCAAug(object):
265
+ """
266
+ Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
267
+ This version is faster.
268
+ """
269
+ def __init__(self, schedule_coeff=1):
270
+ self.augcolor = torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.5, hue=0.5/3.14)
271
+
272
+ def __call__(self, inputs, target,intr):
273
+ img = np.concatenate([inputs[0],inputs[1]],0)
274
+ shape = img.shape[0]//2
275
+ aug_img = np.asarray(self.augcolor(Image.fromarray(np.uint8(img*255))))/255.
276
+ inputs[0] = aug_img[:shape]
277
+ inputs[1] = aug_img[shape:]
278
+ #inputs[0] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[0]*255))))/255.
279
+ #inputs[1] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[1]*255))))/255.
280
+ return inputs,target,intr
281
+
282
+
283
+ class PCAAug(object):
284
+ """
285
+ Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
286
+ """
287
+ def __init__(self, lmult_pow =[0.4, 0,-0.2],
288
+ lmult_mult =[0.4, 0,0, ],
289
+ lmult_add =[0.03,0,0, ],
290
+ sat_pow =[0.4, 0,0, ],
291
+ sat_mult =[0.5, 0,-0.3],
292
+ sat_add =[0.03,0,0, ],
293
+ col_pow =[0.4, 0,0, ],
294
+ col_mult =[0.2, 0,0, ],
295
+ col_add =[0.02,0,0, ],
296
+ ladd_pow =[0.4, 0,0, ],
297
+ ladd_mult =[0.4, 0,0, ],
298
+ ladd_add =[0.04,0,0, ],
299
+ col_rotate =[1., 0,0, ],
300
+ schedule_coeff=1):
301
+ # no mean
302
+ self.pow_nomean = [1,1,1]
303
+ self.add_nomean = [0,0,0]
304
+ self.mult_nomean = [1,1,1]
305
+ self.pow_withmean = [1,1,1]
306
+ self.add_withmean = [0,0,0]
307
+ self.mult_withmean = [1,1,1]
308
+ self.lmult_pow = 1
309
+ self.lmult_mult = 1
310
+ self.lmult_add = 0
311
+ self.col_angle = 0
312
+ if not ladd_pow is None:
313
+ self.pow_nomean[0] =np.exp(np.random.normal(ladd_pow[2], ladd_pow[0]))
314
+ if not col_pow is None:
315
+ self.pow_nomean[1] =np.exp(np.random.normal(col_pow[2], col_pow[0]))
316
+ self.pow_nomean[2] =np.exp(np.random.normal(col_pow[2], col_pow[0]))
317
+
318
+ if not ladd_add is None:
319
+ self.add_nomean[0] =np.random.normal(ladd_add[2], ladd_add[0])
320
+ if not col_add is None:
321
+ self.add_nomean[1] =np.random.normal(col_add[2], col_add[0])
322
+ self.add_nomean[2] =np.random.normal(col_add[2], col_add[0])
323
+
324
+ if not ladd_mult is None:
325
+ self.mult_nomean[0] =np.exp(np.random.normal(ladd_mult[2], ladd_mult[0]))
326
+ if not col_mult is None:
327
+ self.mult_nomean[1] =np.exp(np.random.normal(col_mult[2], col_mult[0]))
328
+ self.mult_nomean[2] =np.exp(np.random.normal(col_mult[2], col_mult[0]))
329
+
330
+ # with mean
331
+ if not sat_pow is None:
332
+ self.pow_withmean[1] =np.exp(np.random.uniform(sat_pow[2]-sat_pow[0], sat_pow[2]+sat_pow[0]))
333
+ self.pow_withmean[2] =self.pow_withmean[1]
334
+ if not sat_add is None:
335
+ self.add_withmean[1] =np.random.uniform(sat_add[2]-sat_add[0], sat_add[2]+sat_add[0])
336
+ self.add_withmean[2] =self.add_withmean[1]
337
+ if not sat_mult is None:
338
+ self.mult_withmean[1] = np.exp(np.random.uniform(sat_mult[2]-sat_mult[0], sat_mult[2]+sat_mult[0]))
339
+ self.mult_withmean[2] = self.mult_withmean[1]
340
+
341
+ if not lmult_pow is None:
342
+ self.lmult_pow = np.exp(np.random.uniform(lmult_pow[2]-lmult_pow[0], lmult_pow[2]+lmult_pow[0]))
343
+ if not lmult_mult is None:
344
+ self.lmult_mult= np.exp(np.random.uniform(lmult_mult[2]-lmult_mult[0], lmult_mult[2]+lmult_mult[0]))
345
+ if not lmult_add is None:
346
+ self.lmult_add = np.random.uniform(lmult_add[2]-lmult_add[0], lmult_add[2]+lmult_add[0])
347
+ if not col_rotate is None:
348
+ self.col_angle= np.random.uniform(col_rotate[2]-col_rotate[0], col_rotate[2]+col_rotate[0])
349
+
350
+ # eigen vectors
351
+ self.eigvec = np.reshape([0.51,0.56,0.65,0.79,0.01,-0.62,0.35,-0.83,0.44],[3,3]).transpose()
352
+
353
+
354
+ def __call__(self, inputs, target, intr):
355
+ inputs[0] = self.pca_image(inputs[0])
356
+ inputs[1] = self.pca_image(inputs[1])
357
+ return inputs,target,intr
358
+
359
+ def pca_image(self, rgb):
360
+ eig = np.dot(rgb, self.eigvec)
361
+ max_rgb = np.clip(rgb,0,np.inf).max((0,1))
362
+ min_rgb = rgb.min((0,1))
363
+ mean_rgb = rgb.mean((0,1))
364
+ max_abs_eig = np.abs(eig).max((0,1))
365
+ max_l = np.sqrt(np.sum(max_abs_eig*max_abs_eig))
366
+ mean_eig = np.dot(mean_rgb, self.eigvec)
367
+
368
+ # no-mean stuff
369
+ eig -= mean_eig[np.newaxis, np.newaxis]
370
+
371
+ for c in range(3):
372
+ if max_abs_eig[c] > 1e-2:
373
+ mean_eig[c] /= max_abs_eig[c]
374
+ eig[:,:,c] = eig[:,:,c] / max_abs_eig[c];
375
+ eig[:,:,c] = np.power(np.abs(eig[:,:,c]),self.pow_nomean[c]) *\
376
+ ((eig[:,:,c] > 0) -0.5)*2
377
+ eig[:,:,c] = eig[:,:,c] + self.add_nomean[c]
378
+ eig[:,:,c] = eig[:,:,c] * self.mult_nomean[c]
379
+ eig += mean_eig[np.newaxis,np.newaxis]
380
+
381
+ # withmean stuff
382
+ if max_abs_eig[0] > 1e-2:
383
+ eig[:,:,0] = np.power(np.abs(eig[:,:,0]),self.pow_withmean[0]) * \
384
+ ((eig[:,:,0]>0)-0.5)*2;
385
+ eig[:,:,0] = eig[:,:,0] + self.add_withmean[0];
386
+ eig[:,:,0] = eig[:,:,0] * self.mult_withmean[0];
387
+
388
+ s = np.sqrt(eig[:,:,1]*eig[:,:,1] + eig[:,:,2] * eig[:,:,2])
389
+ smask = s > 1e-2
390
+ s1 = np.power(s, self.pow_withmean[1]);
391
+ s1 = np.clip(s1 + self.add_withmean[1], 0,np.inf)
392
+ s1 = s1 * self.mult_withmean[1]
393
+ s1 = s1 * smask + s*(1-smask)
394
+
395
+ # color angle
396
+ if self.col_angle!=0:
397
+ temp1 = np.cos(self.col_angle) * eig[:,:,1] - np.sin(self.col_angle) * eig[:,:,2]
398
+ temp2 = np.sin(self.col_angle) * eig[:,:,1] + np.cos(self.col_angle) * eig[:,:,2]
399
+ eig[:,:,1] = temp1
400
+ eig[:,:,2] = temp2
401
+
402
+ # to origin magnitude
403
+ for c in range(3):
404
+ if max_abs_eig[c] > 1e-2:
405
+ eig[:,:,c] = eig[:,:,c] * max_abs_eig[c]
406
+
407
+ if max_l > 1e-2:
408
+ l1 = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2])
409
+ l1 = l1 / max_l
410
+
411
+ eig[:,:,1][smask] = (eig[:,:,1] / s * s1)[smask]
412
+ eig[:,:,2][smask] = (eig[:,:,2] / s * s1)[smask]
413
+ #eig[:,:,1] = (eig[:,:,1] / s * s1) * smask + eig[:,:,1] * (1-smask)
414
+ #eig[:,:,2] = (eig[:,:,2] / s * s1) * smask + eig[:,:,2] * (1-smask)
415
+
416
+ if max_l > 1e-2:
417
+ l = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2])
418
+ l1 = np.power(l1, self.lmult_pow)
419
+ l1 = np.clip(l1 + self.lmult_add, 0, np.inf)
420
+ l1 = l1 * self.lmult_mult
421
+ l1 = l1 * max_l
422
+ lmask = l > 1e-2
423
+ eig[lmask] = (eig / l[:,:,np.newaxis] * l1[:,:,np.newaxis])[lmask]
424
+ for c in range(3):
425
+ eig[:,:,c][lmask] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c]))[lmask]
426
+ # for c in range(3):
427
+ # # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] * lmask + eig[:,:,c] * (1-lmask)
428
+ # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask]
429
+ # eig[:,:,c] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c])) * lmask + eig[:,:,c] * (1-lmask)
430
+
431
+ return np.clip(np.dot(eig, self.eigvec.transpose()), 0, 1)
432
+
433
+
434
+ class ChromaticAug(object):
435
+ """
436
+ Chromatic augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
437
+ """
438
+ def __init__(self, noise = 0.06,
439
+ gamma = 0.02,
440
+ brightness = 0.02,
441
+ contrast = 0.02,
442
+ color = 0.02,
443
+ schedule_coeff=1):
444
+
445
+ self.noise = np.random.uniform(0,noise)
446
+ self.gamma = np.exp(np.random.normal(0, gamma*schedule_coeff))
447
+ self.brightness = np.random.normal(0, brightness*schedule_coeff)
448
+ self.contrast = np.exp(np.random.normal(0, contrast*schedule_coeff))
449
+ self.color = np.exp(np.random.normal(0, color*schedule_coeff,3))
450
+
451
+ def __call__(self, inputs, target, intr):
452
+ inputs[1] = self.chrom_aug(inputs[1])
453
+ # noise
454
+ inputs[0]+=np.random.normal(0, self.noise, inputs[0].shape)
455
+ inputs[1]+=np.random.normal(0, self.noise, inputs[0].shape)
456
+ return inputs,target,intr
457
+
458
+ def chrom_aug(self, rgb):
459
+ # color change
460
+ mean_in = rgb.sum(-1)
461
+ rgb = rgb*self.color[np.newaxis,np.newaxis]
462
+ brightness_coeff = mean_in / (rgb.sum(-1)+0.01)
463
+ rgb = np.clip(rgb*brightness_coeff[:,:,np.newaxis],0,1)
464
+ # gamma
465
+ rgb = np.power(rgb,self.gamma)
466
+ # brightness
467
+ rgb += self.brightness
468
+ # contrast
469
+ rgb = 0.5 + ( rgb-0.5)*self.contrast
470
+ rgb = np.clip(rgb, 0, 1)
471
+ return rgb
expansion/dataloader/depthloader.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numbers
3
+ import torch
4
+ import torch.utils.data as data
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ import random
8
+ from PIL import Image, ImageOps
9
+ import numpy as np
10
+ import torchvision
11
+ from . import depth_transforms as flow_transforms
12
+ import pdb
13
+ import cv2
14
+ from utils.flowlib import read_flow
15
+ from utils.util_flow import readPFM, load_calib_cam_to_cam
16
+
17
+ def default_loader(path):
18
+ return Image.open(path).convert('RGB')
19
+
20
+ def flow_loader(path):
21
+ if '.pfm' in path:
22
+ data = readPFM(path)[0]
23
+ data[:,:,2] = 1
24
+ return data
25
+ else:
26
+ return read_flow(path)
27
+
28
+ def load_exts(cam_file):
29
+ with open(cam_file, 'r') as f:
30
+ lines = f.readlines()
31
+
32
+ l_exts = []
33
+ r_exts = []
34
+ for l in lines:
35
+ if 'L ' in l:
36
+ l_exts.append(np.asarray([float(i) for i in l[2:].strip().split(' ')]).reshape(4,4))
37
+ if 'R ' in l:
38
+ r_exts.append(np.asarray([float(i) for i in l[2:].strip().split(' ')]).reshape(4,4))
39
+ return l_exts,r_exts
40
+
41
+ def disparity_loader(path):
42
+ if '.png' in path:
43
+ data = Image.open(path)
44
+ data = np.ascontiguousarray(data,dtype=np.float32)/256
45
+ return data
46
+ else:
47
+ return readPFM(path)[0]
48
+
49
+ # triangulation
50
+ def triangulation(disp, xcoord, ycoord, bl=1, fl = 450, cx = 479.5, cy = 269.5):
51
+ depth = bl*fl / disp # 450px->15mm focal length
52
+ X = (xcoord - cx) * depth / fl
53
+ Y = (ycoord - cy) * depth / fl
54
+ Z = depth
55
+ P = np.concatenate((X[np.newaxis],Y[np.newaxis],Z[np.newaxis]),0).reshape(3,-1)
56
+ P = np.concatenate((P,np.ones((1,P.shape[-1]))),0)
57
+ return P
58
+
59
+ class myImageFloder(data.Dataset):
60
+ def __init__(self, iml0, iml1, flowl0, loader=default_loader, dploader= flow_loader, scale=1.,shape=[320,448], order=1, noise=0.06, pca_augmentor=True, prob = 1.,sc=False,disp0=None,disp1=None,calib=None ):
61
+ self.iml0 = iml0
62
+ self.iml1 = iml1
63
+ self.flowl0 = flowl0
64
+ self.loader = loader
65
+ self.dploader = dploader
66
+ self.scale=scale
67
+ self.shape=shape
68
+ self.order=order
69
+ self.noise = noise
70
+ self.pca_augmentor = pca_augmentor
71
+ self.prob = prob
72
+ self.sc = sc
73
+ self.disp0 = disp0
74
+ self.disp1 = disp1
75
+ self.calib = calib
76
+
77
+ def __getitem__(self, index):
78
+ iml0 = self.iml0[index]
79
+ iml1 = self.iml1[index]
80
+ flowl0= self.flowl0[index]
81
+ th, tw = self.shape
82
+
83
+ iml0 = self.loader(iml0)
84
+ iml1 = self.loader(iml1)
85
+
86
+ # get disparity
87
+ if self.sc:
88
+ flowl0 = self.dploader(flowl0)
89
+ flowl0 = np.ascontiguousarray(flowl0,dtype=np.float32)
90
+ flowl0[np.isnan(flowl0)] = 1e6 # set to max
91
+ if 'camera_data.txt' in self.calib[index]:
92
+ bl=1
93
+ if '15mm_' in self.calib[index]:
94
+ fl=450 # 450
95
+ else:
96
+ fl=1050
97
+ cx = 479.5
98
+ cy = 269.5
99
+ # negative disp
100
+ d1 = np.abs(disparity_loader(self.disp0[index]))
101
+ d2 = np.abs(disparity_loader(self.disp1[index]) + d1)
102
+ elif 'Sintel' in self.calib[index]:
103
+ fl = 1000
104
+ bl = 1
105
+ cx = 511.5
106
+ cy = 217.5
107
+ d1 = np.zeros(flowl0.shape[:2])
108
+ d2 = np.zeros(flowl0.shape[:2])
109
+ else:
110
+ ints = load_calib_cam_to_cam(self.calib[index])
111
+ fl = ints['K_cam2'][0,0]
112
+ cx = ints['K_cam2'][0,2]
113
+ cy = ints['K_cam2'][1,2]
114
+ bl = ints['b20']-ints['b30']
115
+ d1 = disparity_loader(self.disp0[index])
116
+ d2 = disparity_loader(self.disp1[index])
117
+ #flowl0[:,:,2] = (flowl0[:,:,2]==1).astype(float)
118
+ flowl0[:,:,2] = np.logical_and(np.logical_and(flowl0[:,:,2]==1, d1!=0), d2!=0).astype(float)
119
+
120
+ shape = d1.shape
121
+ mesh = np.meshgrid(range(shape[1]),range(shape[0]))
122
+ xcoord = mesh[0].astype(float)
123
+ ycoord = mesh[1].astype(float)
124
+
125
+ # triangulation in two frames
126
+ P0 = triangulation(d1, xcoord, ycoord, bl=bl, fl = fl, cx = cx, cy = cy)
127
+ P1 = triangulation(d2, xcoord + flowl0[:,:,0], ycoord + flowl0[:,:,1], bl=bl, fl = fl, cx = cx, cy = cy)
128
+ dis0 = P0[2]
129
+ dis1 = P1[2]
130
+
131
+ change_size = dis0.reshape(shape).astype(np.float32)
132
+ flow3d = (P1-P0)[:3].reshape((3,)+shape).transpose((1,2,0))
133
+
134
+ gt_normal = np.concatenate((d1[:,:,np.newaxis],d2[:,:,np.newaxis],d2[:,:,np.newaxis]),-1)
135
+ change_size = np.concatenate((change_size[:,:,np.newaxis],gt_normal,flow3d),2)
136
+ else:
137
+ shape = iml0.size
138
+ shape=[shape[1],shape[0]]
139
+ flowl0 = np.zeros((shape[0],shape[1],3))
140
+ change_size = np.zeros((shape[0],shape[1],7))
141
+ depth = disparity_loader(self.iml1[index].replace('camera','groundtruth'))
142
+ change_size[:,:,0] = depth
143
+
144
+ seqid = self.iml0[index].split('/')[-5].rsplit('_',3)[0]
145
+ ints = load_calib_cam_to_cam('/data/gengshay/KITTI/%s/calib_cam_to_cam.txt'%seqid)
146
+ fl = ints['K_cam2'][0,0]
147
+ cx = ints['K_cam2'][0,2]
148
+ cy = ints['K_cam2'][1,2]
149
+ bl = ints['b20']-ints['b30']
150
+
151
+
152
+ iml1 = np.asarray(iml1)/255.
153
+ iml0 = np.asarray(iml0)/255.
154
+ iml0 = iml0[:,:,::-1].copy()
155
+ iml1 = iml1[:,:,::-1].copy()
156
+
157
+ ## following data augmentation procedure in PWCNet
158
+ ## https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
159
+ import __main__ # a workaround for "discount_coeff"
160
+ try:
161
+ with open('/scratch/gengshay/iter_counts-%d.txt'%int(__main__.args.logname.split('-')[-1]), 'r') as f:
162
+ iter_counts = int(f.readline())
163
+ except:
164
+ iter_counts = 0
165
+ schedule = [0.5, 1., 50000.] # initial coeff, final_coeff, half life
166
+ schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * \
167
+ (2/(1+np.exp(-1.0986*iter_counts/schedule[2])) - 1)
168
+
169
+ if self.pca_augmentor:
170
+ pca_augmentor = flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff)
171
+ else:
172
+ pca_augmentor = flow_transforms.Scale(1., order=0)
173
+
174
+ if np.random.binomial(1,self.prob):
175
+ co_transform1 = flow_transforms.Compose([
176
+ flow_transforms.SpatialAug([th,tw],
177
+ scale=[0.2,0.,0.1],
178
+ rot=[0.4,0.],
179
+ trans=[0.4,0.],
180
+ squeeze=[0.3,0.], schedule_coeff=schedule_coeff, order=self.order),
181
+ ])
182
+ else:
183
+ co_transform1 = flow_transforms.Compose([
184
+ flow_transforms.RandomCrop([th,tw]),
185
+ ])
186
+
187
+ co_transform2 = flow_transforms.Compose([
188
+ flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff),
189
+ #flow_transforms.PCAAug(schedule_coeff=schedule_coeff),
190
+ flow_transforms.ChromaticAug( schedule_coeff=schedule_coeff, noise=self.noise),
191
+ ])
192
+
193
+ flowl0 = np.concatenate([flowl0,change_size],-1)
194
+ augmented,flowl0,intr = co_transform1([iml0, iml1], flowl0, [fl,cx,cy,bl])
195
+ imol0 = augmented[0]
196
+ imol1 = augmented[1]
197
+ augmented,flowl0,intr = co_transform2(augmented, flowl0, intr)
198
+
199
+ iml0 = augmented[0]
200
+ iml1 = augmented[1]
201
+ flowl0 = flowl0.astype(np.float32)
202
+ change_size = flowl0[:,:,3:]
203
+ flowl0 = flowl0[:,:,:3]
204
+
205
+ # randomly cover a region
206
+ sx=0;sy=0;cx=0;cy=0
207
+ if np.random.binomial(1,0.5):
208
+ sx = int(np.random.uniform(25,100))
209
+ sy = int(np.random.uniform(25,100))
210
+ #sx = int(np.random.uniform(50,150))
211
+ #sy = int(np.random.uniform(50,150))
212
+ cx = int(np.random.uniform(sx,iml1.shape[0]-sx))
213
+ cy = int(np.random.uniform(sy,iml1.shape[1]-sy))
214
+ iml1[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(iml1,0),0)[np.newaxis,np.newaxis]
215
+
216
+ iml0 = torch.Tensor(np.transpose(iml0,(2,0,1)))
217
+ iml1 = torch.Tensor(np.transpose(iml1,(2,0,1)))
218
+
219
+ return iml0, iml1, flowl0, change_size, intr, imol0, imol1, np.asarray([cx-sx,cx+sx,cy-sy,cy+sy])
220
+
221
+ def __len__(self):
222
+ return len(self.iml0)
expansion/dataloader/flow_transforms.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import numbers
6
+ import types
7
+ import scipy.ndimage as ndimage
8
+ import pdb
9
+ import torchvision
10
+ import PIL.Image as Image
11
+ import cv2
12
+ from torch.nn import functional as F
13
+
14
+
15
+ class Compose(object):
16
+ """ Composes several co_transforms together.
17
+ For example:
18
+ >>> co_transforms.Compose([
19
+ >>> co_transforms.CenterCrop(10),
20
+ >>> co_transforms.ToTensor(),
21
+ >>> ])
22
+ """
23
+
24
+ def __init__(self, co_transforms):
25
+ self.co_transforms = co_transforms
26
+
27
+ def __call__(self, input, target):
28
+ for t in self.co_transforms:
29
+ input,target = t(input,target)
30
+ return input,target
31
+
32
+
33
+ class Scale(object):
34
+ """ Rescales the inputs and target arrays to the given 'size'.
35
+ 'size' will be the size of the smaller edge.
36
+ For example, if height > width, then image will be
37
+ rescaled to (size * height / width, size)
38
+ size: size of the smaller edge
39
+ interpolation order: Default: 2 (bilinear)
40
+ """
41
+
42
+ def __init__(self, size, order=1):
43
+ self.ratio = size
44
+ self.order = order
45
+ if order==0:
46
+ self.code=cv2.INTER_NEAREST
47
+ elif order==1:
48
+ self.code=cv2.INTER_LINEAR
49
+ elif order==2:
50
+ self.code=cv2.INTER_CUBIC
51
+
52
+ def __call__(self, inputs, target):
53
+ if self.ratio==1:
54
+ return inputs, target
55
+ h, w, _ = inputs[0].shape
56
+ ratio = self.ratio
57
+
58
+ inputs[0] = cv2.resize(inputs[0], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR)
59
+ inputs[1] = cv2.resize(inputs[1], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR)
60
+ # keep the mask same
61
+ tmp = cv2.resize(target[:,:,2], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_NEAREST)
62
+ target = cv2.resize(target, None, fx=ratio,fy=ratio,interpolation=self.code) * ratio
63
+ target[:,:,2] = tmp
64
+
65
+
66
+ return inputs, target
67
+
68
+
69
+
70
+
71
+ class SpatialAug(object):
72
+ def __init__(self, crop, scale=None, rot=None, trans=None, squeeze=None, schedule_coeff=1, order=1, black=False):
73
+ self.crop = crop
74
+ self.scale = scale
75
+ self.rot = rot
76
+ self.trans = trans
77
+ self.squeeze = squeeze
78
+ self.t = np.zeros(6)
79
+ self.schedule_coeff = schedule_coeff
80
+ self.order = order
81
+ self.black = black
82
+
83
+ def to_identity(self):
84
+ self.t[0] = 1; self.t[2] = 0; self.t[4] = 0; self.t[1] = 0; self.t[3] = 1; self.t[5] = 0;
85
+
86
+ def left_multiply(self, u0, u1, u2, u3, u4, u5):
87
+ result = np.zeros(6)
88
+ result[0] = self.t[0]*u0 + self.t[1]*u2;
89
+ result[1] = self.t[0]*u1 + self.t[1]*u3;
90
+
91
+ result[2] = self.t[2]*u0 + self.t[3]*u2;
92
+ result[3] = self.t[2]*u1 + self.t[3]*u3;
93
+
94
+ result[4] = self.t[4]*u0 + self.t[5]*u2 + u4;
95
+ result[5] = self.t[4]*u1 + self.t[5]*u3 + u5;
96
+ self.t = result
97
+
98
+ def inverse(self):
99
+ result = np.zeros(6)
100
+ a = self.t[0]; c = self.t[2]; e = self.t[4];
101
+ b = self.t[1]; d = self.t[3]; f = self.t[5];
102
+
103
+ denom = a*d - b*c;
104
+
105
+ result[0] = d / denom;
106
+ result[1] = -b / denom;
107
+ result[2] = -c / denom;
108
+ result[3] = a / denom;
109
+ result[4] = (c*f-d*e) / denom;
110
+ result[5] = (b*e-a*f) / denom;
111
+
112
+ return result
113
+
114
+ def grid_transform(self, meshgrid, t, normalize=True, gridsize=None):
115
+ if gridsize is None:
116
+ h, w = meshgrid[0].shape
117
+ else:
118
+ h, w = gridsize
119
+ vgrid = torch.cat([(meshgrid[0] * t[0] + meshgrid[1] * t[2] + t[4])[:,:,np.newaxis],
120
+ (meshgrid[0] * t[1] + meshgrid[1] * t[3] + t[5])[:,:,np.newaxis]],-1)
121
+ if normalize:
122
+ vgrid[:,:,0] = 2.0*vgrid[:,:,0]/max(w-1,1)-1.0
123
+ vgrid[:,:,1] = 2.0*vgrid[:,:,1]/max(h-1,1)-1.0
124
+ return vgrid
125
+
126
+
127
+ def __call__(self, inputs, target):
128
+ h, w, _ = inputs[0].shape
129
+ th, tw = self.crop
130
+ meshgrid = torch.meshgrid([torch.Tensor(range(th)), torch.Tensor(range(tw))])[::-1]
131
+ cornergrid = torch.meshgrid([torch.Tensor([0,th-1]), torch.Tensor([0,tw-1])])[::-1]
132
+
133
+ for i in range(50):
134
+ # im0
135
+ self.to_identity()
136
+ #TODO add mirror
137
+ if np.random.binomial(1,0.5):
138
+ mirror = True
139
+ else:
140
+ mirror = False
141
+ ##TODO
142
+ #mirror = False
143
+ if mirror:
144
+ self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th);
145
+ else:
146
+ self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
147
+ scale0 = 1; scale1 = 1; squeeze0 = 1; squeeze1 = 1;
148
+ if not self.rot is None:
149
+ rot0 = np.random.uniform(-self.rot[0],+self.rot[0])
150
+ rot1 = np.random.uniform(-self.rot[1]*self.schedule_coeff, self.rot[1]*self.schedule_coeff) + rot0
151
+ self.left_multiply(np.cos(rot0), np.sin(rot0), -np.sin(rot0), np.cos(rot0), 0, 0)
152
+ if not self.trans is None:
153
+ trans0 = np.random.uniform(-self.trans[0],+self.trans[0], 2)
154
+ trans1 = np.random.uniform(-self.trans[1]*self.schedule_coeff,+self.trans[1]*self.schedule_coeff, 2) + trans0
155
+ self.left_multiply(1, 0, 0, 1, trans0[0] * tw, trans0[1] * th)
156
+ if not self.squeeze is None:
157
+ squeeze0 = np.exp(np.random.uniform(-self.squeeze[0], self.squeeze[0]))
158
+ squeeze1 = np.exp(np.random.uniform(-self.squeeze[1]*self.schedule_coeff, self.squeeze[1]*self.schedule_coeff)) * squeeze0
159
+ if not self.scale is None:
160
+ scale0 = np.exp(np.random.uniform(self.scale[2]-self.scale[0], self.scale[2]+self.scale[0]))
161
+ scale1 = np.exp(np.random.uniform(-self.scale[1]*self.schedule_coeff, self.scale[1]*self.schedule_coeff)) * scale0
162
+ self.left_multiply(1.0/(scale0*squeeze0), 0, 0, 1.0/(scale0/squeeze0), 0, 0)
163
+
164
+ self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
165
+ transmat0 = self.t.copy()
166
+
167
+ # im1
168
+ self.to_identity()
169
+ if mirror:
170
+ self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th);
171
+ else:
172
+ self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
173
+ if not self.rot is None:
174
+ self.left_multiply(np.cos(rot1), np.sin(rot1), -np.sin(rot1), np.cos(rot1), 0, 0)
175
+ if not self.trans is None:
176
+ self.left_multiply(1, 0, 0, 1, trans1[0] * tw, trans1[1] * th)
177
+ self.left_multiply(1.0/(scale1*squeeze1), 0, 0, 1.0/(scale1/squeeze1), 0, 0)
178
+ self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
179
+ transmat1 = self.t.copy()
180
+ transmat1_inv = self.inverse()
181
+
182
+ if self.black:
183
+ # black augmentation, allowing 0 values in the input images
184
+ # https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/black_augmentation_layer.cu
185
+ break
186
+ else:
187
+ if ((self.grid_transform(cornergrid, transmat0, gridsize=[float(h),float(w)]).abs()>1).sum() +\
188
+ (self.grid_transform(cornergrid, transmat1, gridsize=[float(h),float(w)]).abs()>1).sum()) == 0:
189
+ break
190
+ if i==49:
191
+ print('max_iter in augmentation')
192
+ self.to_identity()
193
+ self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
194
+ self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
195
+ transmat0 = self.t.copy()
196
+ transmat1 = self.t.copy()
197
+
198
+ # do the real work
199
+ vgrid = self.grid_transform(meshgrid, transmat0,gridsize=[float(h),float(w)])
200
+ inputs_0 = F.grid_sample(torch.Tensor(inputs[0]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
201
+ if self.order == 0:
202
+ target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0)
203
+ else:
204
+ target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
205
+
206
+ mask_0 = target[:,:,2:3].copy(); mask_0[mask_0==0]=np.nan
207
+ if self.order == 0:
208
+ mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0)
209
+ else:
210
+ mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
211
+ mask_0[torch.isnan(mask_0)] = 0
212
+
213
+
214
+ vgrid = self.grid_transform(meshgrid, transmat1,gridsize=[float(h),float(w)])
215
+ inputs_1 = F.grid_sample(torch.Tensor(inputs[1]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
216
+
217
+ # flow
218
+ pos = target_0[:,:,:2] + self.grid_transform(meshgrid, transmat0,normalize=False)
219
+ pos = self.grid_transform(pos.permute(2,0,1),transmat1_inv,normalize=False)
220
+ if target_0.shape[2]>=4:
221
+ # scale
222
+ exp = target_0[:,:,3:] * scale1 / scale0
223
+ target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1),
224
+ (pos[:,:,1] - meshgrid[1]).unsqueeze(-1),
225
+ mask_0,
226
+ exp], -1)
227
+ else:
228
+ target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1),
229
+ (pos[:,:,1] - meshgrid[1]).unsqueeze(-1),
230
+ mask_0], -1)
231
+ # target_0[:,:,2].unsqueeze(-1) ], -1)
232
+ inputs = [np.asarray(inputs_0), np.asarray(inputs_1)]
233
+ target = np.asarray(target)
234
+
235
+ return inputs,target
236
+
237
+
238
+ class pseudoPCAAug(object):
239
+ """
240
+ Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
241
+ This version is faster.
242
+ """
243
+ def __init__(self, schedule_coeff=1):
244
+ self.augcolor = torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.5, hue=0.5/3.14)
245
+
246
+ def __call__(self, inputs, target):
247
+ inputs[0] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[0]*255))))/255.
248
+ inputs[1] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[1]*255))))/255.
249
+ return inputs,target
250
+
251
+
252
+ class PCAAug(object):
253
+ """
254
+ Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
255
+ """
256
+ def __init__(self, lmult_pow =[0.4, 0,-0.2],
257
+ lmult_mult =[0.4, 0,0, ],
258
+ lmult_add =[0.03,0,0, ],
259
+ sat_pow =[0.4, 0,0, ],
260
+ sat_mult =[0.5, 0,-0.3],
261
+ sat_add =[0.03,0,0, ],
262
+ col_pow =[0.4, 0,0, ],
263
+ col_mult =[0.2, 0,0, ],
264
+ col_add =[0.02,0,0, ],
265
+ ladd_pow =[0.4, 0,0, ],
266
+ ladd_mult =[0.4, 0,0, ],
267
+ ladd_add =[0.04,0,0, ],
268
+ col_rotate =[1., 0,0, ],
269
+ schedule_coeff=1):
270
+ # no mean
271
+ self.pow_nomean = [1,1,1]
272
+ self.add_nomean = [0,0,0]
273
+ self.mult_nomean = [1,1,1]
274
+ self.pow_withmean = [1,1,1]
275
+ self.add_withmean = [0,0,0]
276
+ self.mult_withmean = [1,1,1]
277
+ self.lmult_pow = 1
278
+ self.lmult_mult = 1
279
+ self.lmult_add = 0
280
+ self.col_angle = 0
281
+ if not ladd_pow is None:
282
+ self.pow_nomean[0] =np.exp(np.random.normal(ladd_pow[2], ladd_pow[0]))
283
+ if not col_pow is None:
284
+ self.pow_nomean[1] =np.exp(np.random.normal(col_pow[2], col_pow[0]))
285
+ self.pow_nomean[2] =np.exp(np.random.normal(col_pow[2], col_pow[0]))
286
+
287
+ if not ladd_add is None:
288
+ self.add_nomean[0] =np.random.normal(ladd_add[2], ladd_add[0])
289
+ if not col_add is None:
290
+ self.add_nomean[1] =np.random.normal(col_add[2], col_add[0])
291
+ self.add_nomean[2] =np.random.normal(col_add[2], col_add[0])
292
+
293
+ if not ladd_mult is None:
294
+ self.mult_nomean[0] =np.exp(np.random.normal(ladd_mult[2], ladd_mult[0]))
295
+ if not col_mult is None:
296
+ self.mult_nomean[1] =np.exp(np.random.normal(col_mult[2], col_mult[0]))
297
+ self.mult_nomean[2] =np.exp(np.random.normal(col_mult[2], col_mult[0]))
298
+
299
+ # with mean
300
+ if not sat_pow is None:
301
+ self.pow_withmean[1] =np.exp(np.random.uniform(sat_pow[2]-sat_pow[0], sat_pow[2]+sat_pow[0]))
302
+ self.pow_withmean[2] =self.pow_withmean[1]
303
+ if not sat_add is None:
304
+ self.add_withmean[1] =np.random.uniform(sat_add[2]-sat_add[0], sat_add[2]+sat_add[0])
305
+ self.add_withmean[2] =self.add_withmean[1]
306
+ if not sat_mult is None:
307
+ self.mult_withmean[1] = np.exp(np.random.uniform(sat_mult[2]-sat_mult[0], sat_mult[2]+sat_mult[0]))
308
+ self.mult_withmean[2] = self.mult_withmean[1]
309
+
310
+ if not lmult_pow is None:
311
+ self.lmult_pow = np.exp(np.random.uniform(lmult_pow[2]-lmult_pow[0], lmult_pow[2]+lmult_pow[0]))
312
+ if not lmult_mult is None:
313
+ self.lmult_mult= np.exp(np.random.uniform(lmult_mult[2]-lmult_mult[0], lmult_mult[2]+lmult_mult[0]))
314
+ if not lmult_add is None:
315
+ self.lmult_add = np.random.uniform(lmult_add[2]-lmult_add[0], lmult_add[2]+lmult_add[0])
316
+ if not col_rotate is None:
317
+ self.col_angle= np.random.uniform(col_rotate[2]-col_rotate[0], col_rotate[2]+col_rotate[0])
318
+
319
+ # eigen vectors
320
+ self.eigvec = np.reshape([0.51,0.56,0.65,0.79,0.01,-0.62,0.35,-0.83,0.44],[3,3]).transpose()
321
+
322
+
323
+ def __call__(self, inputs, target):
324
+ inputs[0] = self.pca_image(inputs[0])
325
+ inputs[1] = self.pca_image(inputs[1])
326
+ return inputs,target
327
+
328
+ def pca_image(self, rgb):
329
+ eig = np.dot(rgb, self.eigvec)
330
+ max_rgb = np.clip(rgb,0,np.inf).max((0,1))
331
+ min_rgb = rgb.min((0,1))
332
+ mean_rgb = rgb.mean((0,1))
333
+ max_abs_eig = np.abs(eig).max((0,1))
334
+ max_l = np.sqrt(np.sum(max_abs_eig*max_abs_eig))
335
+ mean_eig = np.dot(mean_rgb, self.eigvec)
336
+
337
+ # no-mean stuff
338
+ eig -= mean_eig[np.newaxis, np.newaxis]
339
+
340
+ for c in range(3):
341
+ if max_abs_eig[c] > 1e-2:
342
+ mean_eig[c] /= max_abs_eig[c]
343
+ eig[:,:,c] = eig[:,:,c] / max_abs_eig[c];
344
+ eig[:,:,c] = np.power(np.abs(eig[:,:,c]),self.pow_nomean[c]) *\
345
+ ((eig[:,:,c] > 0) -0.5)*2
346
+ eig[:,:,c] = eig[:,:,c] + self.add_nomean[c]
347
+ eig[:,:,c] = eig[:,:,c] * self.mult_nomean[c]
348
+ eig += mean_eig[np.newaxis,np.newaxis]
349
+
350
+ # withmean stuff
351
+ if max_abs_eig[0] > 1e-2:
352
+ eig[:,:,0] = np.power(np.abs(eig[:,:,0]),self.pow_withmean[0]) * \
353
+ ((eig[:,:,0]>0)-0.5)*2;
354
+ eig[:,:,0] = eig[:,:,0] + self.add_withmean[0];
355
+ eig[:,:,0] = eig[:,:,0] * self.mult_withmean[0];
356
+
357
+ s = np.sqrt(eig[:,:,1]*eig[:,:,1] + eig[:,:,2] * eig[:,:,2])
358
+ smask = s > 1e-2
359
+ s1 = np.power(s, self.pow_withmean[1]);
360
+ s1 = np.clip(s1 + self.add_withmean[1], 0,np.inf)
361
+ s1 = s1 * self.mult_withmean[1]
362
+ s1 = s1 * smask + s*(1-smask)
363
+
364
+ # color angle
365
+ if self.col_angle!=0:
366
+ temp1 = np.cos(self.col_angle) * eig[:,:,1] - np.sin(self.col_angle) * eig[:,:,2]
367
+ temp2 = np.sin(self.col_angle) * eig[:,:,1] + np.cos(self.col_angle) * eig[:,:,2]
368
+ eig[:,:,1] = temp1
369
+ eig[:,:,2] = temp2
370
+
371
+ # to origin magnitude
372
+ for c in range(3):
373
+ if max_abs_eig[c] > 1e-2:
374
+ eig[:,:,c] = eig[:,:,c] * max_abs_eig[c]
375
+
376
+ if max_l > 1e-2:
377
+ l1 = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2])
378
+ l1 = l1 / max_l
379
+
380
+ eig[:,:,1][smask] = (eig[:,:,1] / s * s1)[smask]
381
+ eig[:,:,2][smask] = (eig[:,:,2] / s * s1)[smask]
382
+ #eig[:,:,1] = (eig[:,:,1] / s * s1) * smask + eig[:,:,1] * (1-smask)
383
+ #eig[:,:,2] = (eig[:,:,2] / s * s1) * smask + eig[:,:,2] * (1-smask)
384
+
385
+ if max_l > 1e-2:
386
+ l = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2])
387
+ l1 = np.power(l1, self.lmult_pow)
388
+ l1 = np.clip(l1 + self.lmult_add, 0, np.inf)
389
+ l1 = l1 * self.lmult_mult
390
+ l1 = l1 * max_l
391
+ lmask = l > 1e-2
392
+ eig[lmask] = (eig / l[:,:,np.newaxis] * l1[:,:,np.newaxis])[lmask]
393
+ for c in range(3):
394
+ eig[:,:,c][lmask] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c]))[lmask]
395
+ # for c in range(3):
396
+ # # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] * lmask + eig[:,:,c] * (1-lmask)
397
+ # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask]
398
+ # eig[:,:,c] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c])) * lmask + eig[:,:,c] * (1-lmask)
399
+
400
+ return np.clip(np.dot(eig, self.eigvec.transpose()), 0, 1)
401
+
402
+
403
+ class ChromaticAug(object):
404
+ """
405
+ Chromatic augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
406
+ """
407
+ def __init__(self, noise = 0.06,
408
+ gamma = 0.02,
409
+ brightness = 0.02,
410
+ contrast = 0.02,
411
+ color = 0.02,
412
+ schedule_coeff=1):
413
+
414
+ self.noise = np.random.uniform(0,noise)
415
+ self.gamma = np.exp(np.random.normal(0, gamma*schedule_coeff))
416
+ self.brightness = np.random.normal(0, brightness*schedule_coeff)
417
+ self.contrast = np.exp(np.random.normal(0, contrast*schedule_coeff))
418
+ self.color = np.exp(np.random.normal(0, color*schedule_coeff,3))
419
+
420
+ def __call__(self, inputs, target):
421
+ inputs[1] = self.chrom_aug(inputs[1])
422
+ # noise
423
+ inputs[0]+=np.random.normal(0, self.noise, inputs[0].shape)
424
+ inputs[1]+=np.random.normal(0, self.noise, inputs[0].shape)
425
+ return inputs,target
426
+
427
+ def chrom_aug(self, rgb):
428
+ # color change
429
+ mean_in = rgb.sum(-1)
430
+ rgb = rgb*self.color[np.newaxis,np.newaxis]
431
+ brightness_coeff = mean_in / (rgb.sum(-1)+0.01)
432
+ rgb = np.clip(rgb*brightness_coeff[:,:,np.newaxis],0,1)
433
+ # gamma
434
+ rgb = np.power(rgb,self.gamma)
435
+ # brightness
436
+ rgb += self.brightness
437
+ # contrast
438
+ rgb = 0.5 + ( rgb-0.5)*self.contrast
439
+ rgb = np.clip(rgb, 0, 1)
440
+ return rgb
expansion/dataloader/hd1klist.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+ import pdb
8
+
9
+ IMG_EXTENSIONS = [
10
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
11
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12
+ ]
13
+
14
+
15
+ def is_image_file(filename):
16
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17
+
18
+ def dataloader(filepath):
19
+
20
+ left_fold = 'image_2/'
21
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('HD1K2018') > -1]
22
+ train = sorted(train)
23
+
24
+ l0_train = [filepath+left_fold+img for img in train]
25
+ l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
26
+ l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
27
+ flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
28
+
29
+ return l0_train, l1_train, flow_train
expansion/dataloader/kitti12list.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+
8
+ IMG_EXTENSIONS = [
9
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
10
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11
+ ]
12
+
13
+
14
+ def is_image_file(filename):
15
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16
+
17
+ def dataloader(filepath):
18
+
19
+ left_fold = 'colored_0/'
20
+ flow_noc = 'flow_occ/'
21
+
22
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23
+
24
+ l0_train = [filepath+left_fold+img for img in train]
25
+ l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
26
+ flow_train = [filepath+flow_noc+img for img in train]
27
+
28
+
29
+ return l0_train, l1_train, flow_train
expansion/dataloader/kitti15list.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+
8
+ IMG_EXTENSIONS = [
9
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
10
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11
+ ]
12
+
13
+
14
+ def is_image_file(filename):
15
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16
+
17
+ def dataloader(filepath):
18
+
19
+ left_fold = 'image_2/'
20
+ flow_noc = 'flow_occ/'
21
+
22
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23
+
24
+ l0_train = [filepath+left_fold+img for img in train]
25
+ l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
26
+ flow_train = [filepath+flow_noc+img for img in train]
27
+
28
+
29
+ return sorted(l0_train), sorted(l1_train), sorted(flow_train)
expansion/dataloader/kitti15list_train.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+
8
+ IMG_EXTENSIONS = [
9
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
10
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11
+ ]
12
+
13
+
14
+ def is_image_file(filename):
15
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16
+
17
+ def dataloader(filepath):
18
+
19
+ left_fold = 'image_2/'
20
+ flow_noc = 'flow_occ/'
21
+
22
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23
+
24
+ train = [i for i in train if int(i.split('_')[0])%5!=0]
25
+
26
+ l0_train = [filepath+left_fold+img for img in train]
27
+ l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
28
+ flow_train = [filepath+flow_noc+img for img in train]
29
+
30
+
31
+ return sorted(l0_train), sorted(l1_train), sorted(flow_train)
expansion/dataloader/kitti15list_train_lidar.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+
8
+ IMG_EXTENSIONS = [
9
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
10
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11
+ ]
12
+
13
+
14
+ def is_image_file(filename):
15
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16
+
17
+ def dataloader(filepath):
18
+
19
+ left_fold = 'image_2/'
20
+ flow_noc = 'flow_occ/'
21
+
22
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23
+
24
+ # train = [i for i in train if int(i.split('_')[0])%5!=0]
25
+ with open('/data/gengshay/kitti_scene/devkit/mapping/train_mapping.txt','r') as f:
26
+ flags = [True if len(i)>1 else False for i in f.readlines()]
27
+ train = [fn for (it,fn) in enumerate(sorted(train)) if flags[it] ][:100]
28
+
29
+ l0_train = [filepath+left_fold+img for img in train]
30
+ l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
31
+ flow_train = [filepath+flow_noc+img for img in train]
32
+
33
+
34
+ return sorted(l0_train), sorted(l1_train), sorted(flow_train)
expansion/dataloader/kitti15list_val.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+
8
+ IMG_EXTENSIONS = [
9
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
10
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11
+ ]
12
+
13
+
14
+ def is_image_file(filename):
15
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16
+
17
+ def dataloader(filepath):
18
+
19
+ left_fold = 'image_2/'
20
+ flow_noc = 'flow_occ/'
21
+
22
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23
+
24
+ train = [i for i in train if int(i.split('_')[0])%5==0]
25
+
26
+ l0_train = [filepath+left_fold+img for img in train]
27
+ l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
28
+ flow_train = [filepath+flow_noc+img for img in train]
29
+
30
+
31
+ return sorted(l0_train), sorted(l1_train), sorted(flow_train)
expansion/dataloader/kitti15list_val_lidar.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+
8
+ IMG_EXTENSIONS = [
9
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
10
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11
+ ]
12
+
13
+
14
+ def is_image_file(filename):
15
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16
+
17
+ def dataloader(filepath):
18
+
19
+ left_fold = 'image_2/'
20
+ flow_noc = 'flow_occ/'
21
+
22
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23
+
24
+ # train = [i for i in train if int(i.split('_')[0])%5!=0]
25
+ with open('/data/gengshay/kitti_scene/devkit/mapping/train_mapping.txt','r') as f:
26
+ flags = [True if len(i)>1 else False for i in f.readlines()]
27
+ train = [fn for (it,fn) in enumerate(sorted(train)) if flags[it] ][100:]
28
+
29
+ l0_train = [filepath+left_fold+img for img in train]
30
+ l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
31
+ flow_train = [filepath+flow_noc+img for img in train]
32
+
33
+
34
+ return sorted(l0_train), sorted(l1_train), sorted(flow_train)
expansion/dataloader/kitti15list_val_mr.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+
8
+ IMG_EXTENSIONS = [
9
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
10
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11
+ ]
12
+
13
+
14
+ def is_image_file(filename):
15
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16
+
17
+ def dataloader(filepath):
18
+
19
+ left_fold = 'image_2/'
20
+ flow_noc = 'flow_occ/'
21
+
22
+ train = [img for img in os.listdir(filepath+left_fold) if 'Kitti' in img and img.find('_10') > -1]
23
+
24
+ # train = [i for i in train if int(i.split('_')[1])%5==0]
25
+ import pdb; pdb.set_trace()
26
+ train = sorted([i for i in train if int(i.split('_')[1])%5==0])[0:1]
27
+
28
+ l0_train = [filepath+left_fold+img for img in train]
29
+ l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
30
+ flow_train = [filepath+flow_noc+img for img in train]
31
+
32
+ l0_train += [filepath+left_fold+img.replace('_10','_09') for img in train]
33
+ l1_train += [filepath+left_fold+img for img in train]
34
+ flow_train += flow_train
35
+
36
+ tmp = l0_train
37
+ l0_train = l0_train+ [i.replace('rob_flow', 'kitti_scene').replace('Kitti2015_','') for i in l1_train]
38
+ l1_train = l1_train+tmp
39
+ flow_train += flow_train
40
+
41
+ return l0_train, l1_train, flow_train
expansion/dataloader/robloader.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numbers
3
+ import torch
4
+ import torch.utils.data as data
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ import random
8
+ from PIL import Image, ImageOps
9
+ import numpy as np
10
+ import torchvision
11
+ from . import flow_transforms
12
+ import pdb
13
+ import cv2
14
+ from utils.flowlib import read_flow
15
+ from utils.util_flow import readPFM
16
+
17
+
18
+ def default_loader(path):
19
+ return Image.open(path).convert('RGB')
20
+
21
+ def flow_loader(path):
22
+ if '.pfm' in path:
23
+ data = readPFM(path)[0]
24
+ data[:,:,2] = 1
25
+ return data
26
+ else:
27
+ return read_flow(path)
28
+
29
+
30
+ def disparity_loader(path):
31
+ if '.png' in path:
32
+ data = Image.open(path)
33
+ data = np.ascontiguousarray(data,dtype=np.float32)/256
34
+ return data
35
+ else:
36
+ return readPFM(path)[0]
37
+
38
+ class myImageFloder(data.Dataset):
39
+ def __init__(self, iml0, iml1, flowl0, loader=default_loader, dploader= flow_loader, scale=1.,shape=[320,448], order=1, noise=0.06, pca_augmentor=True, prob = 1., cover=False, black=False, scale_aug=[0.4,0.2]):
40
+ self.iml0 = iml0
41
+ self.iml1 = iml1
42
+ self.flowl0 = flowl0
43
+ self.loader = loader
44
+ self.dploader = dploader
45
+ self.scale=scale
46
+ self.shape=shape
47
+ self.order=order
48
+ self.noise = noise
49
+ self.pca_augmentor = pca_augmentor
50
+ self.prob = prob
51
+ self.cover = cover
52
+ self.black = black
53
+ self.scale_aug = scale_aug
54
+
55
+ def __getitem__(self, index):
56
+ iml0 = self.iml0[index]
57
+ iml1 = self.iml1[index]
58
+ flowl0= self.flowl0[index]
59
+ th, tw = self.shape
60
+
61
+ iml0 = self.loader(iml0)
62
+ iml1 = self.loader(iml1)
63
+ iml1 = np.asarray(iml1)/255.
64
+ iml0 = np.asarray(iml0)/255.
65
+ iml0 = iml0[:,:,::-1].copy()
66
+ iml1 = iml1[:,:,::-1].copy()
67
+ flowl0 = self.dploader(flowl0)
68
+ #flowl0[:,:,-1][flowl0[:,:,0]==np.inf]=0 # for gtav window pfm files
69
+ #flowl0[:,:,0][~flowl0[:,:,2].astype(bool)]=0
70
+ #flowl0[:,:,1][~flowl0[:,:,2].astype(bool)]=0 # avoid nan in grad
71
+ flowl0 = np.ascontiguousarray(flowl0,dtype=np.float32)
72
+ flowl0[np.isnan(flowl0)] = 1e6 # set to max
73
+
74
+ ## following data augmentation procedure in PWCNet
75
+ ## https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
76
+ import __main__ # a workaround for "discount_coeff"
77
+ try:
78
+ with open('iter_counts-%d.txt'%int(__main__.args.logname.split('-')[-1]), 'r') as f:
79
+ iter_counts = int(f.readline())
80
+ except:
81
+ iter_counts = 0
82
+ schedule = [0.5, 1., 50000.] # initial coeff, final_coeff, half life
83
+ schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * \
84
+ (2/(1+np.exp(-1.0986*iter_counts/schedule[2])) - 1)
85
+
86
+ if self.pca_augmentor:
87
+ pca_augmentor = flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff)
88
+ else:
89
+ pca_augmentor = flow_transforms.Scale(1., order=0)
90
+
91
+ if np.random.binomial(1,self.prob):
92
+ co_transform = flow_transforms.Compose([
93
+ flow_transforms.Scale(self.scale, order=self.order),
94
+ #flow_transforms.SpatialAug([th,tw], trans=[0.2,0.03], order=self.order, black=self.black),
95
+ flow_transforms.SpatialAug([th,tw],scale=[self.scale_aug[0],0.03,self.scale_aug[1]],
96
+ rot=[0.4,0.03],
97
+ trans=[0.4,0.03],
98
+ squeeze=[0.3,0.], schedule_coeff=schedule_coeff, order=self.order, black=self.black),
99
+ #flow_transforms.pseudoPCAAug(schedule_coeff=schedule_coeff),
100
+ flow_transforms.PCAAug(schedule_coeff=schedule_coeff),
101
+ flow_transforms.ChromaticAug( schedule_coeff=schedule_coeff, noise=self.noise),
102
+ ])
103
+ else:
104
+ co_transform = flow_transforms.Compose([
105
+ flow_transforms.Scale(self.scale, order=self.order),
106
+ flow_transforms.SpatialAug([th,tw], trans=[0.4,0.03], order=self.order, black=self.black)
107
+ ])
108
+
109
+ augmented,flowl0 = co_transform([iml0, iml1], flowl0)
110
+ iml0 = augmented[0]
111
+ iml1 = augmented[1]
112
+
113
+ if self.cover:
114
+ ## randomly cover a region
115
+ # following sec. 3.2 of http://openaccess.thecvf.com/content_CVPR_2019/html/Yang_Hierarchical_Deep_Stereo_Matching_on_High-Resolution_Images_CVPR_2019_paper.html
116
+ if np.random.binomial(1,0.5):
117
+ #sx = int(np.random.uniform(25,100))
118
+ #sy = int(np.random.uniform(25,100))
119
+ sx = int(np.random.uniform(50,125))
120
+ sy = int(np.random.uniform(50,125))
121
+ #sx = int(np.random.uniform(50,150))
122
+ #sy = int(np.random.uniform(50,150))
123
+ cx = int(np.random.uniform(sx,iml1.shape[0]-sx))
124
+ cy = int(np.random.uniform(sy,iml1.shape[1]-sy))
125
+ iml1[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(iml1,0),0)[np.newaxis,np.newaxis]
126
+
127
+ iml0 = torch.Tensor(np.transpose(iml0,(2,0,1)))
128
+ iml1 = torch.Tensor(np.transpose(iml1,(2,0,1)))
129
+
130
+ return iml0, iml1, flowl0
131
+
132
+ def __len__(self):
133
+ return len(self.iml0)
expansion/dataloader/sceneflowlist.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+ import glob
4
+
5
+ def dataloader(filepath, level=6):
6
+ iml0 = []
7
+ iml1 = []
8
+ flowl0 = []
9
+ disp0 = []
10
+ dispc = []
11
+ calib = []
12
+ level_stars = '/*'*level
13
+ candidate_pool = glob.glob('%s/optical_flow%s'%(filepath,level_stars))
14
+ for flow_path in sorted(candidate_pool):
15
+ if 'TEST' in flow_path: continue
16
+ if 'flower_storm_x2/into_future/right/OpticalFlowIntoFuture_0023_R.pfm' in flow_path:
17
+ continue
18
+ if 'flower_storm_x2/into_future/left/OpticalFlowIntoFuture_0023_L.pfm' in flow_path:
19
+ continue
20
+ if 'flower_storm_augmented0_x2/into_future/right/OpticalFlowIntoFuture_0023_R.pfm' in flow_path:
21
+ continue
22
+ if 'flower_storm_augmented0_x2/into_future/left/OpticalFlowIntoFuture_0023_L.pfm' in flow_path:
23
+ continue
24
+ if 'FlyingThings' in flow_path and '_0014_' in flow_path:
25
+ continue
26
+ if 'FlyingThings' in flow_path and '_0015_' in flow_path:
27
+ continue
28
+ idd = flow_path.split('/')[-1].split('_')[-2]
29
+ if 'into_future' in flow_path:
30
+ idd_p1 = '%04d'%(int(idd)+1)
31
+ else:
32
+ idd_p1 = '%04d'%(int(idd)-1)
33
+ if os.path.exists(flow_path.replace(idd,idd_p1)):
34
+ d0_path = flow_path.replace('/into_future/','/').replace('/into_past/','/').replace('optical_flow','disparity')
35
+ d0_path = '%s/%s.pfm'%(d0_path.rsplit('/',1)[0],idd)
36
+ dc_path = flow_path.replace('optical_flow','disparity_change')
37
+ dc_path = '%s/%s.pfm'%(dc_path.rsplit('/',1)[0],idd)
38
+ im_path = flow_path.replace('/into_future/','/').replace('/into_past/','/').replace('optical_flow','frames_cleanpass')
39
+ im0_path = '%s/%s.png'%(im_path.rsplit('/',1)[0],idd)
40
+ im1_path = '%s/%s.png'%(im_path.rsplit('/',1)[0],idd_p1)
41
+ #with open('%s/camera_data.txt'%(im0_path.replace('frames_cleanpass','camera_data').rsplit('/',2)[0]),'r') as f:
42
+ # if 'FlyingThings' in flow_path and len(f.readlines())!=40:
43
+ # print(flow_path)
44
+ # continue
45
+ iml0.append(im0_path)
46
+ iml1.append(im1_path)
47
+ flowl0.append(flow_path)
48
+ disp0.append(d0_path)
49
+ dispc.append(dc_path)
50
+ calib.append('%s/camera_data.txt'%(im0_path.replace('frames_cleanpass','camera_data').rsplit('/',2)[0]))
51
+ return iml0, iml1, flowl0, disp0, dispc, calib
expansion/dataloader/seqlist.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+ import glob
8
+
9
+ IMG_EXTENSIONS = [
10
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
11
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12
+ ]
13
+
14
+
15
+ def is_image_file(filename):
16
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17
+
18
+ def dataloader(filepath):
19
+
20
+ train = [img for img in sorted(glob.glob('%s/*'%filepath))]
21
+
22
+ l0_train = train[:-1]
23
+ l1_train = train[1:]
24
+
25
+
26
+ return sorted(l0_train), sorted(l1_train), sorted(l0_train)
expansion/dataloader/sintellist.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+ import pdb
8
+
9
+ IMG_EXTENSIONS = [
10
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
11
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12
+ ]
13
+
14
+
15
+ def is_image_file(filename):
16
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17
+
18
+ def dataloader(filepath):
19
+
20
+ left_fold = 'image_2/'
21
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
22
+
23
+ l0_train = [filepath+left_fold+img for img in train]
24
+ l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25
+
26
+ #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
27
+
28
+ l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
29
+ flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
30
+
31
+
32
+ return l0_train, l1_train, flow_train
expansion/dataloader/sintellist_clean.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+ import pdb
8
+
9
+ IMG_EXTENSIONS = [
10
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
11
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12
+ ]
13
+
14
+
15
+ def is_image_file(filename):
16
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17
+
18
+ def dataloader(filepath):
19
+
20
+ left_fold = 'image_2/'
21
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_clean') > -1]
22
+
23
+ l0_train = [filepath+left_fold+img for img in train]
24
+ l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25
+
26
+ #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
27
+
28
+ l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
29
+ flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
30
+
31
+ return l0_train, l1_train, flow_train
expansion/dataloader/sintellist_final.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+ import pdb
8
+
9
+ IMG_EXTENSIONS = [
10
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
11
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12
+ ]
13
+
14
+
15
+ def is_image_file(filename):
16
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17
+
18
+ def dataloader(filepath):
19
+
20
+ left_fold = 'image_2/'
21
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_final') > -1]
22
+
23
+ l0_train = [filepath+left_fold+img for img in train]
24
+ l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25
+
26
+ #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
27
+
28
+ l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
29
+ flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
30
+
31
+ pdb.set_trace()
32
+ return l0_train, l1_train, flow_train
expansion/dataloader/sintellist_train.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+ import pdb
8
+
9
+ IMG_EXTENSIONS = [
10
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
11
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12
+ ]
13
+
14
+
15
+ def is_image_file(filename):
16
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17
+
18
+ def dataloader(filepath):
19
+
20
+ left_fold = 'image_2/'
21
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
22
+
23
+ l0_train = [filepath+left_fold+img for img in train]
24
+ l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25
+
26
+ l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val
27
+
28
+ l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
29
+ flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
30
+
31
+
32
+ return l0_train, l1_train, flow_train
expansion/dataloader/sintellist_val.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+ import pdb
8
+
9
+ IMG_EXTENSIONS = [
10
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
11
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12
+ ]
13
+
14
+
15
+ def is_image_file(filename):
16
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17
+
18
+ def dataloader(filepath):
19
+
20
+ left_fold = 'image_2/'
21
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
22
+
23
+ l0_train = [filepath+left_fold+img for img in train]
24
+ l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25
+
26
+ l0_train = [i for i in l0_train if ('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i)] # remove 10 as val
27
+ #l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val
28
+
29
+ l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
30
+ flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
31
+
32
+
33
+ return sorted(l0_train)[::3], sorted(l1_train)[::3], sorted(flow_train)[::3]
34
+ # return sorted(l0_train)[::10], sorted(l1_train)[::10], sorted(flow_train)[::10]
expansion/dataloader/thingslist.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+ import os
5
+ import os.path
6
+ import numpy as np
7
+
8
+ IMG_EXTENSIONS = [
9
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
10
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11
+ ]
12
+
13
+
14
+ def is_image_file(filename):
15
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16
+
17
+ def dataloader(filepath):
18
+ exc_list = [
19
+ '0004117.flo',
20
+ '0003149.flo',
21
+ '0001203.flo',
22
+ '0003147.flo',
23
+ '0003666.flo',
24
+ '0006337.flo',
25
+ '0006336.flo',
26
+ '0007126.flo',
27
+ '0004118.flo',
28
+ ]
29
+
30
+ left_fold = 'image_clean/left/'
31
+ flow_noc = 'flow/left/into_future/'
32
+ train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
33
+
34
+ l0_trainlf = [filepath+left_fold+img.replace('flo','png') for img in train]
35
+ l1_trainlf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlf]
36
+ flow_trainlf = [filepath+flow_noc+img for img in train]
37
+
38
+
39
+ exc_list = [
40
+ '0003148.flo',
41
+ '0004117.flo',
42
+ '0002890.flo',
43
+ '0003149.flo',
44
+ '0001203.flo',
45
+ '0003666.flo',
46
+ '0006337.flo',
47
+ '0006336.flo',
48
+ '0004118.flo',
49
+ ]
50
+
51
+ left_fold = 'image_clean/right/'
52
+ flow_noc = 'flow/right/into_future/'
53
+ train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
54
+
55
+ l0_trainrf = [filepath+left_fold+img.replace('flo','png') for img in train]
56
+ l1_trainrf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrf]
57
+ flow_trainrf = [filepath+flow_noc+img for img in train]
58
+
59
+
60
+ exc_list = [
61
+ '0004237.flo',
62
+ '0004705.flo',
63
+ '0004045.flo',
64
+ '0004346.flo',
65
+ '0000161.flo',
66
+ '0000931.flo',
67
+ '0000121.flo',
68
+ '0010822.flo',
69
+ '0004117.flo',
70
+ '0006023.flo',
71
+ '0005034.flo',
72
+ '0005054.flo',
73
+ '0000162.flo',
74
+ '0000053.flo',
75
+ '0005055.flo',
76
+ '0003147.flo',
77
+ '0004876.flo',
78
+ '0000163.flo',
79
+ '0006878.flo',
80
+ ]
81
+
82
+ left_fold = 'image_clean/left/'
83
+ flow_noc = 'flow/left/into_past/'
84
+ train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
85
+
86
+ l0_trainlp = [filepath+left_fold+img.replace('flo','png') for img in train]
87
+ l1_trainlp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlp]
88
+ flow_trainlp = [filepath+flow_noc+img for img in train]
89
+
90
+ exc_list = [
91
+ '0003148.flo',
92
+ '0004705.flo',
93
+ '0000161.flo',
94
+ '0000121.flo',
95
+ '0004117.flo',
96
+ '0000160.flo',
97
+ '0005034.flo',
98
+ '0005054.flo',
99
+ '0000162.flo',
100
+ '0000053.flo',
101
+ '0005055.flo',
102
+ '0003147.flo',
103
+ '0001549.flo',
104
+ '0000163.flo',
105
+ '0006336.flo',
106
+ '0001648.flo',
107
+ '0006878.flo',
108
+ ]
109
+
110
+ left_fold = 'image_clean/right/'
111
+ flow_noc = 'flow/right/into_past/'
112
+ train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
113
+
114
+ l0_trainrp = [filepath+left_fold+img.replace('flo','png') for img in train]
115
+ l1_trainrp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrp]
116
+ flow_trainrp = [filepath+flow_noc+img for img in train]
117
+
118
+
119
+ l0_train = l0_trainlf + l0_trainrf + l0_trainlp + l0_trainrp
120
+ l1_train = l1_trainlf + l1_trainrf + l1_trainlp + l1_trainrp
121
+ flow_train = flow_trainlf + flow_trainrf + flow_trainlp + flow_trainrp
122
+ return l0_train, l1_train, flow_train
expansion/models/VCN_exp.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ import os
6
+ os.environ['PYTHON_EGG_CACHE'] = 'tmp/' # a writable directory
7
+ import numpy as np
8
+ import math
9
+ import pdb
10
+ import time
11
+
12
+ from .submodule import pspnet, bfmodule, conv
13
+ from .conv4d import sepConv4d, sepConv4dBlock, butterfly4D
14
+
15
+ class flow_reg(nn.Module):
16
+ """
17
+ Soft winner-take-all that selects the most likely diplacement.
18
+ Set ent=True to enable entropy output.
19
+ Set maxdisp to adjust maximum allowed displacement towards one side.
20
+ maxdisp=4 searches for a 9x9 region.
21
+ Set fac to squeeze search window.
22
+ maxdisp=4 and fac=2 gives search window of 9x5
23
+ """
24
+ def __init__(self, size, ent=False, maxdisp = int(4), fac=1):
25
+ B,W,H = size
26
+ super(flow_reg, self).__init__()
27
+ self.ent = ent
28
+ self.md = maxdisp
29
+ self.fac = fac
30
+ self.truncated = True
31
+ self.wsize = 3 # by default using truncation 7x7
32
+
33
+ flowrangey = range(-maxdisp,maxdisp+1)
34
+ flowrangex = range(-int(maxdisp//self.fac),int(maxdisp//self.fac)+1)
35
+ meshgrid = np.meshgrid(flowrangex,flowrangey)
36
+ flowy = np.tile( np.reshape(meshgrid[0],[1,2*maxdisp+1,2*int(maxdisp//self.fac)+1,1,1]), (B,1,1,H,W) )
37
+ flowx = np.tile( np.reshape(meshgrid[1],[1,2*maxdisp+1,2*int(maxdisp//self.fac)+1,1,1]), (B,1,1,H,W) )
38
+ self.register_buffer('flowx',torch.Tensor(flowx))
39
+ self.register_buffer('flowy',torch.Tensor(flowy))
40
+
41
+ self.pool3d = nn.MaxPool3d((self.wsize*2+1,self.wsize*2+1,1),stride=1,padding=(self.wsize,self.wsize,0))
42
+
43
+ def forward(self, x):
44
+ b,u,v,h,w = x.shape
45
+ oldx = x
46
+
47
+ if self.truncated:
48
+ # truncated softmax
49
+ x = x.view(b,u*v,h,w)
50
+
51
+ idx = x.argmax(1)[:,np.newaxis]
52
+ if x.is_cuda:
53
+ mask = Variable(torch.cuda.HalfTensor(b,u*v,h,w)).fill_(0)
54
+ else:
55
+ mask = Variable(torch.FloatTensor(b,u*v,h,w)).fill_(0)
56
+ mask.scatter_(1,idx,1)
57
+ mask = mask.view(b,1,u,v,-1)
58
+ mask = self.pool3d(mask)[:,0].view(b,u,v,h,w)
59
+
60
+ ninf = x.clone().fill_(-np.inf).view(b,u,v,h,w)
61
+ x = torch.where(mask.byte(),oldx,ninf)
62
+ else:
63
+ self.wsize = (np.sqrt(u*v)-1)/2
64
+
65
+ b,u,v,h,w = x.shape
66
+ x = F.softmax(x.view(b,-1,h,w),1).view(b,u,v,h,w)
67
+ outx = torch.sum(torch.sum(x*self.flowx,1),1,keepdim=True)
68
+ outy = torch.sum(torch.sum(x*self.flowy,1),1,keepdim=True)
69
+
70
+ if self.ent:
71
+ # local
72
+ local_entropy = (-x*torch.clamp(x,1e-9,1-1e-9).log()).sum(1).sum(1)[:,np.newaxis]
73
+ if self.wsize == 0:
74
+ local_entropy[:] = 1.
75
+ else:
76
+ local_entropy /= np.log((self.wsize*2+1)**2)
77
+
78
+ # global
79
+ x = F.softmax(oldx.view(b,-1,h,w),1).view(b,u,v,h,w)
80
+ global_entropy = (-x*torch.clamp(x,1e-9,1-1e-9).log()).sum(1).sum(1)[:,np.newaxis]
81
+ global_entropy /= np.log(x.shape[1]*x.shape[2])
82
+ return torch.cat([outx,outy],1),torch.cat([local_entropy, global_entropy],1)
83
+ else:
84
+ return torch.cat([outx,outy],1),None
85
+
86
+
87
+ class WarpModule(nn.Module):
88
+ """
89
+ taken from https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py
90
+ """
91
+ def __init__(self, size):
92
+ super(WarpModule, self).__init__()
93
+ B,W,H = size
94
+ # mesh grid
95
+ xx = torch.arange(0, W).view(1,-1).repeat(H,1)
96
+ yy = torch.arange(0, H).view(-1,1).repeat(1,W)
97
+ xx = xx.view(1,1,H,W).repeat(B,1,1,1)
98
+ yy = yy.view(1,1,H,W).repeat(B,1,1,1)
99
+ self.register_buffer('grid',torch.cat((xx,yy),1).float())
100
+
101
+ def forward(self, x, flo):
102
+ """
103
+ warp an image/tensor (im2) back to im1, according to the optical flow
104
+
105
+ x: [B, C, H, W] (im2)
106
+ flo: [B, 2, H, W] flow
107
+
108
+ """
109
+ B, C, H, W = x.size()
110
+ vgrid = self.grid + flo
111
+
112
+ # scale grid to [-1,1]
113
+ vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:]/max(W-1,1)-1.0
114
+ vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:]/max(H-1,1)-1.0
115
+
116
+ vgrid = vgrid.permute(0,2,3,1)
117
+ output = nn.functional.grid_sample(x, vgrid,align_corners=True)
118
+ mask = ((vgrid[:,:,:,0].abs()<1) * (vgrid[:,:,:,1].abs()<1)) >0
119
+ return output*mask.unsqueeze(1).float(), mask
120
+
121
+
122
+ def get_grid(B,H,W):
123
+ meshgrid_base = np.meshgrid(range(0,W), range(0,H))[::-1]
124
+ basey = np.reshape(meshgrid_base[0],[1,1,1,H,W])
125
+ basex = np.reshape(meshgrid_base[1],[1,1,1,H,W])
126
+ grid = torch.tensor(np.concatenate((basex.reshape((-1,H,W,1)),basey.reshape((-1,H,W,1))),-1)).cuda().float()
127
+ return grid.view(1,1,H,W,2)
128
+
129
+
130
+ class VCN(nn.Module):
131
+ """
132
+ VCN.
133
+ md defines maximum displacement for each level, following a coarse-to-fine-warping scheme
134
+ fac defines squeeze parameter for the coarsest level
135
+ """
136
+ def __init__(self, size, md=[4,4,4,4,4], fac=1.,exp_unc=False): # exp_uncertainty
137
+ super(VCN,self).__init__()
138
+ self.md = md
139
+ self.fac = fac
140
+ use_entropy = True
141
+ withbn = True
142
+
143
+ ## pspnet
144
+ self.pspnet = pspnet(is_proj=False)
145
+
146
+ ### Volumetric-UNet
147
+ fdima1 = 128 # 6/5/4
148
+ fdima2 = 64 # 3/2
149
+ fdimb1 = 16 # 6/5/4/3
150
+ fdimb2 = 12 # 2
151
+
152
+ full=False
153
+ self.f6 = butterfly4D(fdima1, fdimb1,withbn=withbn,full=full)
154
+ self.p6 = sepConv4d(fdimb1,fdimb1, with_bn=False, full=full)
155
+
156
+ self.f5 = butterfly4D(fdima1, fdimb1,withbn=withbn, full=full)
157
+ self.p5 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full)
158
+
159
+ self.f4 = butterfly4D(fdima1, fdimb1,withbn=withbn,full=full)
160
+ self.p4 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full)
161
+
162
+ self.f3 = butterfly4D(fdima2, fdimb1,withbn=withbn,full=full)
163
+ self.p3 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full)
164
+
165
+ full=True
166
+ self.f2 = butterfly4D(fdima2, fdimb2,withbn=withbn,full=full)
167
+ self.p2 = sepConv4d(fdimb2,fdimb2, with_bn=False,full=full)
168
+
169
+ self.flow_reg64 = flow_reg([fdimb1*size[0],size[1]//64,size[2]//64], ent=use_entropy, maxdisp=self.md[0], fac=self.fac)
170
+ self.flow_reg32 = flow_reg([fdimb1*size[0],size[1]//32,size[2]//32], ent=use_entropy, maxdisp=self.md[1])
171
+ self.flow_reg16 = flow_reg([fdimb1*size[0],size[1]//16,size[2]//16], ent=use_entropy, maxdisp=self.md[2])
172
+ self.flow_reg8 = flow_reg([fdimb1*size[0],size[1]//8,size[2]//8] , ent=use_entropy, maxdisp=self.md[3])
173
+ self.flow_reg4 = flow_reg([fdimb2*size[0],size[1]//4,size[2]//4] , ent=use_entropy, maxdisp=self.md[4])
174
+
175
+ self.warp5 = WarpModule([size[0],size[1]//32,size[2]//32])
176
+ self.warp4 = WarpModule([size[0],size[1]//16,size[2]//16])
177
+ self.warp3 = WarpModule([size[0],size[1]//8,size[2]//8])
178
+ self.warp2 = WarpModule([size[0],size[1]//4,size[2]//4])
179
+
180
+ ## hypotheses fusion modules, adopted from the refinement module of PWCNet
181
+ # https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py
182
+ # c6
183
+ self.dc6_conv1 = conv(128+4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1)
184
+ self.dc6_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
185
+ self.dc6_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
186
+ self.dc6_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
187
+ self.dc6_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
188
+ self.dc6_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
189
+ self.dc6_conv7 = nn.Conv2d(32,2*fdimb1,kernel_size=3,stride=1,padding=1,bias=True)
190
+
191
+ # c5
192
+ self.dc5_conv1 = conv(128+4*fdimb1*2, 128, kernel_size=3, stride=1, padding=1, dilation=1)
193
+ self.dc5_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
194
+ self.dc5_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
195
+ self.dc5_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
196
+ self.dc5_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
197
+ self.dc5_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
198
+ self.dc5_conv7 = nn.Conv2d(32,2*fdimb1*2,kernel_size=3,stride=1,padding=1,bias=True)
199
+
200
+ # c4
201
+ self.dc4_conv1 = conv(128+4*fdimb1*3, 128, kernel_size=3, stride=1, padding=1, dilation=1)
202
+ self.dc4_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
203
+ self.dc4_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
204
+ self.dc4_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
205
+ self.dc4_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
206
+ self.dc4_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
207
+ self.dc4_conv7 = nn.Conv2d(32,2*fdimb1*3,kernel_size=3,stride=1,padding=1,bias=True)
208
+
209
+ # c3
210
+ self.dc3_conv1 = conv(64+16*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1)
211
+ self.dc3_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
212
+ self.dc3_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
213
+ self.dc3_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
214
+ self.dc3_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
215
+ self.dc3_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
216
+ self.dc3_conv7 = nn.Conv2d(32,8*fdimb1,kernel_size=3,stride=1,padding=1,bias=True)
217
+
218
+ # c2
219
+ self.dc2_conv1 = conv(64+16*fdimb1+4*fdimb2, 128, kernel_size=3, stride=1, padding=1, dilation=1)
220
+ self.dc2_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
221
+ self.dc2_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
222
+ self.dc2_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
223
+ self.dc2_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
224
+ self.dc2_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
225
+ self.dc2_conv7 = nn.Conv2d(32,4*2*fdimb1 + 2*fdimb2,kernel_size=3,stride=1,padding=1,bias=True)
226
+
227
+ self.dc6_conv = nn.Sequential( self.dc6_conv1,
228
+ self.dc6_conv2,
229
+ self.dc6_conv3,
230
+ self.dc6_conv4,
231
+ self.dc6_conv5,
232
+ self.dc6_conv6,
233
+ self.dc6_conv7)
234
+ self.dc5_conv = nn.Sequential( self.dc5_conv1,
235
+ self.dc5_conv2,
236
+ self.dc5_conv3,
237
+ self.dc5_conv4,
238
+ self.dc5_conv5,
239
+ self.dc5_conv6,
240
+ self.dc5_conv7)
241
+ self.dc4_conv = nn.Sequential( self.dc4_conv1,
242
+ self.dc4_conv2,
243
+ self.dc4_conv3,
244
+ self.dc4_conv4,
245
+ self.dc4_conv5,
246
+ self.dc4_conv6,
247
+ self.dc4_conv7)
248
+ self.dc3_conv = nn.Sequential( self.dc3_conv1,
249
+ self.dc3_conv2,
250
+ self.dc3_conv3,
251
+ self.dc3_conv4,
252
+ self.dc3_conv5,
253
+ self.dc3_conv6,
254
+ self.dc3_conv7)
255
+ self.dc2_conv = nn.Sequential( self.dc2_conv1,
256
+ self.dc2_conv2,
257
+ self.dc2_conv3,
258
+ self.dc2_conv4,
259
+ self.dc2_conv5,
260
+ self.dc2_conv6,
261
+ self.dc2_conv7)
262
+
263
+ ## Out-of-range detection
264
+ self.dc6_convo = nn.Sequential(conv(128+4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1),
265
+ conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
266
+ conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
267
+ conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
268
+ conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
269
+ conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
270
+ nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
271
+
272
+ self.dc5_convo = nn.Sequential(conv(128+2*4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1),
273
+ conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
274
+ conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
275
+ conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
276
+ conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
277
+ conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
278
+ nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
279
+
280
+ self.dc4_convo = nn.Sequential(conv(128+3*4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1),
281
+ conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
282
+ conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
283
+ conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
284
+ conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
285
+ conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
286
+ nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
287
+
288
+ self.dc3_convo = nn.Sequential(conv(64+16*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1),
289
+ conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
290
+ conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
291
+ conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
292
+ conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
293
+ conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
294
+ nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
295
+
296
+ self.dc2_convo = nn.Sequential(conv(64+16*fdimb1+4*fdimb2, 128, kernel_size=3, stride=1, padding=1, dilation=1),
297
+ conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
298
+ conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
299
+ conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
300
+ conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
301
+ conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
302
+ nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
303
+
304
+ # affine-exp
305
+ self.f3d2v1 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
306
+ self.f3d2v2 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
307
+ self.f3d2v3 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
308
+ self.f3d2v4 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
309
+ self.f3d2v5 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
310
+ self.f3d2v6 = conv(12*81, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
311
+ self.f3d2 = bfmodule(128-64,1)
312
+
313
+ # depth change net
314
+ self.dcnetv1 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
315
+ self.dcnetv2 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
316
+ self.dcnetv3 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
317
+ self.dcnetv4 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
318
+ self.dcnetv5 = conv(12*81, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
319
+ self.dcnetv6 = conv(4, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
320
+ if exp_unc:
321
+ self.dcnet = bfmodule(128,2)
322
+ else:
323
+ self.dcnet = bfmodule(128,1)
324
+
325
+ for m in self.modules():
326
+ if isinstance(m, nn.Conv3d):
327
+ n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels
328
+ m.weight.data.normal_(0, math.sqrt(2. / n))
329
+ if hasattr(m.bias,'data'):
330
+ m.bias.data.zero_()
331
+
332
+ self.facs = [self.fac,1,1,1,1]
333
+ self.warp_modules = nn.ModuleList([None, self.warp5, self.warp4, self.warp3, self.warp2])
334
+ self.f_modules = nn.ModuleList([self.f6, self.f5, self.f4, self.f3, self.f2])
335
+ self.p_modules = nn.ModuleList([self.p6, self.p5, self.p4, self.p3, self.p2])
336
+ self.reg_modules = nn.ModuleList([self.flow_reg64, self.flow_reg32, self.flow_reg16, self.flow_reg8, self.flow_reg4])
337
+ self.oor_modules = nn.ModuleList([self.dc6_convo, self.dc5_convo, self.dc4_convo, self.dc3_convo, self.dc2_convo])
338
+ self.fuse_modules = nn.ModuleList([self.dc6_conv, self.dc5_conv, self.dc4_conv, self.dc3_conv, self.dc2_conv])
339
+
340
+ def corrf(self, refimg_fea, targetimg_fea,maxdisp, fac=1):
341
+ """
342
+ slow correlation function
343
+ """
344
+ b,c,height,width = refimg_fea.shape
345
+ if refimg_fea.is_cuda:
346
+ cost = Variable(torch.cuda.FloatTensor(b,c,2*maxdisp+1,2*int(maxdisp//fac)+1,height,width)).fill_(0.) # b,c,u,v,h,w
347
+ else:
348
+ cost = Variable(torch.FloatTensor(b,c,2*maxdisp+1,2*int(maxdisp//fac)+1,height,width)).fill_(0.) # b,c,u,v,h,w
349
+ for i in range(2*maxdisp+1):
350
+ ind = i-maxdisp
351
+ for j in range(2*int(maxdisp//fac)+1):
352
+ indd = j-int(maxdisp//fac)
353
+ feata = refimg_fea[:,:,max(0,-indd):height-indd,max(0,-ind):width-ind]
354
+ featb = targetimg_fea[:,:,max(0,+indd):height+indd,max(0,ind):width+ind]
355
+ diff = (feata*featb)
356
+ cost[:, :, i,j,max(0,-indd):height-indd,max(0,-ind):width-ind] = diff # standard
357
+ cost = F.leaky_relu(cost, 0.1,inplace=True)
358
+ return cost
359
+
360
+ def cost_matching(self,up_flow, c1, c2, flowh, enth, level):
361
+ """
362
+ up_flow: upsample coarse flow
363
+ c1: normalized feature of image 1
364
+ c2: normalized feature of image 2
365
+ flowh: flow hypotheses
366
+ enth: entropy
367
+ """
368
+
369
+ # normalize
370
+ c1n = c1 / (c1.norm(dim=1, keepdim=True)+1e-9)
371
+ c2n = c2 / (c2.norm(dim=1, keepdim=True)+1e-9)
372
+
373
+ # cost volume
374
+ if level == 0:
375
+ warp = c2n
376
+ else:
377
+ warp,_ = self.warp_modules[level](c2n, up_flow)
378
+
379
+ feat = self.corrf(c1n,warp,self.md[level],fac=self.facs[level])
380
+ feat = self.f_modules[level](feat)
381
+ cost = self.p_modules[level](feat) # b, 16, u,v,h,w
382
+
383
+ # soft WTA
384
+ b,c,u,v,h,w = cost.shape
385
+ cost = cost.view(-1,u,v,h,w) # bx16, 9,9,h,w, also predict uncertainty from here
386
+ flowhh,enthh = self.reg_modules[level](cost) # bx16, 2, h, w
387
+ flowhh = flowhh.view(b,c,2,h,w)
388
+ if level > 0:
389
+ flowhh = flowhh + up_flow[:,np.newaxis]
390
+ flowhh = flowhh.view(b,-1,h,w) # b, 16*2, h, w
391
+ enthh = enthh.view(b,-1,h,w) # b, 16*1, h, w
392
+
393
+ # append coarse hypotheses
394
+ if level == 0:
395
+ flowh = flowhh
396
+ enth = enthh
397
+ else:
398
+ flowh = torch.cat((flowhh, F.upsample(flowh.detach()*2, [flowhh.shape[2],flowhh.shape[3]], mode='bilinear')),1) # b, k2--k2, h, w
399
+ enth = torch.cat((enthh, F.upsample(enth, [flowhh.shape[2],flowhh.shape[3]], mode='bilinear')),1)
400
+
401
+ if self.training or level==4:
402
+ x = torch.cat((enth.detach(), flowh.detach(), c1),1)
403
+ oor = self.oor_modules[level](x)[:,0]
404
+ else: oor = None
405
+
406
+ # hypotheses fusion
407
+ x = torch.cat((enth.detach(), flowh.detach(), c1),1)
408
+ va = self.fuse_modules[level](x)
409
+ va = va.view(b,-1,2,h,w)
410
+ flow = ( flowh.view(b,-1,2,h,w) * F.softmax(va,1) ).sum(1) # b, 2k, 2, h, w
411
+
412
+ return flow, flowh, enth, oor
413
+
414
+ def affine(self,pref,flow, pw=1):
415
+ b,_,lh,lw=flow.shape
416
+ ptar = pref + flow
417
+ pw = 1
418
+ pref = F.unfold(pref, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-pref[:,:,np.newaxis]
419
+ ptar = F.unfold(ptar, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-ptar[:,:,np.newaxis] # b, 2,9,h,w
420
+ pref = pref.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2)
421
+ ptar = ptar.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2)
422
+
423
+ prefprefT = pref.matmul(pref.permute(0,2,1))
424
+ ppdet = prefprefT[:,0,0]*prefprefT[:,1,1]-prefprefT[:,1,0]*prefprefT[:,0,1]
425
+ ppinv = torch.cat((prefprefT[:,1,1:],-prefprefT[:,0,1:], -prefprefT[:,1:,0], prefprefT[:,0:1,0]),1).view(-1,2,2)/ppdet.clamp(1e-10,np.inf)[:,np.newaxis,np.newaxis]
426
+
427
+ Affine = ptar.matmul(pref.permute(0,2,1)).matmul(ppinv)
428
+ Error = (Affine.matmul(pref)-ptar).norm(2,1).mean(1).view(b,1,lh,lw)
429
+
430
+ Avol = (Affine[:,0,0]*Affine[:,1,1]-Affine[:,1,0]*Affine[:,0,1]).view(b,1,lh,lw).abs().clamp(1e-10,np.inf)
431
+ exp = Avol.sqrt()
432
+ mask = (exp>0.5) & (exp<2) & (Error<0.1)
433
+ mask = mask[:,0]
434
+
435
+ exp = exp.clamp(0.5,2)
436
+ exp[Error>0.1]=1
437
+ return exp, Error, mask
438
+
439
+ def affine_mask(self,pref,flow, pw=3):
440
+ """
441
+ pref: reference coordinates
442
+ pw: patch width
443
+ """
444
+ flmask = flow[:,2:]
445
+ flow = flow[:,:2]
446
+ b,_,lh,lw=flow.shape
447
+ ptar = pref + flow
448
+ pref = F.unfold(pref, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-pref[:,:,np.newaxis]
449
+ ptar = F.unfold(ptar, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-ptar[:,:,np.newaxis] # b, 2,9,h,w
450
+
451
+ conf_flow = flmask
452
+ conf_flow = F.unfold(conf_flow,(pw*2+1,pw*2+1), padding=(pw)).view(b,1,(pw*2+1)**2,lh,lw)
453
+ count = conf_flow.sum(2,keepdims=True)
454
+ conf_flow = ((pw*2+1)**2)*conf_flow / count
455
+ pref = pref * conf_flow
456
+ ptar = ptar * conf_flow
457
+
458
+ pref = pref.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2)
459
+ ptar = ptar.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2)
460
+
461
+ prefprefT = pref.matmul(pref.permute(0,2,1))
462
+ ppdet = prefprefT[:,0,0]*prefprefT[:,1,1]-prefprefT[:,1,0]*prefprefT[:,0,1]
463
+ ppinv = torch.cat((prefprefT[:,1,1:],-prefprefT[:,0,1:], -prefprefT[:,1:,0], prefprefT[:,0:1,0]),1).view(-1,2,2)/ppdet.clamp(1e-10,np.inf)[:,np.newaxis,np.newaxis]
464
+
465
+ Affine = ptar.matmul(pref.permute(0,2,1)).matmul(ppinv)
466
+ Error = (Affine.matmul(pref)-ptar).norm(2,1).mean(1).view(b,1,lh,lw)
467
+
468
+ Avol = (Affine[:,0,0]*Affine[:,1,1]-Affine[:,1,0]*Affine[:,0,1]).view(b,1,lh,lw).abs().clamp(1e-10,np.inf)
469
+ exp = Avol.sqrt()
470
+ mask = (exp>0.5) & (exp<2) & (Error<0.2) & (flmask.bool()) & (count[:,0]>4)
471
+ mask = mask[:,0]
472
+
473
+ exp = exp.clamp(0.5,2)
474
+ exp[Error>0.2]=1
475
+ return exp, Error, mask
476
+
477
+ def weight_parameters(self):
478
+ return [param for name, param in self.named_parameters() if 'weight' in name]
479
+
480
+ def bias_parameters(self):
481
+ return [param for name, param in self.named_parameters() if 'bias' in name]
482
+
483
+ def forward(self,im,disc_aux=None):
484
+ bs = im.shape[0]//2
485
+
486
+ if self.training and disc_aux[-1]: # if only fine-tuning expansion
487
+ reset=True
488
+ self.eval()
489
+ torch.set_grad_enabled(False)
490
+ else: reset=False
491
+
492
+ c06,c05,c04,c03,c02 = self.pspnet(im)
493
+ c16 = c06[:bs]; c26 = c06[bs:]
494
+ c15 = c05[:bs]; c25 = c05[bs:]
495
+ c14 = c04[:bs]; c24 = c04[bs:]
496
+ c13 = c03[:bs]; c23 = c03[bs:]
497
+ c12 = c02[:bs]; c22 = c02[bs:]
498
+
499
+ ## matching 6
500
+ flow6, flow6h, ent6h, oor6 = self.cost_matching(None, c16, c26, None, None,level=0)
501
+
502
+ ## matching 5
503
+ up_flow6 = F.upsample(flow6, [im.size()[2]//32,im.size()[3]//32], mode='bilinear')*2
504
+ flow5, flow5h, ent5h, oor5 = self.cost_matching(up_flow6, c15, c25, flow6h, ent6h,level=1)
505
+
506
+ ## matching 4
507
+ up_flow5 = F.upsample(flow5, [im.size()[2]//16,im.size()[3]//16], mode='bilinear')*2
508
+ flow4, flow4h, ent4h, oor4 = self.cost_matching(up_flow5, c14, c24, flow5h, ent5h,level=2)
509
+
510
+ ## matching 3
511
+ up_flow4 = F.upsample(flow4, [im.size()[2]//8,im.size()[3]//8], mode='bilinear')*2
512
+ flow3, flow3h, ent3h, oor3 = self.cost_matching(up_flow4, c13, c23, flow4h, ent4h,level=3)
513
+
514
+ ## matching 2
515
+ up_flow3 = F.upsample(flow3, [im.size()[2]//4,im.size()[3]//4], mode='bilinear')*2
516
+ flow2, flow2h, ent2h, oor2 = self.cost_matching(up_flow3, c12, c22, flow3h, ent3h,level=4)
517
+
518
+ if reset:
519
+ torch.set_grad_enabled(True)
520
+ self.train()
521
+
522
+ # expansion
523
+ b,_,h,w = flow2.shape
524
+ exp2,err2,_ = self.affine(get_grid(b,h,w)[:,0].permute(0,3,1,2).repeat(b,1,1,1).clone(), flow2.detach(),pw=1)
525
+ x = torch.cat((
526
+ self.f3d2v2(-exp2.log()),
527
+ self.f3d2v3(err2),
528
+ ),1)
529
+ dchange2 = -exp2.log()+1./200*self.f3d2(x)[0]
530
+
531
+ # depth change net
532
+ iexp2 = F.upsample(dchange2.clone(), [im.size()[2],im.size()[3]], mode='bilinear')
533
+
534
+ x = torch.cat((self.dcnetv1(c12.detach()),
535
+ self.dcnetv2(dchange2.detach()),
536
+ self.dcnetv3(-exp2.log()),
537
+ self.dcnetv4(err2),
538
+ ),1)
539
+ dcneto = 1./200*self.dcnet(x)[0]
540
+ dchange2 = dchange2.detach() + dcneto[:,:1]
541
+
542
+ flow2 = F.upsample(flow2.detach(), [im.size()[2],im.size()[3]], mode='bilinear')*4
543
+ dchange2 = F.upsample(dchange2, [im.size()[2],im.size()[3]], mode='bilinear')
544
+
545
+ if self.training:
546
+ flowl0 = disc_aux[0].permute(0,3,1,2).clone()
547
+ gt_depth = disc_aux[2][:,:,:,0]
548
+ gt_f3d = disc_aux[2][:,:,:,4:7].permute(0,3,1,2).clone()
549
+ gt_dchange = (1+gt_f3d[:,2]/gt_depth)
550
+ maskdc = (gt_dchange < 2) & (gt_dchange > 0.5) & disc_aux[1]
551
+
552
+ gt_expi,gt_expi_err,maskoe = self.affine_mask(get_grid(b,4*h,4*w)[:,0].permute(0,3,1,2).repeat(b,1,1,1), flowl0,pw=3)
553
+ gt_exp = 1./gt_expi[:,0]
554
+
555
+ loss = 0.1* (dchange2[:,0]-gt_dchange.log()).abs()[maskdc].mean()
556
+ loss += 0.1* (iexp2[:,0]-gt_exp.log()).abs()[maskoe].mean()
557
+ return flow2*4, flow3*8,flow4*16,flow5*32,flow6*64,loss, dchange2[:,0], iexp2[:,0]
558
+
559
+ else:
560
+ return flow2, oor2, dchange2, iexp2
561
+
expansion/models/__init__.py ADDED
File without changes
expansion/models/__pycache__/VCN_exp.cpython-38.pyc ADDED
Binary file (17.7 kB). View file
 
expansion/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (158 Bytes). View file
 
expansion/models/__pycache__/conv4d.cpython-38.pyc ADDED
Binary file (8.26 kB). View file
 
expansion/models/__pycache__/submodule.cpython-38.pyc ADDED
Binary file (12 kB). View file
 
expansion/models/conv4d.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch.nn as nn
3
+ import math
4
+ import torch
5
+ from torch.nn.parameter import Parameter
6
+ import torch.nn.functional as F
7
+ from torch.nn import Module
8
+ from torch.nn.modules.conv import _ConvNd
9
+ from torch.nn.modules.utils import _quadruple
10
+ from torch.autograd import Variable
11
+ from torch.nn import Conv2d
12
+
13
+ def conv4d(data,filters,bias=None,permute_filters=True,use_half=False):
14
+ """
15
+ This is done by stacking results of multiple 3D convolutions, and is very slow.
16
+ Taken from https://github.com/ignacio-rocco/ncnet
17
+ """
18
+ b,c,h,w,d,t=data.size()
19
+
20
+ data=data.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop
21
+
22
+ # Same permutation is done with filters, unless already provided with permutation
23
+ if permute_filters:
24
+ filters=filters.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop
25
+
26
+ c_out=filters.size(1)
27
+ if use_half:
28
+ output = Variable(torch.HalfTensor(h,b,c_out,w,d,t),requires_grad=data.requires_grad)
29
+ else:
30
+ output = Variable(torch.zeros(h,b,c_out,w,d,t),requires_grad=data.requires_grad)
31
+
32
+ padding=filters.size(0)//2
33
+ if use_half:
34
+ Z=Variable(torch.zeros(padding,b,c,w,d,t).half())
35
+ else:
36
+ Z=Variable(torch.zeros(padding,b,c,w,d,t))
37
+
38
+ if data.is_cuda:
39
+ Z=Z.cuda(data.get_device())
40
+ output=output.cuda(data.get_device())
41
+
42
+ data_padded = torch.cat((Z,data,Z),0)
43
+
44
+
45
+ for i in range(output.size(0)): # loop on first feature dimension
46
+ # convolve with center channel of filter (at position=padding)
47
+ output[i,:,:,:,:,:]=F.conv3d(data_padded[i+padding,:,:,:,:,:],
48
+ filters[padding,:,:,:,:,:], bias=bias, stride=1, padding=padding)
49
+ # convolve with upper/lower channels of filter (at postions [:padding] [padding+1:])
50
+ for p in range(1,padding+1):
51
+ output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding-p,:,:,:,:,:],
52
+ filters[padding-p,:,:,:,:,:], bias=None, stride=1, padding=padding)
53
+ output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding+p,:,:,:,:,:],
54
+ filters[padding+p,:,:,:,:,:], bias=None, stride=1, padding=padding)
55
+
56
+ output=output.permute(1,2,0,3,4,5).contiguous()
57
+ return output
58
+
59
+ class Conv4d(_ConvNd):
60
+ """Applies a 4D convolution over an input signal composed of several input
61
+ planes.
62
+ """
63
+
64
+ def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True):
65
+ # stride, dilation and groups !=1 functionality not tested
66
+ stride=1
67
+ dilation=1
68
+ groups=1
69
+ # zero padding is added automatically in conv4d function to preserve tensor size
70
+ padding = 0
71
+ kernel_size = _quadruple(kernel_size)
72
+ stride = _quadruple(stride)
73
+ padding = _quadruple(padding)
74
+ dilation = _quadruple(dilation)
75
+ super(Conv4d, self).__init__(
76
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
77
+ False, _quadruple(0), groups, bias)
78
+ # weights will be sliced along one dimension during convolution loop
79
+ # make the looping dimension to be the first one in the tensor,
80
+ # so that we don't need to call contiguous() inside the loop
81
+ self.pre_permuted_filters=pre_permuted_filters
82
+ if self.pre_permuted_filters:
83
+ self.weight.data=self.weight.data.permute(2,0,1,3,4,5).contiguous()
84
+ self.use_half=False
85
+ # self.isbias = bias
86
+ # if not self.isbias:
87
+ # self.bn = torch.nn.BatchNorm1d(out_channels)
88
+
89
+
90
+ def forward(self, input):
91
+ out = conv4d(input, self.weight, bias=self.bias,permute_filters=not self.pre_permuted_filters,use_half=self.use_half) # filters pre-permuted in constructor
92
+ # if not self.isbias:
93
+ # b,c,u,v,h,w = out.shape
94
+ # out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w)
95
+ return out
96
+
97
+ class fullConv4d(torch.nn.Module):
98
+ def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True):
99
+ super(fullConv4d, self).__init__()
100
+ self.conv = Conv4d(in_channels, out_channels, kernel_size, bias=bias, pre_permuted_filters=pre_permuted_filters)
101
+ self.isbias = bias
102
+ if not self.isbias:
103
+ self.bn = torch.nn.BatchNorm1d(out_channels)
104
+
105
+ def forward(self, input):
106
+ out = self.conv(input)
107
+ if not self.isbias:
108
+ b,c,u,v,h,w = out.shape
109
+ out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w)
110
+ return out
111
+
112
+ class butterfly4D(torch.nn.Module):
113
+ '''
114
+ butterfly 4d
115
+ '''
116
+ def __init__(self, fdima, fdimb, withbn=True, full=True,groups=1):
117
+ super(butterfly4D, self).__init__()
118
+ self.proj = nn.Sequential(projfeat4d(fdima, fdimb, 1, with_bn=withbn,groups=groups),
119
+ nn.ReLU(inplace=True),)
120
+ self.conva1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups)
121
+ self.conva2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups)
122
+ self.convb3 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
123
+ self.convb2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
124
+ self.convb1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
125
+
126
+ #@profile
127
+ def forward(self,x):
128
+ out = self.proj(x)
129
+ b,c,u,v,h,w = out.shape # 9x9
130
+
131
+ out1 = self.conva1(out) # 5x5, 3
132
+ _,c1,u1,v1,h1,w1 = out1.shape
133
+
134
+ out2 = self.conva2(out1) # 3x3, 9
135
+ _,c2,u2,v2,h2,w2 = out2.shape
136
+
137
+ out2 = self.convb3(out2) # 3x3, 9
138
+
139
+ tout1 = F.upsample(out2.view(b,c,u2,v2,-1),(u1,v1,h2*w2),mode='trilinear').view(b,c,u1,v1,h2,w2) # 5x5
140
+ tout1 = F.upsample(tout1.view(b,c,-1,h2,w2),(u1*v1,h1,w1),mode='trilinear').view(b,c,u1,v1,h1,w1) # 5x5
141
+ out1 = tout1 + out1
142
+ out1 = self.convb2(out1)
143
+
144
+ tout = F.upsample(out1.view(b,c,u1,v1,-1),(u,v,h1*w1),mode='trilinear').view(b,c,u,v,h1,w1)
145
+ tout = F.upsample(tout.view(b,c,-1,h1,w1),(u*v,h,w),mode='trilinear').view(b,c,u,v,h,w)
146
+ out = tout + out
147
+ out = self.convb1(out)
148
+
149
+ return out
150
+
151
+
152
+
153
+ class projfeat4d(torch.nn.Module):
154
+ '''
155
+ Turn 3d projection into 2d projection
156
+ '''
157
+ def __init__(self, in_planes, out_planes, stride, with_bn=True,groups=1):
158
+ super(projfeat4d, self).__init__()
159
+ self.with_bn = with_bn
160
+ self.stride = stride
161
+ self.conv1 = nn.Conv3d(in_planes, out_planes, 1, (stride,stride,1), padding=0,bias=not with_bn,groups=groups)
162
+ self.bn = nn.BatchNorm3d(out_planes)
163
+
164
+ def forward(self,x):
165
+ b,c,u,v,h,w = x.size()
166
+ x = self.conv1(x.view(b,c,u,v,h*w))
167
+ if self.with_bn:
168
+ x = self.bn(x)
169
+ _,c,u,v,_ = x.shape
170
+ x = x.view(b,c,u,v,h,w)
171
+ return x
172
+
173
+ class sepConv4d(torch.nn.Module):
174
+ '''
175
+ Separable 4d convolution block as 2 3D convolutions
176
+ '''
177
+ def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, ksize=3, full=True,groups=1):
178
+ super(sepConv4d, self).__init__()
179
+ bias = not with_bn
180
+ self.isproj = False
181
+ self.stride = stride[0]
182
+ expand = 1
183
+
184
+ if with_bn:
185
+ if in_planes != out_planes:
186
+ self.isproj = True
187
+ self.proj = nn.Sequential(nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups),
188
+ nn.BatchNorm2d(out_planes))
189
+ if full:
190
+ self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups),
191
+ nn.BatchNorm3d(in_planes))
192
+ else:
193
+ self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups),
194
+ nn.BatchNorm3d(in_planes))
195
+ self.conv2 = nn.Sequential(nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups),
196
+ nn.BatchNorm3d(in_planes*expand))
197
+ else:
198
+ if in_planes != out_planes:
199
+ self.isproj = True
200
+ self.proj = nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups)
201
+ if full:
202
+ self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups)
203
+ else:
204
+ self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups)
205
+ self.conv2 = nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups)
206
+ self.relu = nn.ReLU(inplace=True)
207
+
208
+ #@profile
209
+ def forward(self,x):
210
+ b,c,u,v,h,w = x.shape
211
+ x = self.conv2(x.view(b,c,u,v,-1))
212
+ b,c,u,v,_ = x.shape
213
+ x = self.relu(x)
214
+ x = self.conv1(x.view(b,c,-1,h,w))
215
+ b,c,_,h,w = x.shape
216
+
217
+ if self.isproj:
218
+ x = self.proj(x.view(b,c,-1,w))
219
+ x = x.view(b,-1,u,v,h,w)
220
+ return x
221
+
222
+
223
+ class sepConv4dBlock(torch.nn.Module):
224
+ '''
225
+ Separable 4d convolution block as 2 2D convolutions and a projection
226
+ layer
227
+ '''
228
+ def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, full=True,groups=1):
229
+ super(sepConv4dBlock, self).__init__()
230
+ if in_planes == out_planes and stride==(1,1,1):
231
+ self.downsample = None
232
+ else:
233
+ if full:
234
+ self.downsample = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn,ksize=1, full=full,groups=groups)
235
+ else:
236
+ self.downsample = projfeat4d(in_planes, out_planes,stride[0], with_bn=with_bn,groups=groups)
237
+ self.conv1 = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn, full=full ,groups=groups)
238
+ self.conv2 = sepConv4d(out_planes, out_planes,(1,1,1), with_bn=with_bn, full=full,groups=groups)
239
+ self.relu1 = nn.ReLU(inplace=True)
240
+ self.relu2 = nn.ReLU(inplace=True)
241
+
242
+ #@profile
243
+ def forward(self,x):
244
+ out = self.relu1(self.conv1(x))
245
+ if self.downsample:
246
+ x = self.downsample(x)
247
+ out = self.relu2(x + self.conv2(out))
248
+ return out
249
+
250
+
251
+ ##import torch.backends.cudnn as cudnn
252
+ ##cudnn.benchmark = True
253
+ #import time
254
+ ##im = torch.randn(9,64,9,160,224).cuda()
255
+ ##net = torch.nn.Conv3d(64, 64, 3).cuda()
256
+ ##net = Conv4d(1,1,3,bias=True,pre_permuted_filters=True).cuda()
257
+ ##net = sepConv4dBlock(2,2,stride=(1,1,1)).cuda()
258
+ #
259
+ ##im = torch.randn(1,16,9,9,96,320).cuda()
260
+ ##net = sepConv4d(16,16,with_bn=False).cuda()
261
+ #
262
+ ##im = torch.randn(1,16,81,96,320).cuda()
263
+ ##net = torch.nn.Conv3d(16,16,(1,3,3),padding=(0,1,1)).cuda()
264
+ #
265
+ ##im = torch.randn(1,16,9,9,96*320).cuda()
266
+ ##net = torch.nn.Conv3d(16,16,(3,3,1),padding=(1,1,0)).cuda()
267
+ #
268
+ ##im = torch.randn(10000,10,9,9).cuda()
269
+ ##net = torch.nn.Conv2d(10,10,3,padding=1).cuda()
270
+ #
271
+ ##im = torch.randn(81,16,96,320).cuda()
272
+ ##net = torch.nn.Conv2d(16,16,3,padding=1).cuda()
273
+ #c= int(16 *1)
274
+ #cp = int(16 *1)
275
+ #h=int(96 *4)
276
+ #w=int(320 *4)
277
+ #k=3
278
+ #im = torch.randn(1,c,h,w).cuda()
279
+ #net = torch.nn.Conv2d(c,cp,k,padding=k//2).cuda()
280
+ #
281
+ #im2 = torch.randn(cp,k*k*c).cuda()
282
+ #im1 = F.unfold(im, (k,k), padding=k//2)[0]
283
+ #
284
+ #
285
+ #net(im)
286
+ #net(im)
287
+ #torch.mm(im2,im1)
288
+ #torch.mm(im2,im1)
289
+ #torch.cuda.synchronize()
290
+ #beg = time.time()
291
+ #for i in range(100):
292
+ # net(im)
293
+ # #im1 = F.unfold(im, (k,k), padding=k//2)[0]
294
+ # torch.mm(im2,im1)
295
+ #torch.cuda.synchronize()
296
+ #print('%f'%((time.time()-beg)*10.))
expansion/models/submodule.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.utils.data
5
+ from torch.autograd import Variable
6
+ import torch.nn.functional as F
7
+ import math
8
+ import numpy as np
9
+ import pdb
10
+
11
+ class residualBlock(nn.Module):
12
+ expansion = 1
13
+
14
+ def __init__(self, in_channels, n_filters, stride=1, downsample=None,dilation=1,with_bn=True):
15
+ super(residualBlock, self).__init__()
16
+ if dilation > 1:
17
+ padding = dilation
18
+ else:
19
+ padding = 1
20
+
21
+ if with_bn:
22
+ self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, padding, dilation=dilation)
23
+ self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1)
24
+ else:
25
+ self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, padding, dilation=dilation,with_bn=False)
26
+ self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, with_bn=False)
27
+ self.downsample = downsample
28
+ self.relu = nn.LeakyReLU(0.1, inplace=True)
29
+
30
+ def forward(self, x):
31
+ residual = x
32
+
33
+ out = self.convbnrelu1(x)
34
+ out = self.convbn2(out)
35
+
36
+ if self.downsample is not None:
37
+ residual = self.downsample(x)
38
+
39
+ out += residual
40
+ return self.relu(out)
41
+
42
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
43
+ return nn.Sequential(
44
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
45
+ padding=padding, dilation=dilation, bias=True),
46
+ nn.BatchNorm2d(out_planes),
47
+ nn.LeakyReLU(0.1,inplace=True))
48
+
49
+
50
+ class conv2DBatchNorm(nn.Module):
51
+ def __init__(self, in_channels, n_filters, k_size, stride, padding, dilation=1, with_bn=True):
52
+ super(conv2DBatchNorm, self).__init__()
53
+ bias = not with_bn
54
+
55
+ if dilation > 1:
56
+ conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
57
+ padding=padding, stride=stride, bias=bias, dilation=dilation)
58
+
59
+ else:
60
+ conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
61
+ padding=padding, stride=stride, bias=bias, dilation=1)
62
+
63
+
64
+ if with_bn:
65
+ self.cb_unit = nn.Sequential(conv_mod,
66
+ nn.BatchNorm2d(int(n_filters)),)
67
+ else:
68
+ self.cb_unit = nn.Sequential(conv_mod,)
69
+
70
+ def forward(self, inputs):
71
+ outputs = self.cb_unit(inputs)
72
+ return outputs
73
+
74
+ class conv2DBatchNormRelu(nn.Module):
75
+ def __init__(self, in_channels, n_filters, k_size, stride, padding, dilation=1, with_bn=True):
76
+ super(conv2DBatchNormRelu, self).__init__()
77
+ bias = not with_bn
78
+ if dilation > 1:
79
+ conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
80
+ padding=padding, stride=stride, bias=bias, dilation=dilation)
81
+
82
+ else:
83
+ conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
84
+ padding=padding, stride=stride, bias=bias, dilation=1)
85
+
86
+ if with_bn:
87
+ self.cbr_unit = nn.Sequential(conv_mod,
88
+ nn.BatchNorm2d(int(n_filters)),
89
+ nn.LeakyReLU(0.1, inplace=True),)
90
+ else:
91
+ self.cbr_unit = nn.Sequential(conv_mod,
92
+ nn.LeakyReLU(0.1, inplace=True),)
93
+
94
+ def forward(self, inputs):
95
+ outputs = self.cbr_unit(inputs)
96
+ return outputs
97
+
98
+ class pyramidPooling(nn.Module):
99
+
100
+ def __init__(self, in_channels, with_bn=True, levels=4):
101
+ super(pyramidPooling, self).__init__()
102
+ self.levels = levels
103
+
104
+ self.paths = []
105
+ for i in range(levels):
106
+ self.paths.append(conv2DBatchNormRelu(in_channels, in_channels, 1, 1, 0, with_bn=with_bn))
107
+ self.path_module_list = nn.ModuleList(self.paths)
108
+ self.relu = nn.LeakyReLU(0.1, inplace=True)
109
+
110
+ def forward(self, x):
111
+ h, w = x.shape[2:]
112
+
113
+ k_sizes = []
114
+ strides = []
115
+ for pool_size in np.linspace(1,min(h,w)//2,self.levels,dtype=int):
116
+ k_sizes.append((int(h/pool_size), int(w/pool_size)))
117
+ strides.append((int(h/pool_size), int(w/pool_size)))
118
+ k_sizes = k_sizes[::-1]
119
+ strides = strides[::-1]
120
+
121
+ pp_sum = x
122
+
123
+ for i, module in enumerate(self.path_module_list):
124
+ out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
125
+ out = module(out)
126
+ out = F.upsample(out, size=(h,w), mode='bilinear')
127
+ pp_sum = pp_sum + 1./self.levels*out
128
+ pp_sum = self.relu(pp_sum/2.)
129
+
130
+ return pp_sum
131
+
132
+ class pspnet(nn.Module):
133
+ """
134
+ Modified PSPNet. https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/pspnet.py
135
+ """
136
+ def __init__(self, is_proj=True,groups=1):
137
+ super(pspnet, self).__init__()
138
+ self.inplanes = 32
139
+ self.is_proj = is_proj
140
+
141
+ # Encoder
142
+ self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16,
143
+ padding=1, stride=2)
144
+ self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16,
145
+ padding=1, stride=1)
146
+ self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32,
147
+ padding=1, stride=1)
148
+ # Vanilla Residual Blocks
149
+ self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
150
+ self.res_block5 = self._make_layer(residualBlock,128,1,stride=2)
151
+ self.res_block6 = self._make_layer(residualBlock,128,1,stride=2)
152
+ self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
153
+ self.pyramid_pooling = pyramidPooling(128, levels=3)
154
+
155
+ # Iconvs
156
+ self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2),
157
+ conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
158
+ padding=1, stride=1))
159
+ self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
160
+ padding=1, stride=1)
161
+ self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2),
162
+ conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
163
+ padding=1, stride=1))
164
+ self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
165
+ padding=1, stride=1)
166
+ self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2),
167
+ conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
168
+ padding=1, stride=1))
169
+ self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
170
+ padding=1, stride=1)
171
+ self.upconv3 = nn.Sequential(nn.Upsample(scale_factor=2),
172
+ conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
173
+ padding=1, stride=1))
174
+ self.iconv2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64,
175
+ padding=1, stride=1)
176
+
177
+ if self.is_proj:
178
+ self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
179
+ self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
180
+ self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
181
+ self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
182
+ self.proj2 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
183
+
184
+ for m in self.modules():
185
+ if isinstance(m, nn.Conv2d):
186
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
187
+ m.weight.data.normal_(0, math.sqrt(2. / n))
188
+ if hasattr(m.bias,'data'):
189
+ m.bias.data.zero_()
190
+
191
+
192
+ def _make_layer(self, block, planes, blocks, stride=1):
193
+ downsample = None
194
+ if stride != 1 or self.inplanes != planes * block.expansion:
195
+ downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,
196
+ kernel_size=1, stride=stride, bias=False),
197
+ nn.BatchNorm2d(planes * block.expansion),)
198
+ layers = []
199
+ layers.append(block(self.inplanes, planes, stride, downsample))
200
+ self.inplanes = planes * block.expansion
201
+ for i in range(1, blocks):
202
+ layers.append(block(self.inplanes, planes))
203
+ return nn.Sequential(*layers)
204
+
205
+ def forward(self, x):
206
+ # H, W -> H/2, W/2
207
+ conv1 = self.convbnrelu1_1(x)
208
+ conv1 = self.convbnrelu1_2(conv1)
209
+ conv1 = self.convbnrelu1_3(conv1)
210
+
211
+ ## H/2, W/2 -> H/4, W/4
212
+ pool1 = F.max_pool2d(conv1, 3, 2, 1)
213
+
214
+ # H/4, W/4 -> H/16, W/16
215
+ rconv3 = self.res_block3(pool1)
216
+ conv4 = self.res_block5(rconv3)
217
+ conv5 = self.res_block6(conv4)
218
+ conv6 = self.res_block7(conv5)
219
+ conv6 = self.pyramid_pooling(conv6)
220
+
221
+ conv6x = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]],mode='bilinear')
222
+ concat5 = torch.cat((conv5,self.upconv6[1](conv6x)),dim=1)
223
+ conv5 = self.iconv5(concat5)
224
+
225
+ conv5x = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]],mode='bilinear')
226
+ concat4 = torch.cat((conv4,self.upconv5[1](conv5x)),dim=1)
227
+ conv4 = self.iconv4(concat4)
228
+
229
+ conv4x = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]],mode='bilinear')
230
+ concat3 = torch.cat((rconv3,self.upconv4[1](conv4x)),dim=1)
231
+ conv3 = self.iconv3(concat3)
232
+
233
+ conv3x = F.upsample(conv3, [pool1.size()[2],pool1.size()[3]],mode='bilinear')
234
+ concat2 = torch.cat((pool1,self.upconv3[1](conv3x)),dim=1)
235
+ conv2 = self.iconv2(concat2)
236
+
237
+ if self.is_proj:
238
+ proj6 = self.proj6(conv6)
239
+ proj5 = self.proj5(conv5)
240
+ proj4 = self.proj4(conv4)
241
+ proj3 = self.proj3(conv3)
242
+ proj2 = self.proj2(conv2)
243
+ return proj6,proj5,proj4,proj3,proj2
244
+ else:
245
+ return conv6, conv5, conv4, conv3, conv2
246
+
247
+
248
+ class pspnet_s(nn.Module):
249
+ """
250
+ Modified PSPNet. https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/pspnet.py
251
+ """
252
+ def __init__(self, is_proj=True,groups=1):
253
+ super(pspnet_s, self).__init__()
254
+ self.inplanes = 32
255
+ self.is_proj = is_proj
256
+
257
+ # Encoder
258
+ self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16,
259
+ padding=1, stride=2)
260
+ self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16,
261
+ padding=1, stride=1)
262
+ self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32,
263
+ padding=1, stride=1)
264
+ # Vanilla Residual Blocks
265
+ self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
266
+ self.res_block5 = self._make_layer(residualBlock,128,1,stride=2)
267
+ self.res_block6 = self._make_layer(residualBlock,128,1,stride=2)
268
+ self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
269
+ self.pyramid_pooling = pyramidPooling(128, levels=3)
270
+
271
+ # Iconvs
272
+ self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2),
273
+ conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
274
+ padding=1, stride=1))
275
+ self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
276
+ padding=1, stride=1)
277
+ self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2),
278
+ conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
279
+ padding=1, stride=1))
280
+ self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
281
+ padding=1, stride=1)
282
+ self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2),
283
+ conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
284
+ padding=1, stride=1))
285
+ self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
286
+ padding=1, stride=1)
287
+ #self.upconv3 = nn.Sequential(nn.Upsample(scale_factor=2),
288
+ # conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
289
+ # padding=1, stride=1))
290
+ #self.iconv2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64,
291
+ # padding=1, stride=1)
292
+
293
+ if self.is_proj:
294
+ self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
295
+ self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
296
+ self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
297
+ self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
298
+ #self.proj2 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
299
+
300
+ for m in self.modules():
301
+ if isinstance(m, nn.Conv2d):
302
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
303
+ m.weight.data.normal_(0, math.sqrt(2. / n))
304
+ if hasattr(m.bias,'data'):
305
+ m.bias.data.zero_()
306
+
307
+
308
+ def _make_layer(self, block, planes, blocks, stride=1):
309
+ downsample = None
310
+ if stride != 1 or self.inplanes != planes * block.expansion:
311
+ downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,
312
+ kernel_size=1, stride=stride, bias=False),
313
+ nn.BatchNorm2d(planes * block.expansion),)
314
+ layers = []
315
+ layers.append(block(self.inplanes, planes, stride, downsample))
316
+ self.inplanes = planes * block.expansion
317
+ for i in range(1, blocks):
318
+ layers.append(block(self.inplanes, planes))
319
+ return nn.Sequential(*layers)
320
+
321
+ def forward(self, x):
322
+ # H, W -> H/2, W/2
323
+ conv1 = self.convbnrelu1_1(x)
324
+ conv1 = self.convbnrelu1_2(conv1)
325
+ conv1 = self.convbnrelu1_3(conv1)
326
+
327
+ ## H/2, W/2 -> H/4, W/4
328
+ pool1 = F.max_pool2d(conv1, 3, 2, 1)
329
+
330
+ # H/4, W/4 -> H/16, W/16
331
+ rconv3 = self.res_block3(pool1)
332
+ conv4 = self.res_block5(rconv3)
333
+ conv5 = self.res_block6(conv4)
334
+ conv6 = self.res_block7(conv5)
335
+ conv6 = self.pyramid_pooling(conv6)
336
+
337
+ conv6x = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]],mode='bilinear')
338
+ concat5 = torch.cat((conv5,self.upconv6[1](conv6x)),dim=1)
339
+ conv5 = self.iconv5(concat5)
340
+
341
+ conv5x = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]],mode='bilinear')
342
+ concat4 = torch.cat((conv4,self.upconv5[1](conv5x)),dim=1)
343
+ conv4 = self.iconv4(concat4)
344
+
345
+ conv4x = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]],mode='bilinear')
346
+ concat3 = torch.cat((rconv3,self.upconv4[1](conv4x)),dim=1)
347
+ conv3 = self.iconv3(concat3)
348
+
349
+ #conv3x = F.upsample(conv3, [pool1.size()[2],pool1.size()[3]],mode='bilinear')
350
+ #concat2 = torch.cat((pool1,self.upconv3[1](conv3x)),dim=1)
351
+ #conv2 = self.iconv2(concat2)
352
+
353
+ if self.is_proj:
354
+ proj6 = self.proj6(conv6)
355
+ proj5 = self.proj5(conv5)
356
+ proj4 = self.proj4(conv4)
357
+ proj3 = self.proj3(conv3)
358
+ # proj2 = self.proj2(conv2)
359
+ # return proj6,proj5,proj4,proj3,proj2
360
+ return proj6,proj5,proj4,proj3
361
+ else:
362
+ # return conv6, conv5, conv4, conv3, conv2
363
+ return conv6, conv5, conv4, conv3
364
+
365
+ class bfmodule(nn.Module):
366
+ def __init__(self, inplanes, outplanes):
367
+ super(bfmodule, self).__init__()
368
+ self.proj = conv2DBatchNormRelu(in_channels=inplanes,k_size=1,n_filters=64,padding=0,stride=1)
369
+ self.inplanes = 64
370
+ # Vanilla Residual Blocks
371
+ self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
372
+ self.res_block5 = self._make_layer(residualBlock,64,1,stride=2)
373
+ self.res_block6 = self._make_layer(residualBlock,64,1,stride=2)
374
+ self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
375
+ self.pyramid_pooling = pyramidPooling(128, levels=3)
376
+ # Iconvs
377
+ self.upconv6 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
378
+ padding=1, stride=1)
379
+ self.upconv5 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
380
+ padding=1, stride=1)
381
+ self.upconv4 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
382
+ padding=1, stride=1)
383
+ self.upconv3 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
384
+ padding=1, stride=1)
385
+ self.iconv5 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
386
+ padding=1, stride=1)
387
+ self.iconv4 = conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64,
388
+ padding=1, stride=1)
389
+ self.iconv3 = conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64,
390
+ padding=1, stride=1)
391
+ self.iconv2 = nn.Sequential(conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64,
392
+ padding=1, stride=1),
393
+ nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True))
394
+
395
+ self.proj6 = nn.Conv2d(128, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
396
+ self.proj5 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
397
+ self.proj4 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
398
+ self.proj3 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
399
+
400
+ for m in self.modules():
401
+ if isinstance(m, nn.Conv2d):
402
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
403
+ m.weight.data.normal_(0, math.sqrt(2. / n))
404
+ if hasattr(m.bias,'data'):
405
+ m.bias.data.zero_()
406
+
407
+
408
+ def _make_layer(self, block, planes, blocks, stride=1):
409
+ downsample = None
410
+ if stride != 1 or self.inplanes != planes * block.expansion:
411
+ downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,
412
+ kernel_size=1, stride=stride, bias=False),
413
+ nn.BatchNorm2d(planes * block.expansion),)
414
+ layers = []
415
+ layers.append(block(self.inplanes, planes, stride, downsample))
416
+ self.inplanes = planes * block.expansion
417
+ for i in range(1, blocks):
418
+ layers.append(block(self.inplanes, planes))
419
+ return nn.Sequential(*layers)
420
+
421
+ def forward(self, x):
422
+ proj = self.proj(x) # 4x
423
+ rconv3 = self.res_block3(proj) #8x
424
+ conv4 = self.res_block5(rconv3) #16x
425
+ conv5 = self.res_block6(conv4) #32x
426
+ conv6 = self.res_block7(conv5) #64x
427
+ conv6 = self.pyramid_pooling(conv6) #64x
428
+ pred6 = self.proj6(conv6)
429
+
430
+ conv6u = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]], mode='bilinear')
431
+ concat5 = torch.cat((conv5,self.upconv6(conv6u)),dim=1)
432
+ conv5 = self.iconv5(concat5) #32x
433
+ pred5 = self.proj5(conv5)
434
+
435
+ conv5u = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]], mode='bilinear')
436
+ concat4 = torch.cat((conv4,self.upconv5(conv5u)),dim=1)
437
+ conv4 = self.iconv4(concat4) #16x
438
+ pred4 = self.proj4(conv4)
439
+
440
+ conv4u = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]], mode='bilinear')
441
+ concat3 = torch.cat((rconv3,self.upconv4(conv4u)),dim=1)
442
+ conv3 = self.iconv3(concat3) # 8x
443
+ pred3 = self.proj3(conv3)
444
+
445
+ conv3u = F.upsample(conv3, [x.size()[2],x.size()[3]], mode='bilinear')
446
+ concat2 = torch.cat((proj,self.upconv3(conv3u)),dim=1)
447
+ pred2 = self.iconv2(concat2) # 4x
448
+
449
+ return pred2, pred3, pred4, pred5, pred6
450
+
expansion/submission.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import sys
3
+ import cv2
4
+ import argparse
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.backends.cudnn as cudnn
9
+ import torch.optim as optim
10
+ import torch.nn.functional as F
11
+ cudnn.benchmark = False
12
+
13
+ class Expansion():
14
+
15
+ def __init__(self, loadmodel = 'pretrained_models/optical_expansion/robust.pth', testres = 1, maxdisp = 256, fac = 1):
16
+
17
+ maxw,maxh = [int(testres*1280), int(testres*384)]
18
+
19
+ max_h = int(maxh // 64 * 64)
20
+ max_w = int(maxw // 64 * 64)
21
+ if max_h < maxh: max_h += 64
22
+ if max_w < maxw: max_w += 64
23
+ maxh = max_h
24
+ maxw = max_w
25
+
26
+ mean_L = [[0.33,0.33,0.33]]
27
+ mean_R = [[0.33,0.33,0.33]]
28
+
29
+ # construct model, VCN-expansion
30
+ from expansion.models.VCN_exp import VCN
31
+ model = VCN([1, maxw, maxh], md=[int(4*(maxdisp/256)),4,4,4,4], fac=fac,
32
+ exp_unc=('robust' in loadmodel)) # expansion uncertainty only in the new model
33
+ model = nn.DataParallel(model, device_ids=[0])
34
+ model.cuda()
35
+
36
+ if loadmodel is not None:
37
+ pretrained_dict = torch.load(loadmodel)
38
+ mean_L=pretrained_dict['mean_L']
39
+ mean_R=pretrained_dict['mean_R']
40
+ pretrained_dict['state_dict'] = {k:v for k,v in pretrained_dict['state_dict'].items()}
41
+ model.load_state_dict(pretrained_dict['state_dict'],strict=False)
42
+ else:
43
+ print('dry run')
44
+
45
+ model.eval()
46
+ # resize
47
+ maxh = 256
48
+ maxw = 256
49
+ max_h = int(maxh // 64 * 64)
50
+ max_w = int(maxw // 64 * 64)
51
+ if max_h < maxh: max_h += 64
52
+ if max_w < maxw: max_w += 64
53
+
54
+ # modify module according to inputs
55
+ from expansion.models.VCN_exp import WarpModule, flow_reg
56
+ for i in range(len(model.module.reg_modules)):
57
+ model.module.reg_modules[i] = flow_reg([1,max_w//(2**(6-i)), max_h//(2**(6-i))],
58
+ ent=getattr(model.module, 'flow_reg%d'%2**(6-i)).ent,\
59
+ maxdisp=getattr(model.module, 'flow_reg%d'%2**(6-i)).md,\
60
+ fac=getattr(model.module, 'flow_reg%d'%2**(6-i)).fac).cuda()
61
+ for i in range(len(model.module.warp_modules)):
62
+ model.module.warp_modules[i] = WarpModule([1,max_w//(2**(6-i)), max_h//(2**(6-i))]).cuda()
63
+
64
+ mean_L = torch.from_numpy(np.asarray(mean_L).astype(np.float32).mean(0)[np.newaxis,:,np.newaxis,np.newaxis]).cuda()
65
+ mean_R = torch.from_numpy(np.asarray(mean_R).astype(np.float32).mean(0)[np.newaxis,:,np.newaxis,np.newaxis]).cuda()
66
+
67
+ self.max_h = max_h
68
+ self.max_w = max_w
69
+ self.model = model
70
+ self.mean_L = mean_L
71
+ self.mean_R = mean_R
72
+
73
+ def run(self, imgL_o, imgR_o):
74
+ model = self.model
75
+ mean_L = self.mean_L
76
+ mean_R = self.mean_R
77
+
78
+ imgL_o[imgL_o<-1] = -1
79
+ imgL_o[imgL_o>1] = 1
80
+ imgR_o[imgR_o<-1] = -1
81
+ imgR_o[imgR_o>1] = 1
82
+ imgL = (imgL_o+1.)*0.5-mean_L
83
+ imgR = (imgR_o*1.)*0.5-mean_R
84
+
85
+ with torch.no_grad():
86
+ imgLR = torch.cat([imgL,imgR],0)
87
+ model.eval()
88
+ torch.cuda.synchronize()
89
+ rts = model(imgLR)
90
+ torch.cuda.synchronize()
91
+ flow, occ, logmid, logexp = rts
92
+
93
+ torch.cuda.empty_cache()
94
+
95
+ return flow, logexp
expansion/utils/__init__.py ADDED
File without changes
expansion/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (157 Bytes). View file
 
expansion/utils/__pycache__/flowlib.cpython-38.pyc ADDED
Binary file (16 kB). View file
 
expansion/utils/__pycache__/io.cpython-38.pyc ADDED
Binary file (3.97 kB). View file
 
expansion/utils/__pycache__/pfm.cpython-38.pyc ADDED
Binary file (1.65 kB). View file