Spaces:
Runtime error
Runtime error
endo-yuki-t
commited on
Commit
·
d7dbcdd
0
Parent(s):
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +21 -0
- README.md +51 -0
- criteria/__init__.py +0 -0
- criteria/lpips/__init__.py +0 -0
- criteria/lpips/lpips.py +35 -0
- criteria/lpips/networks.py +96 -0
- criteria/lpips/utils.py +30 -0
- docs/teaser.jpg +0 -0
- docs/thumb.gif +0 -0
- env.yaml +380 -0
- expansion/__init__.py +0 -0
- expansion/dataloader/__init__.py +0 -0
- expansion/dataloader/__pycache__/__init__.cpython-38.pyc +0 -0
- expansion/dataloader/__pycache__/seqlist.cpython-38.pyc +0 -0
- expansion/dataloader/chairslist.py +33 -0
- expansion/dataloader/chairssdlist.py +30 -0
- expansion/dataloader/depth_transforms.py +471 -0
- expansion/dataloader/depthloader.py +222 -0
- expansion/dataloader/flow_transforms.py +440 -0
- expansion/dataloader/hd1klist.py +29 -0
- expansion/dataloader/kitti12list.py +29 -0
- expansion/dataloader/kitti15list.py +29 -0
- expansion/dataloader/kitti15list_train.py +31 -0
- expansion/dataloader/kitti15list_train_lidar.py +34 -0
- expansion/dataloader/kitti15list_val.py +31 -0
- expansion/dataloader/kitti15list_val_lidar.py +34 -0
- expansion/dataloader/kitti15list_val_mr.py +41 -0
- expansion/dataloader/robloader.py +133 -0
- expansion/dataloader/sceneflowlist.py +51 -0
- expansion/dataloader/seqlist.py +26 -0
- expansion/dataloader/sintellist.py +32 -0
- expansion/dataloader/sintellist_clean.py +31 -0
- expansion/dataloader/sintellist_final.py +32 -0
- expansion/dataloader/sintellist_train.py +32 -0
- expansion/dataloader/sintellist_val.py +34 -0
- expansion/dataloader/thingslist.py +122 -0
- expansion/models/VCN_exp.py +561 -0
- expansion/models/__init__.py +0 -0
- expansion/models/__pycache__/VCN_exp.cpython-38.pyc +0 -0
- expansion/models/__pycache__/__init__.cpython-38.pyc +0 -0
- expansion/models/__pycache__/conv4d.cpython-38.pyc +0 -0
- expansion/models/__pycache__/submodule.cpython-38.pyc +0 -0
- expansion/models/conv4d.py +296 -0
- expansion/models/submodule.py +450 -0
- expansion/submission.py +95 -0
- expansion/utils/__init__.py +0 -0
- expansion/utils/__pycache__/__init__.cpython-38.pyc +0 -0
- expansion/utils/__pycache__/flowlib.cpython-38.pyc +0 -0
- expansion/utils/__pycache__/io.cpython-38.pyc +0 -0
- expansion/utils/__pycache__/pfm.cpython-38.pyc +0 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Yuki Endo
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# User-Controllable Latent Transformer for StyleGAN Image Layout Editing
|
2 |
+
<!--a href="https://arxiv.org/abs/2103.14877"><img src="https://img.shields.io/badge/arXiv-2103.14877-b31b1b.svg"></a-->
|
3 |
+
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
|
4 |
+
<p align="center">
|
5 |
+
<img src="docs/teaser.jpg" width="800px"/>
|
6 |
+
</p>
|
7 |
+
|
8 |
+
This repository contains our implementation of the following paper:
|
9 |
+
|
10 |
+
Yuki Endo: "User-Controllable Latent Transformer for StyleGAN Image Layout Editing," Computer Graphpics Forum (Pacific Graphics 2022) [[Project](http://www.cgg.cs.tsukuba.ac.jp/~endo/projects/UserControllableLT)] [[PDF (preprint)]()]
|
11 |
+
|
12 |
+
## Prerequisites
|
13 |
+
1. Python 3.8
|
14 |
+
2. PyTorch 1.9.0
|
15 |
+
3. Flask
|
16 |
+
4. Others (see env.yml)
|
17 |
+
|
18 |
+
## Preparation
|
19 |
+
Download and decompress <a href="https://drive.google.com/file/d/1lBL_J-uROvqZ0BYu9gmEcMCNyaPo9cBY/view?usp=sharing">our pre-trained models</a>.
|
20 |
+
|
21 |
+
## Inference with our pre-trained models
|
22 |
+
<img src="docs/thumb.gif" width="150px"/><br>
|
23 |
+
We provide an interactive interface based on Flask. This interface can be locally launched with
|
24 |
+
```
|
25 |
+
python interface/flask_app.py --checkpoint_path=pretrained_models/latent_transformer/cat.pt
|
26 |
+
```
|
27 |
+
The interface can be accessed via http://localhost:8000/.
|
28 |
+
|
29 |
+
## Training
|
30 |
+
The latent transformer can be trained with
|
31 |
+
```
|
32 |
+
python scripts/train.py --exp_dir=results --stylegan_weights=pretrained_models/stylegan2-cat-config-f.pt
|
33 |
+
```
|
34 |
+
|
35 |
+
## Citation
|
36 |
+
Please cite our paper if you find the code useful:
|
37 |
+
```
|
38 |
+
@Article{endoPG2022,
|
39 |
+
Title = {User-Controllable Latent Transformer for StyleGAN Image Layout Editing},
|
40 |
+
Author = {Yuki Endo},
|
41 |
+
Journal = {Computer Graphics Forum},
|
42 |
+
volume = {},
|
43 |
+
number = {},
|
44 |
+
pages = {},
|
45 |
+
doi = {},
|
46 |
+
Year = {2022}
|
47 |
+
}
|
48 |
+
```
|
49 |
+
|
50 |
+
## Acknowledgements
|
51 |
+
This code heavily borrows from the [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel) and [expansion](https://github.com/gengshan-y/expansion) repositories.
|
criteria/__init__.py
ADDED
File without changes
|
criteria/lpips/__init__.py
ADDED
File without changes
|
criteria/lpips/lpips.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from criteria.lpips.networks import get_network, LinLayers
|
5 |
+
from criteria.lpips.utils import get_state_dict
|
6 |
+
|
7 |
+
|
8 |
+
class LPIPS(nn.Module):
|
9 |
+
r"""Creates a criterion that measures
|
10 |
+
Learned Perceptual Image Patch Similarity (LPIPS).
|
11 |
+
Arguments:
|
12 |
+
net_type (str): the network type to compare the features:
|
13 |
+
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
14 |
+
version (str): the version of LPIPS. Default: 0.1.
|
15 |
+
"""
|
16 |
+
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
|
17 |
+
|
18 |
+
assert version in ['0.1'], 'v0.1 is only supported now'
|
19 |
+
|
20 |
+
super(LPIPS, self).__init__()
|
21 |
+
|
22 |
+
# pretrained network
|
23 |
+
self.net = get_network(net_type).to("cuda")
|
24 |
+
|
25 |
+
# linear layers
|
26 |
+
self.lin = LinLayers(self.net.n_channels_list).to("cuda")
|
27 |
+
self.lin.load_state_dict(get_state_dict(net_type, version))
|
28 |
+
|
29 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
30 |
+
feat_x, feat_y = self.net(x), self.net(y)
|
31 |
+
|
32 |
+
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
|
33 |
+
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
|
34 |
+
|
35 |
+
return torch.sum(torch.cat(res, 0)) / x.shape[0]
|
criteria/lpips/networks.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence
|
2 |
+
|
3 |
+
from itertools import chain
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision import models
|
8 |
+
|
9 |
+
from criteria.lpips.utils import normalize_activation
|
10 |
+
|
11 |
+
|
12 |
+
def get_network(net_type: str):
|
13 |
+
if net_type == 'alex':
|
14 |
+
return AlexNet()
|
15 |
+
elif net_type == 'squeeze':
|
16 |
+
return SqueezeNet()
|
17 |
+
elif net_type == 'vgg':
|
18 |
+
return VGG16()
|
19 |
+
else:
|
20 |
+
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
|
21 |
+
|
22 |
+
|
23 |
+
class LinLayers(nn.ModuleList):
|
24 |
+
def __init__(self, n_channels_list: Sequence[int]):
|
25 |
+
super(LinLayers, self).__init__([
|
26 |
+
nn.Sequential(
|
27 |
+
nn.Identity(),
|
28 |
+
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
|
29 |
+
) for nc in n_channels_list
|
30 |
+
])
|
31 |
+
|
32 |
+
for param in self.parameters():
|
33 |
+
param.requires_grad = False
|
34 |
+
|
35 |
+
|
36 |
+
class BaseNet(nn.Module):
|
37 |
+
def __init__(self):
|
38 |
+
super(BaseNet, self).__init__()
|
39 |
+
|
40 |
+
# register buffer
|
41 |
+
self.register_buffer(
|
42 |
+
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
43 |
+
self.register_buffer(
|
44 |
+
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
45 |
+
|
46 |
+
def set_requires_grad(self, state: bool):
|
47 |
+
for param in chain(self.parameters(), self.buffers()):
|
48 |
+
param.requires_grad = state
|
49 |
+
|
50 |
+
def z_score(self, x: torch.Tensor):
|
51 |
+
return (x - self.mean) / self.std
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor):
|
54 |
+
x = self.z_score(x)
|
55 |
+
|
56 |
+
output = []
|
57 |
+
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
|
58 |
+
x = layer(x)
|
59 |
+
if i in self.target_layers:
|
60 |
+
output.append(normalize_activation(x))
|
61 |
+
if len(output) == len(self.target_layers):
|
62 |
+
break
|
63 |
+
return output
|
64 |
+
|
65 |
+
|
66 |
+
class SqueezeNet(BaseNet):
|
67 |
+
def __init__(self):
|
68 |
+
super(SqueezeNet, self).__init__()
|
69 |
+
|
70 |
+
self.layers = models.squeezenet1_1(True).features
|
71 |
+
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
|
72 |
+
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
|
73 |
+
|
74 |
+
self.set_requires_grad(False)
|
75 |
+
|
76 |
+
|
77 |
+
class AlexNet(BaseNet):
|
78 |
+
def __init__(self):
|
79 |
+
super(AlexNet, self).__init__()
|
80 |
+
|
81 |
+
self.layers = models.alexnet(True).features
|
82 |
+
self.target_layers = [2, 5, 8, 10, 12]
|
83 |
+
self.n_channels_list = [64, 192, 384, 256, 256]
|
84 |
+
|
85 |
+
self.set_requires_grad(False)
|
86 |
+
|
87 |
+
|
88 |
+
class VGG16(BaseNet):
|
89 |
+
def __init__(self):
|
90 |
+
super(VGG16, self).__init__()
|
91 |
+
|
92 |
+
self.layers = models.vgg16(True).features
|
93 |
+
self.target_layers = [4, 9, 16, 23, 30]
|
94 |
+
self.n_channels_list = [64, 128, 256, 512, 512]
|
95 |
+
|
96 |
+
self.set_requires_grad(False)
|
criteria/lpips/utils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def normalize_activation(x, eps=1e-10):
|
7 |
+
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
|
8 |
+
return x / (norm_factor + eps)
|
9 |
+
|
10 |
+
|
11 |
+
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
|
12 |
+
# build url
|
13 |
+
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
|
14 |
+
+ f'master/lpips/weights/v{version}/{net_type}.pth'
|
15 |
+
|
16 |
+
# download
|
17 |
+
old_state_dict = torch.hub.load_state_dict_from_url(
|
18 |
+
url, progress=True,
|
19 |
+
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
20 |
+
)
|
21 |
+
|
22 |
+
# rename keys
|
23 |
+
new_state_dict = OrderedDict()
|
24 |
+
for key, val in old_state_dict.items():
|
25 |
+
new_key = key
|
26 |
+
new_key = new_key.replace('lin', '')
|
27 |
+
new_key = new_key.replace('model.', '')
|
28 |
+
new_state_dict[new_key] = val
|
29 |
+
|
30 |
+
return new_state_dict
|
docs/teaser.jpg
ADDED
![]() |
docs/thumb.gif
ADDED
![]() |
env.yaml
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: uclt
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- anaconda
|
5 |
+
- nvidia
|
6 |
+
- conda-forge
|
7 |
+
- defaults
|
8 |
+
dependencies:
|
9 |
+
- _ipyw_jlab_nb_ext_conf=0.1.0=py38_0
|
10 |
+
- _libgcc_mutex=0.1=conda_forge
|
11 |
+
- _openmp_mutex=4.5=1_llvm
|
12 |
+
- absl-py=0.13.0=pyhd8ed1ab_0
|
13 |
+
- aiohttp=3.7.4.post0=py38h497a2fe_0
|
14 |
+
- albumentations=1.0.3=pyhd8ed1ab_0
|
15 |
+
- alsa-lib=1.2.3=h516909a_0
|
16 |
+
- anaconda-client=1.8.0=py38h06a4308_0
|
17 |
+
- anaconda-navigator=2.0.4=py38_0
|
18 |
+
- anyio=2.2.0=py38h06a4308_1
|
19 |
+
- appdirs=1.4.4=pyh9f0ad1d_0
|
20 |
+
- argon2-cffi=20.1.0=py38h27cfd23_1
|
21 |
+
- async-timeout=3.0.1=py_1000
|
22 |
+
- async_generator=1.10=pyhd3eb1b0_0
|
23 |
+
- attrs=21.2.0=pyhd3eb1b0_0
|
24 |
+
- babel=2.9.1=pyhd3eb1b0_0
|
25 |
+
- backcall=0.2.0=pyhd3eb1b0_0
|
26 |
+
- backports=1.0=pyhd3eb1b0_2
|
27 |
+
- backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0
|
28 |
+
- backports.tempfile=1.0=pyhd3eb1b0_1
|
29 |
+
- backports.weakref=1.0.post1=py_1
|
30 |
+
- beautifulsoup4=4.9.3=pyha847dfd_0
|
31 |
+
- blas=1.0=mkl
|
32 |
+
- bleach=4.0.0=pyhd3eb1b0_0
|
33 |
+
- blinker=1.4=py_1
|
34 |
+
- brotli=1.0.9=h7f98852_5
|
35 |
+
- brotli-bin=1.0.9=h7f98852_5
|
36 |
+
- brotlipy=0.7.0=py38h27cfd23_1003
|
37 |
+
- bzip2=1.0.8=h7b6447c_0
|
38 |
+
- c-ares=1.17.1=h27cfd23_0
|
39 |
+
- ca-certificates=2021.10.8=ha878542_0
|
40 |
+
- cachetools=4.2.2=pyhd8ed1ab_0
|
41 |
+
- cairo=1.16.0=hf32fb01_1
|
42 |
+
- certifi=2021.10.8=py38h578d9bd_1
|
43 |
+
- cffi=1.14.6=py38h400218f_0
|
44 |
+
- chardet=4.0.0=py38h06a4308_1003
|
45 |
+
- click=8.0.1=pyhd3eb1b0_0
|
46 |
+
- cloudpickle=1.6.0=py_0
|
47 |
+
- clyent=1.2.2=py38_1
|
48 |
+
- conda=4.11.0=py38h578d9bd_0
|
49 |
+
- conda-build=3.21.4=py38h06a4308_0
|
50 |
+
- conda-content-trust=0.1.1=pyhd3eb1b0_0
|
51 |
+
- conda-env=2.6.0=1
|
52 |
+
- conda-package-handling=1.7.3=py38h27cfd23_1
|
53 |
+
- conda-repo-cli=1.0.4=pyhd3eb1b0_0
|
54 |
+
- conda-token=0.3.0=pyhd3eb1b0_0
|
55 |
+
- conda-verify=3.4.2=py_1
|
56 |
+
- cryptography=3.4.7=py38hd23ed53_0
|
57 |
+
- cudatoolkit=11.1.74=h6bb024c_0
|
58 |
+
- cycler=0.10.0=py_2
|
59 |
+
- cytoolz=0.11.0=py38h497a2fe_3
|
60 |
+
- dask-core=2021.8.1=pyhd8ed1ab_0
|
61 |
+
- dbus=1.13.18=hb2f20db_0
|
62 |
+
- decorator=5.0.9=pyhd3eb1b0_0
|
63 |
+
- defusedxml=0.7.1=pyhd3eb1b0_0
|
64 |
+
- dill=0.3.4=pyhd8ed1ab_0
|
65 |
+
- dominate=2.6.0=pyhd8ed1ab_0
|
66 |
+
- entrypoints=0.3=py38_0
|
67 |
+
- enum34=1.1.10=py38h32f6830_2
|
68 |
+
- expat=2.4.1=h2531618_2
|
69 |
+
- ffmpeg=4.3.2=hca11adc_0
|
70 |
+
- filelock=3.0.12=pyhd3eb1b0_1
|
71 |
+
- flask=1.1.2=pyh9f0ad1d_0
|
72 |
+
- flask-httpauth=4.4.0=pyhd8ed1ab_0
|
73 |
+
- fontconfig=2.13.1=h6c09931_0
|
74 |
+
- fonttools=4.25.0=pyhd3eb1b0_0
|
75 |
+
- freetype=2.10.4=h5ab3b9f_0
|
76 |
+
- fsspec=2021.7.0=pyhd8ed1ab_0
|
77 |
+
- ftfy=6.0.3=pyhd8ed1ab_0
|
78 |
+
- func_timeout=4.3.5=py_0
|
79 |
+
- future=0.18.2=py38_1
|
80 |
+
- gdown=4.2.0=pyhd8ed1ab_0
|
81 |
+
- geos=3.10.0=h9c3ff4c_0
|
82 |
+
- gettext=0.19.8.1=h0b5b191_1005
|
83 |
+
- git=2.23.0=pl526hacde149_0
|
84 |
+
- glib=2.68.4=h9c3ff4c_0
|
85 |
+
- glib-tools=2.68.4=h9c3ff4c_0
|
86 |
+
- glob2=0.7=pyhd3eb1b0_0
|
87 |
+
- gmp=6.2.1=h58526e2_0
|
88 |
+
- gnutls=3.6.13=h85f3911_1
|
89 |
+
- google-auth=1.35.0=pyh6c4a22f_0
|
90 |
+
- google-auth-oauthlib=0.4.5=pyhd8ed1ab_0
|
91 |
+
- gputil=1.4.0=pyh9f0ad1d_0
|
92 |
+
- graphite2=1.3.13=h58526e2_1001
|
93 |
+
- gst-plugins-base=1.18.4=hf529b03_2
|
94 |
+
- gstreamer=1.18.4=h76c114f_2
|
95 |
+
- harfbuzz=2.9.0=h83ec7ef_0
|
96 |
+
- hdf5=1.10.6=nompi_h6a2412b_1114
|
97 |
+
- icu=68.1=h58526e2_0
|
98 |
+
- idna=2.10=pyhd3eb1b0_0
|
99 |
+
- imagecodecs-lite=2019.12.3=py38h5c078b8_3
|
100 |
+
- imageio=2.9.0=py_0
|
101 |
+
- imageio-ffmpeg=0.4.5=pyhd8ed1ab_0
|
102 |
+
- imgaug=0.4.0=py_1
|
103 |
+
- importlib-metadata=3.10.0=py38h06a4308_0
|
104 |
+
- importlib_metadata=3.10.0=hd3eb1b0_0
|
105 |
+
- intel-openmp=2021.3.0=h06a4308_3350
|
106 |
+
- ipykernel=5.3.4=py38h5ca1d4c_0
|
107 |
+
- ipympl=0.8.2=pyhd8ed1ab_0
|
108 |
+
- ipython=7.26.0=py38hb070fc8_0
|
109 |
+
- ipython_genutils=0.2.0=pyhd3eb1b0_1
|
110 |
+
- ipywidgets=7.6.3=pyhd3eb1b0_1
|
111 |
+
- itsdangerous=2.0.1=pyhd3eb1b0_0
|
112 |
+
- jasper=1.900.1=h07fcdf6_1006
|
113 |
+
- jedi=0.18.0=py38h06a4308_1
|
114 |
+
- jinja2=2.11.3=pyhd3eb1b0_0
|
115 |
+
- joblib=1.1.0=pyhd8ed1ab_0
|
116 |
+
- jpeg=9d=h36c2ea0_0
|
117 |
+
- json5=0.9.6=pyhd3eb1b0_0
|
118 |
+
- jsonnet=0.17.0=py38hadf7658_0
|
119 |
+
- jsonschema=3.2.0=py_2
|
120 |
+
- jupyter_client=6.1.12=pyhd3eb1b0_0
|
121 |
+
- jupyter_core=4.7.1=py38h06a4308_0
|
122 |
+
- jupyter_server=1.4.1=py38h06a4308_0
|
123 |
+
- jupyterlab=3.1.7=pyhd3eb1b0_0
|
124 |
+
- jupyterlab_pygments=0.1.2=py_0
|
125 |
+
- jupyterlab_server=2.7.1=pyhd3eb1b0_0
|
126 |
+
- jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
|
127 |
+
- kiwisolver=1.3.1=py38h1fd1430_1
|
128 |
+
- krb5=1.19.2=hcc1bbae_0
|
129 |
+
- lame=3.100=h7f98852_1001
|
130 |
+
- lcms2=2.12=h3be6417_0
|
131 |
+
- ld_impl_linux-64=2.35.1=h7274673_9
|
132 |
+
- libarchive=3.4.2=h62408e4_0
|
133 |
+
- libblas=3.9.0=11_linux64_mkl
|
134 |
+
- libbrotlicommon=1.0.9=h7f98852_5
|
135 |
+
- libbrotlidec=1.0.9=h7f98852_5
|
136 |
+
- libbrotlienc=1.0.9=h7f98852_5
|
137 |
+
- libcblas=3.9.0=11_linux64_mkl
|
138 |
+
- libcurl=7.78.0=h2574ce0_0
|
139 |
+
- libedit=3.1.20191231=he28a2e2_2
|
140 |
+
- libev=4.33=h516909a_1
|
141 |
+
- libevent=2.1.10=hcdb4288_3
|
142 |
+
- libffi=3.3=he6710b0_2
|
143 |
+
- libgcc-ng=11.1.0=hc902ee8_8
|
144 |
+
- libgfortran-ng=11.1.0=h69a702a_8
|
145 |
+
- libgfortran5=11.1.0=h6c583b3_8
|
146 |
+
- libglib=2.68.4=h3e27bee_0
|
147 |
+
- libiconv=1.16=h516909a_0
|
148 |
+
- liblapack=3.9.0=11_linux64_mkl
|
149 |
+
- liblapacke=3.9.0=11_linux64_mkl
|
150 |
+
- liblief=0.10.1=he6710b0_0
|
151 |
+
- libllvm11=11.1.0=hf817b99_2
|
152 |
+
- libnghttp2=1.43.0=h812cca2_0
|
153 |
+
- libogg=1.3.4=h7f98852_1
|
154 |
+
- libopencv=4.5.2=py38hcdf9bf1_0
|
155 |
+
- libopus=1.3.1=h7f98852_1
|
156 |
+
- libpng=1.6.37=hbc83047_0
|
157 |
+
- libpq=13.3=hd57d9b9_0
|
158 |
+
- libprotobuf=3.15.8=h780b84a_0
|
159 |
+
- libsodium=1.0.18=h7b6447c_0
|
160 |
+
- libssh2=1.9.0=ha56f1ee_6
|
161 |
+
- libstdcxx-ng=11.1.0=h56837e0_8
|
162 |
+
- libtiff=4.2.0=h85742a9_0
|
163 |
+
- libuuid=1.0.3=h1bed415_2
|
164 |
+
- libuv=1.40.0=h7b6447c_0
|
165 |
+
- libvorbis=1.3.7=h9c3ff4c_0
|
166 |
+
- libwebp-base=1.2.0=h27cfd23_0
|
167 |
+
- libxcb=1.14=h7b6447c_0
|
168 |
+
- libxkbcommon=1.0.3=he3ba5ed_0
|
169 |
+
- libxml2=2.9.12=h72842e0_0
|
170 |
+
- llvm-openmp=12.0.1=h4bd325d_1
|
171 |
+
- locket=0.2.0=py_2
|
172 |
+
- lz4-c=1.9.3=h295c915_1
|
173 |
+
- markdown=3.3.4=pyhd8ed1ab_0
|
174 |
+
- markupsafe=2.0.1=py38h27cfd23_0
|
175 |
+
- matplotlib=3.4.2=py38h578d9bd_0
|
176 |
+
- matplotlib-base=3.4.2=py38hab158f2_0
|
177 |
+
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
|
178 |
+
- mistune=0.8.4=py38h7b6447c_1000
|
179 |
+
- mkl=2021.3.0=h06a4308_520
|
180 |
+
- mkl-service=2.4.0=py38h7f8727e_0
|
181 |
+
- mkl_fft=1.3.0=py38h42c9631_2
|
182 |
+
- mkl_random=1.2.2=py38h51133e4_0
|
183 |
+
- multidict=5.1.0=py38h497a2fe_1
|
184 |
+
- munkres=1.1.4=pyh9f0ad1d_0
|
185 |
+
- mysql-common=8.0.25=ha770c72_0
|
186 |
+
- mysql-libs=8.0.25=h935591d_0
|
187 |
+
- navigator-updater=0.2.1=py38_0
|
188 |
+
- nbclassic=0.2.6=pyhd3eb1b0_0
|
189 |
+
- nbclient=0.5.3=pyhd3eb1b0_0
|
190 |
+
- nbconvert=6.1.0=py38h06a4308_0
|
191 |
+
- nbformat=5.1.3=pyhd3eb1b0_0
|
192 |
+
- ncurses=6.2=he6710b0_1
|
193 |
+
- nest-asyncio=1.5.1=pyhd3eb1b0_0
|
194 |
+
- nettle=3.6=he412f7d_0
|
195 |
+
- networkx=2.3=py_0
|
196 |
+
- ninja=1.10.2=hff7bd54_1
|
197 |
+
- notebook=6.4.3=py38h06a4308_0
|
198 |
+
- nspr=4.30=h9c3ff4c_0
|
199 |
+
- nss=3.69=hb5efdd6_0
|
200 |
+
- numpy=1.20.3=py38hf144106_0
|
201 |
+
- numpy-base=1.20.3=py38h74d4b33_0
|
202 |
+
- oauthlib=3.1.1=pyhd8ed1ab_0
|
203 |
+
- olefile=0.46=py_0
|
204 |
+
- opencv=4.5.2=py38h578d9bd_0
|
205 |
+
- openh264=2.1.1=h780b84a_0
|
206 |
+
- openjpeg=2.3.0=h05c96fa_1
|
207 |
+
- openssl=1.1.1l=h7f98852_0
|
208 |
+
- packaging=21.0=pyhd3eb1b0_0
|
209 |
+
- pandas=1.3.2=py38h43a58ef_0
|
210 |
+
- pandocfilters=1.4.3=py38h06a4308_1
|
211 |
+
- parso=0.8.2=pyhd3eb1b0_0
|
212 |
+
- partd=1.2.0=pyhd8ed1ab_0
|
213 |
+
- patchelf=0.12=h2531618_1
|
214 |
+
- pathlib=1.0.1=py38h578d9bd_4
|
215 |
+
- patsy=0.5.1=py_0
|
216 |
+
- pcre=8.45=h295c915_0
|
217 |
+
- perl=5.26.2=h14c3975_0
|
218 |
+
- pexpect=4.8.0=pyhd3eb1b0_3
|
219 |
+
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
220 |
+
- pillow=8.3.1=py38h2c7a002_0
|
221 |
+
- pip=21.2.2=py38h06a4308_0
|
222 |
+
- pixman=0.40.0=h36c2ea0_0
|
223 |
+
- pkginfo=1.7.1=py38h06a4308_0
|
224 |
+
- pooch=1.5.1=pyhd8ed1ab_0
|
225 |
+
- portalocker=1.7.0=py38h578d9bd_1
|
226 |
+
- prometheus_client=0.11.0=pyhd3eb1b0_0
|
227 |
+
- prompt-toolkit=3.0.17=pyh06a4308_0
|
228 |
+
- protobuf=3.15.8=py38h709712a_0
|
229 |
+
- psutil=5.8.0=py38h27cfd23_1
|
230 |
+
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
231 |
+
- py-lief=0.10.1=py38h403a769_0
|
232 |
+
- py-opencv=4.5.2=py38hd0cf306_0
|
233 |
+
- pyasn1=0.4.8=py_0
|
234 |
+
- pyasn1-modules=0.2.7=py_0
|
235 |
+
- pycosat=0.6.3=py38h7b6447c_1
|
236 |
+
- pycparser=2.20=py_2
|
237 |
+
- pygments=2.10.0=pyhd3eb1b0_0
|
238 |
+
- pyjwt=2.1.0=pyhd8ed1ab_0
|
239 |
+
- pyopenssl=20.0.1=pyhd3eb1b0_1
|
240 |
+
- pyparsing=2.4.7=pyhd3eb1b0_0
|
241 |
+
- pypng=0.0.20=py_0
|
242 |
+
- pyqt=5.12.3=py38h578d9bd_7
|
243 |
+
- pyqt-impl=5.12.3=py38h7400c14_7
|
244 |
+
- pyqt5-sip=4.19.18=py38h709712a_7
|
245 |
+
- pyqtchart=5.12=py38h7400c14_7
|
246 |
+
- pyqtwebengine=5.12.1=py38h7400c14_7
|
247 |
+
- pyrsistent=0.17.3=py38h7b6447c_0
|
248 |
+
- pysocks=1.7.1=py38h06a4308_0
|
249 |
+
- python=3.8.10=h12debd9_8
|
250 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
251 |
+
- python-libarchive-c=2.9=pyhd3eb1b0_1
|
252 |
+
- python-lmdb=0.99=py38h709712a_0
|
253 |
+
- python_abi=3.8=2_cp38
|
254 |
+
- pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0
|
255 |
+
- pytz=2021.1=pyhd3eb1b0_0
|
256 |
+
- pyu2f=0.1.5=pyhd8ed1ab_0
|
257 |
+
- pywavelets=1.1.1=py38h5c078b8_3
|
258 |
+
- pyyaml=5.4.1=py38h27cfd23_1
|
259 |
+
- pyzmq=22.2.1=py38h295c915_1
|
260 |
+
- qt=5.12.9=hda022c4_4
|
261 |
+
- qtpy=1.9.0=py_0
|
262 |
+
- readline=8.1=h27cfd23_0
|
263 |
+
- regex=2021.8.28=py38h497a2fe_0
|
264 |
+
- requests=2.25.1=pyhd3eb1b0_0
|
265 |
+
- requests-oauthlib=1.3.0=pyh9f0ad1d_0
|
266 |
+
- ripgrep=12.1.1=0
|
267 |
+
- rsa=4.7.2=pyh44b312d_0
|
268 |
+
- ruamel_yaml=0.15.100=py38h27cfd23_0
|
269 |
+
- scikit-image=0.18.3=py38h43a58ef_0
|
270 |
+
- scikit-learn=1.0=py38hacb3eff_1
|
271 |
+
- scipy=1.7.1=py38h56a6a73_0
|
272 |
+
- seaborn=0.11.2=hd8ed1ab_0
|
273 |
+
- seaborn-base=0.11.2=pyhd8ed1ab_0
|
274 |
+
- send2trash=1.5.0=pyhd3eb1b0_1
|
275 |
+
- setuptools=52.0.0=py38h06a4308_0
|
276 |
+
- shapely=1.8.0=py38hf7953bd_1
|
277 |
+
- sip=4.19.13=py38he6710b0_0
|
278 |
+
- sniffio=1.2.0=py38h06a4308_1
|
279 |
+
- soupsieve=2.2.1=pyhd3eb1b0_0
|
280 |
+
- sqlite=3.36.0=hc218d9a_0
|
281 |
+
- statsmodels=0.12.2=py38h5c078b8_0
|
282 |
+
- tensorboard=2.6.0=pyhd8ed1ab_1
|
283 |
+
- tensorboard-data-server=0.6.0=py38h2b97feb_0
|
284 |
+
- tensorboard-plugin-wit=1.8.0=pyh44b312d_0
|
285 |
+
- tensorboardx=2.4=pyhd8ed1ab_0
|
286 |
+
- terminado=0.9.4=py38h06a4308_0
|
287 |
+
- testpath=0.5.0=pyhd3eb1b0_0
|
288 |
+
- threadpoolctl=3.0.0=pyh8a188c0_0
|
289 |
+
- tifffile=2019.7.26.2=py38_0
|
290 |
+
- tk=8.6.10=hbc83047_0
|
291 |
+
- toolz=0.11.1=py_0
|
292 |
+
- torchfile=0.1.0=py_0
|
293 |
+
- tornado=6.1=py38h27cfd23_0
|
294 |
+
- tqdm=4.62.1=pyhd3eb1b0_1
|
295 |
+
- traitlets=5.0.5=pyhd3eb1b0_0
|
296 |
+
- typing_extensions=3.10.0.0=pyh06a4308_0
|
297 |
+
- urllib3=1.26.6=pyhd3eb1b0_1
|
298 |
+
- wcwidth=0.2.5=py_0
|
299 |
+
- webencodings=0.5.1=py38_1
|
300 |
+
- werkzeug=1.0.1=pyhd3eb1b0_0
|
301 |
+
- wheel=0.37.0=pyhd3eb1b0_0
|
302 |
+
- widgetsnbextension=3.5.1=py38_0
|
303 |
+
- x264=1!161.3030=h7f98852_1
|
304 |
+
- xmltodict=0.12.0=py_0
|
305 |
+
- xz=5.2.5=h7b6447c_0
|
306 |
+
- yacs=0.1.6=py_0
|
307 |
+
- yaml=0.2.5=h7b6447c_0
|
308 |
+
- yarl=1.6.3=py38h497a2fe_2
|
309 |
+
- zeromq=4.3.4=h2531618_0
|
310 |
+
- zipp=3.5.0=pyhd3eb1b0_0
|
311 |
+
- zlib=1.2.11=h7b6447c_3
|
312 |
+
- zstd=1.4.9=haebb681_0
|
313 |
+
- pip:
|
314 |
+
- addict==2.4.0
|
315 |
+
- altair==4.2.0
|
316 |
+
- astor==0.8.1
|
317 |
+
- astunparse==1.6.3
|
318 |
+
- backports-zoneinfo==0.2.1
|
319 |
+
- base58==2.1.1
|
320 |
+
- basicsr==1.3.4.1
|
321 |
+
- boto3==1.18.33
|
322 |
+
- botocore==1.21.33
|
323 |
+
- clang==5.0
|
324 |
+
- clean-fid==0.1.22
|
325 |
+
- clip==1.0
|
326 |
+
- colorama==0.4.4
|
327 |
+
- commonmark==0.9.1
|
328 |
+
- cython==0.29.30
|
329 |
+
- einops==0.3.2
|
330 |
+
- enum-compat==0.0.3
|
331 |
+
- facexlib==0.2.0.3
|
332 |
+
- filterpy==1.4.5
|
333 |
+
- flatbuffers==1.12
|
334 |
+
- gast==0.4.0
|
335 |
+
- google-pasta==0.2.0
|
336 |
+
- grpcio==1.39.0
|
337 |
+
- h5py==3.1.0
|
338 |
+
- ipdb==0.13.9
|
339 |
+
- jacinle==1.0.0
|
340 |
+
- jmespath==0.10.0
|
341 |
+
- jsonpickle==2.2.0
|
342 |
+
- keras==2.7.0
|
343 |
+
- keras-preprocessing==1.1.2
|
344 |
+
- libclang==12.0.0
|
345 |
+
- llvmlite==0.37.0
|
346 |
+
- lpips==0.1.4
|
347 |
+
- numba==0.54.0
|
348 |
+
- opencv-python==4.5.3.56
|
349 |
+
- opt-einsum==3.3.0
|
350 |
+
- pkgconfig==1.5.5
|
351 |
+
- pyarrow==8.0.0
|
352 |
+
- pydantic==1.8.2
|
353 |
+
- pydeck==0.7.1
|
354 |
+
- pyhocon==0.3.58
|
355 |
+
- pytz-deprecation-shim==0.1.0.post0
|
356 |
+
- pyvis==0.2.1
|
357 |
+
- realesrgan==0.2.2.3
|
358 |
+
- rich==10.9.0
|
359 |
+
- s3transfer==0.5.0
|
360 |
+
- six==1.15.0
|
361 |
+
- sklearn==0.0
|
362 |
+
- streamlit==0.64.0
|
363 |
+
- tabulate==0.8.9
|
364 |
+
- tb-nightly==2.7.0a20210827
|
365 |
+
- tensorflow-estimator==2.7.0
|
366 |
+
- tensorflow-gpu==2.7.0
|
367 |
+
- tensorflow-io-gcs-filesystem==0.21.0
|
368 |
+
- tensorfn==0.1.19
|
369 |
+
- termcolor==1.1.0
|
370 |
+
- toml==0.10.2
|
371 |
+
- torchsample==0.1.3
|
372 |
+
- torchvision==0.10.0+cu111
|
373 |
+
- typing-extensions==3.7.4.3
|
374 |
+
- tzdata==2022.1
|
375 |
+
- tzlocal==4.2
|
376 |
+
- validators==0.19.0
|
377 |
+
- vit-pytorch==0.24.3
|
378 |
+
- watchdog==2.1.8
|
379 |
+
- wrapt==1.12.1
|
380 |
+
- yapf==0.31.0
|
expansion/__init__.py
ADDED
File without changes
|
expansion/dataloader/__init__.py
ADDED
File without changes
|
expansion/dataloader/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (162 Bytes). View file
|
|
expansion/dataloader/__pycache__/seqlist.cpython-38.pyc
ADDED
Binary file (1.12 kB). View file
|
|
expansion/dataloader/chairslist.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
import glob
|
8 |
+
|
9 |
+
IMG_EXTENSIONS = [
|
10 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
11 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
def is_image_file(filename):
|
16 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
17 |
+
|
18 |
+
def dataloader(filepath):
|
19 |
+
l0_train = []
|
20 |
+
l1_train = []
|
21 |
+
flow_train = []
|
22 |
+
for flow_map in sorted(glob.glob(os.path.join(filepath,'*_flow.flo'))):
|
23 |
+
root_filename = flow_map[:-9]
|
24 |
+
img1 = root_filename+'_img1.ppm'
|
25 |
+
img2 = root_filename+'_img2.ppm'
|
26 |
+
if not (os.path.isfile(os.path.join(filepath,img1)) and os.path.isfile(os.path.join(filepath,img2))):
|
27 |
+
continue
|
28 |
+
|
29 |
+
l0_train.append(img1)
|
30 |
+
l1_train.append(img2)
|
31 |
+
flow_train.append(flow_map)
|
32 |
+
|
33 |
+
return l0_train, l1_train, flow_train
|
expansion/dataloader/chairssdlist.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
import glob
|
8 |
+
|
9 |
+
IMG_EXTENSIONS = [
|
10 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
11 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
def is_image_file(filename):
|
16 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
17 |
+
|
18 |
+
def dataloader(filepath):
|
19 |
+
l0_train = []
|
20 |
+
l1_train = []
|
21 |
+
flow_train = []
|
22 |
+
for flow_map in sorted(glob.glob('%s/flow/*.pfm'%filepath)):
|
23 |
+
img1 = flow_map.replace('flow','t0').replace('.pfm','.png')
|
24 |
+
img2 = flow_map.replace('flow','t1').replace('.pfm','.png')
|
25 |
+
|
26 |
+
l0_train.append(img1)
|
27 |
+
l1_train.append(img2)
|
28 |
+
flow_train.append(flow_map)
|
29 |
+
|
30 |
+
return l0_train, l1_train, flow_train
|
expansion/dataloader/depth_transforms.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import numbers
|
6 |
+
import types
|
7 |
+
import scipy.ndimage as ndimage
|
8 |
+
import pdb
|
9 |
+
import torchvision
|
10 |
+
import PIL.Image as Image
|
11 |
+
import cv2
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
|
15 |
+
class Compose(object):
|
16 |
+
""" Composes several co_transforms together.
|
17 |
+
For example:
|
18 |
+
>>> co_transforms.Compose([
|
19 |
+
>>> co_transforms.CenterCrop(10),
|
20 |
+
>>> co_transforms.ToTensor(),
|
21 |
+
>>> ])
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, co_transforms):
|
25 |
+
self.co_transforms = co_transforms
|
26 |
+
|
27 |
+
def __call__(self, input, target,intr):
|
28 |
+
for t in self.co_transforms:
|
29 |
+
input,target,intr = t(input,target,intr)
|
30 |
+
return input,target,intr
|
31 |
+
|
32 |
+
|
33 |
+
class Scale(object):
|
34 |
+
""" Rescales the inputs and target arrays to the given 'size'.
|
35 |
+
'size' will be the size of the smaller edge.
|
36 |
+
For example, if height > width, then image will be
|
37 |
+
rescaled to (size * height / width, size)
|
38 |
+
size: size of the smaller edge
|
39 |
+
interpolation order: Default: 2 (bilinear)
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, size, order=1):
|
43 |
+
self.ratio = size
|
44 |
+
self.order = order
|
45 |
+
if order==0:
|
46 |
+
self.code=cv2.INTER_NEAREST
|
47 |
+
elif order==1:
|
48 |
+
self.code=cv2.INTER_LINEAR
|
49 |
+
elif order==2:
|
50 |
+
self.code=cv2.INTER_CUBIC
|
51 |
+
|
52 |
+
def __call__(self, inputs, target):
|
53 |
+
if self.ratio==1:
|
54 |
+
return inputs, target
|
55 |
+
h, w, _ = inputs[0].shape
|
56 |
+
ratio = self.ratio
|
57 |
+
|
58 |
+
inputs[0] = cv2.resize(inputs[0], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR)
|
59 |
+
inputs[1] = cv2.resize(inputs[1], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR)
|
60 |
+
# keep the mask same
|
61 |
+
tmp = cv2.resize(target[:,:,2], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_NEAREST)
|
62 |
+
target = cv2.resize(target, None, fx=ratio,fy=ratio,interpolation=self.code) * ratio
|
63 |
+
target[:,:,2] = tmp
|
64 |
+
|
65 |
+
|
66 |
+
return inputs, target
|
67 |
+
|
68 |
+
|
69 |
+
class RandomCrop(object):
|
70 |
+
"""Crops the given PIL.Image at a random location to have a region of
|
71 |
+
the given size. size can be a tuple (target_height, target_width)
|
72 |
+
or an integer, in which case the target will be of a square shape (size, size)
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, size):
|
76 |
+
if isinstance(size, numbers.Number):
|
77 |
+
self.size = (int(size), int(size))
|
78 |
+
else:
|
79 |
+
self.size = size
|
80 |
+
|
81 |
+
def __call__(self, inputs,target,intr):
|
82 |
+
h, w, _ = inputs[0].shape
|
83 |
+
th, tw = self.size
|
84 |
+
if w < tw: tw=w
|
85 |
+
if h < th: th=h
|
86 |
+
|
87 |
+
x1 = random.randint(0, w - tw)
|
88 |
+
y1 = random.randint(0, h - th)
|
89 |
+
intr[1] -= x1
|
90 |
+
intr[2] -= y1
|
91 |
+
|
92 |
+
inputs[0] = inputs[0][y1: y1 + th,x1: x1 + tw].astype(float)
|
93 |
+
inputs[1] = inputs[1][y1: y1 + th,x1: x1 + tw].astype(float)
|
94 |
+
return inputs, target[y1: y1 + th,x1: x1 + tw].astype(float), list(np.asarray(intr).astype(float)) + list(np.asarray([1.,0.,0.,1.,0.,0.]).astype(float))
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
class SpatialAug(object):
|
99 |
+
def __init__(self, crop, scale=None, rot=None, trans=None, squeeze=None, schedule_coeff=1, order=1, black=False):
|
100 |
+
self.crop = crop
|
101 |
+
self.scale = scale
|
102 |
+
self.rot = rot
|
103 |
+
self.trans = trans
|
104 |
+
self.squeeze = squeeze
|
105 |
+
self.t = np.zeros(6)
|
106 |
+
self.schedule_coeff = schedule_coeff
|
107 |
+
self.order = order
|
108 |
+
self.black = black
|
109 |
+
|
110 |
+
def to_identity(self):
|
111 |
+
self.t[0] = 1; self.t[2] = 0; self.t[4] = 0; self.t[1] = 0; self.t[3] = 1; self.t[5] = 0;
|
112 |
+
|
113 |
+
def left_multiply(self, u0, u1, u2, u3, u4, u5):
|
114 |
+
result = np.zeros(6)
|
115 |
+
result[0] = self.t[0]*u0 + self.t[1]*u2;
|
116 |
+
result[1] = self.t[0]*u1 + self.t[1]*u3;
|
117 |
+
|
118 |
+
result[2] = self.t[2]*u0 + self.t[3]*u2;
|
119 |
+
result[3] = self.t[2]*u1 + self.t[3]*u3;
|
120 |
+
|
121 |
+
result[4] = self.t[4]*u0 + self.t[5]*u2 + u4;
|
122 |
+
result[5] = self.t[4]*u1 + self.t[5]*u3 + u5;
|
123 |
+
self.t = result
|
124 |
+
|
125 |
+
def inverse(self):
|
126 |
+
result = np.zeros(6)
|
127 |
+
a = self.t[0]; c = self.t[2]; e = self.t[4];
|
128 |
+
b = self.t[1]; d = self.t[3]; f = self.t[5];
|
129 |
+
|
130 |
+
denom = a*d - b*c;
|
131 |
+
|
132 |
+
result[0] = d / denom;
|
133 |
+
result[1] = -b / denom;
|
134 |
+
result[2] = -c / denom;
|
135 |
+
result[3] = a / denom;
|
136 |
+
result[4] = (c*f-d*e) / denom;
|
137 |
+
result[5] = (b*e-a*f) / denom;
|
138 |
+
|
139 |
+
return result
|
140 |
+
|
141 |
+
def grid_transform(self, meshgrid, t, normalize=True, gridsize=None):
|
142 |
+
if gridsize is None:
|
143 |
+
h, w = meshgrid[0].shape
|
144 |
+
else:
|
145 |
+
h, w = gridsize
|
146 |
+
vgrid = torch.cat([(meshgrid[0] * t[0] + meshgrid[1] * t[2] + t[4])[:,:,np.newaxis],
|
147 |
+
(meshgrid[0] * t[1] + meshgrid[1] * t[3] + t[5])[:,:,np.newaxis]],-1)
|
148 |
+
if normalize:
|
149 |
+
vgrid[:,:,0] = 2.0*vgrid[:,:,0]/max(w-1,1)-1.0
|
150 |
+
vgrid[:,:,1] = 2.0*vgrid[:,:,1]/max(h-1,1)-1.0
|
151 |
+
return vgrid
|
152 |
+
|
153 |
+
|
154 |
+
def __call__(self, inputs, target, intr):
|
155 |
+
h, w, _ = inputs[0].shape
|
156 |
+
th, tw = self.crop
|
157 |
+
meshgrid = torch.meshgrid([torch.Tensor(range(th)), torch.Tensor(range(tw))])[::-1]
|
158 |
+
cornergrid = torch.meshgrid([torch.Tensor([0,th-1]), torch.Tensor([0,tw-1])])[::-1]
|
159 |
+
|
160 |
+
for i in range(50):
|
161 |
+
# im0
|
162 |
+
self.to_identity()
|
163 |
+
#TODO add mirror
|
164 |
+
if np.random.binomial(1,0.5):
|
165 |
+
mirror = True
|
166 |
+
else:
|
167 |
+
mirror = False
|
168 |
+
##TODO
|
169 |
+
#mirror = False
|
170 |
+
if mirror:
|
171 |
+
self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th);
|
172 |
+
else:
|
173 |
+
self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
|
174 |
+
scale0 = 1; scale1 = 1; squeeze0 = 1; squeeze1 = 1;
|
175 |
+
if not self.rot is None:
|
176 |
+
rot0 = np.random.uniform(-self.rot[0],+self.rot[0])
|
177 |
+
rot1 = np.random.uniform(-self.rot[1]*self.schedule_coeff, self.rot[1]*self.schedule_coeff) + rot0
|
178 |
+
self.left_multiply(np.cos(rot0), np.sin(rot0), -np.sin(rot0), np.cos(rot0), 0, 0)
|
179 |
+
if not self.trans is None:
|
180 |
+
trans0 = np.random.uniform(-self.trans[0],+self.trans[0], 2)
|
181 |
+
trans1 = np.random.uniform(-self.trans[1]*self.schedule_coeff,+self.trans[1]*self.schedule_coeff, 2) + trans0
|
182 |
+
self.left_multiply(1, 0, 0, 1, trans0[0] * tw, trans0[1] * th)
|
183 |
+
if not self.squeeze is None:
|
184 |
+
squeeze0 = np.exp(np.random.uniform(-self.squeeze[0], self.squeeze[0]))
|
185 |
+
squeeze1 = np.exp(np.random.uniform(-self.squeeze[1]*self.schedule_coeff, self.squeeze[1]*self.schedule_coeff)) * squeeze0
|
186 |
+
if not self.scale is None:
|
187 |
+
scale0 = np.exp(np.random.uniform(self.scale[2]-self.scale[0], self.scale[2]+self.scale[0]))
|
188 |
+
scale1 = np.exp(np.random.uniform(-self.scale[1]*self.schedule_coeff, self.scale[1]*self.schedule_coeff)) * scale0
|
189 |
+
self.left_multiply(1.0/(scale0*squeeze0), 0, 0, 1.0/(scale0/squeeze0), 0, 0)
|
190 |
+
|
191 |
+
self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
|
192 |
+
transmat0 = self.t.copy()
|
193 |
+
|
194 |
+
# im1
|
195 |
+
self.to_identity()
|
196 |
+
if mirror:
|
197 |
+
self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th);
|
198 |
+
else:
|
199 |
+
self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
|
200 |
+
if not self.rot is None:
|
201 |
+
self.left_multiply(np.cos(rot1), np.sin(rot1), -np.sin(rot1), np.cos(rot1), 0, 0)
|
202 |
+
if not self.trans is None:
|
203 |
+
self.left_multiply(1, 0, 0, 1, trans1[0] * tw, trans1[1] * th)
|
204 |
+
self.left_multiply(1.0/(scale1*squeeze1), 0, 0, 1.0/(scale1/squeeze1), 0, 0)
|
205 |
+
self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
|
206 |
+
transmat1 = self.t.copy()
|
207 |
+
transmat1_inv = self.inverse()
|
208 |
+
|
209 |
+
if self.black:
|
210 |
+
# black augmentation, allowing 0 values in the input images
|
211 |
+
# https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/black_augmentation_layer.cu
|
212 |
+
break
|
213 |
+
else:
|
214 |
+
if ((self.grid_transform(cornergrid, transmat0, gridsize=[float(h),float(w)]).abs()>1).sum() +\
|
215 |
+
(self.grid_transform(cornergrid, transmat1, gridsize=[float(h),float(w)]).abs()>1).sum()) == 0:
|
216 |
+
break
|
217 |
+
if i==49:
|
218 |
+
print('max_iter in augmentation')
|
219 |
+
self.to_identity()
|
220 |
+
self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
|
221 |
+
self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
|
222 |
+
transmat0 = self.t.copy()
|
223 |
+
transmat1 = self.t.copy()
|
224 |
+
|
225 |
+
# do the real work
|
226 |
+
vgrid = self.grid_transform(meshgrid, transmat0,gridsize=[float(h),float(w)])
|
227 |
+
inputs_0 = F.grid_sample(torch.Tensor(inputs[0]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
|
228 |
+
if self.order == 0:
|
229 |
+
target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0)
|
230 |
+
else:
|
231 |
+
target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
|
232 |
+
|
233 |
+
mask_0 = target[:,:,2:3].copy(); mask_0[mask_0==0]=np.nan
|
234 |
+
if self.order == 0:
|
235 |
+
mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0)
|
236 |
+
else:
|
237 |
+
mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
|
238 |
+
mask_0[torch.isnan(mask_0)] = 0
|
239 |
+
|
240 |
+
|
241 |
+
vgrid = self.grid_transform(meshgrid, transmat1,gridsize=[float(h),float(w)])
|
242 |
+
inputs_1 = F.grid_sample(torch.Tensor(inputs[1]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
|
243 |
+
|
244 |
+
# flow
|
245 |
+
pos = target_0[:,:,:2] + self.grid_transform(meshgrid, transmat0,normalize=False)
|
246 |
+
pos = self.grid_transform(pos.permute(2,0,1),transmat1_inv,normalize=False)
|
247 |
+
if target_0.shape[2]>=4:
|
248 |
+
# scale
|
249 |
+
exp = target_0[:,:,3:] * scale1 / scale0
|
250 |
+
target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1),
|
251 |
+
(pos[:,:,1] - meshgrid[1]).unsqueeze(-1),
|
252 |
+
mask_0,
|
253 |
+
exp], -1)
|
254 |
+
else:
|
255 |
+
target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1),
|
256 |
+
(pos[:,:,1] - meshgrid[1]).unsqueeze(-1),
|
257 |
+
mask_0], -1)
|
258 |
+
inputs = [np.asarray(inputs_0).astype(float), np.asarray(inputs_1).astype(float)]
|
259 |
+
target = np.asarray(target).astype(float)
|
260 |
+
return inputs,target, list(np.asarray(intr+list(transmat0)).astype(float))
|
261 |
+
|
262 |
+
|
263 |
+
|
264 |
+
class pseudoPCAAug(object):
|
265 |
+
"""
|
266 |
+
Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
|
267 |
+
This version is faster.
|
268 |
+
"""
|
269 |
+
def __init__(self, schedule_coeff=1):
|
270 |
+
self.augcolor = torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.5, hue=0.5/3.14)
|
271 |
+
|
272 |
+
def __call__(self, inputs, target,intr):
|
273 |
+
img = np.concatenate([inputs[0],inputs[1]],0)
|
274 |
+
shape = img.shape[0]//2
|
275 |
+
aug_img = np.asarray(self.augcolor(Image.fromarray(np.uint8(img*255))))/255.
|
276 |
+
inputs[0] = aug_img[:shape]
|
277 |
+
inputs[1] = aug_img[shape:]
|
278 |
+
#inputs[0] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[0]*255))))/255.
|
279 |
+
#inputs[1] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[1]*255))))/255.
|
280 |
+
return inputs,target,intr
|
281 |
+
|
282 |
+
|
283 |
+
class PCAAug(object):
|
284 |
+
"""
|
285 |
+
Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
|
286 |
+
"""
|
287 |
+
def __init__(self, lmult_pow =[0.4, 0,-0.2],
|
288 |
+
lmult_mult =[0.4, 0,0, ],
|
289 |
+
lmult_add =[0.03,0,0, ],
|
290 |
+
sat_pow =[0.4, 0,0, ],
|
291 |
+
sat_mult =[0.5, 0,-0.3],
|
292 |
+
sat_add =[0.03,0,0, ],
|
293 |
+
col_pow =[0.4, 0,0, ],
|
294 |
+
col_mult =[0.2, 0,0, ],
|
295 |
+
col_add =[0.02,0,0, ],
|
296 |
+
ladd_pow =[0.4, 0,0, ],
|
297 |
+
ladd_mult =[0.4, 0,0, ],
|
298 |
+
ladd_add =[0.04,0,0, ],
|
299 |
+
col_rotate =[1., 0,0, ],
|
300 |
+
schedule_coeff=1):
|
301 |
+
# no mean
|
302 |
+
self.pow_nomean = [1,1,1]
|
303 |
+
self.add_nomean = [0,0,0]
|
304 |
+
self.mult_nomean = [1,1,1]
|
305 |
+
self.pow_withmean = [1,1,1]
|
306 |
+
self.add_withmean = [0,0,0]
|
307 |
+
self.mult_withmean = [1,1,1]
|
308 |
+
self.lmult_pow = 1
|
309 |
+
self.lmult_mult = 1
|
310 |
+
self.lmult_add = 0
|
311 |
+
self.col_angle = 0
|
312 |
+
if not ladd_pow is None:
|
313 |
+
self.pow_nomean[0] =np.exp(np.random.normal(ladd_pow[2], ladd_pow[0]))
|
314 |
+
if not col_pow is None:
|
315 |
+
self.pow_nomean[1] =np.exp(np.random.normal(col_pow[2], col_pow[0]))
|
316 |
+
self.pow_nomean[2] =np.exp(np.random.normal(col_pow[2], col_pow[0]))
|
317 |
+
|
318 |
+
if not ladd_add is None:
|
319 |
+
self.add_nomean[0] =np.random.normal(ladd_add[2], ladd_add[0])
|
320 |
+
if not col_add is None:
|
321 |
+
self.add_nomean[1] =np.random.normal(col_add[2], col_add[0])
|
322 |
+
self.add_nomean[2] =np.random.normal(col_add[2], col_add[0])
|
323 |
+
|
324 |
+
if not ladd_mult is None:
|
325 |
+
self.mult_nomean[0] =np.exp(np.random.normal(ladd_mult[2], ladd_mult[0]))
|
326 |
+
if not col_mult is None:
|
327 |
+
self.mult_nomean[1] =np.exp(np.random.normal(col_mult[2], col_mult[0]))
|
328 |
+
self.mult_nomean[2] =np.exp(np.random.normal(col_mult[2], col_mult[0]))
|
329 |
+
|
330 |
+
# with mean
|
331 |
+
if not sat_pow is None:
|
332 |
+
self.pow_withmean[1] =np.exp(np.random.uniform(sat_pow[2]-sat_pow[0], sat_pow[2]+sat_pow[0]))
|
333 |
+
self.pow_withmean[2] =self.pow_withmean[1]
|
334 |
+
if not sat_add is None:
|
335 |
+
self.add_withmean[1] =np.random.uniform(sat_add[2]-sat_add[0], sat_add[2]+sat_add[0])
|
336 |
+
self.add_withmean[2] =self.add_withmean[1]
|
337 |
+
if not sat_mult is None:
|
338 |
+
self.mult_withmean[1] = np.exp(np.random.uniform(sat_mult[2]-sat_mult[0], sat_mult[2]+sat_mult[0]))
|
339 |
+
self.mult_withmean[2] = self.mult_withmean[1]
|
340 |
+
|
341 |
+
if not lmult_pow is None:
|
342 |
+
self.lmult_pow = np.exp(np.random.uniform(lmult_pow[2]-lmult_pow[0], lmult_pow[2]+lmult_pow[0]))
|
343 |
+
if not lmult_mult is None:
|
344 |
+
self.lmult_mult= np.exp(np.random.uniform(lmult_mult[2]-lmult_mult[0], lmult_mult[2]+lmult_mult[0]))
|
345 |
+
if not lmult_add is None:
|
346 |
+
self.lmult_add = np.random.uniform(lmult_add[2]-lmult_add[0], lmult_add[2]+lmult_add[0])
|
347 |
+
if not col_rotate is None:
|
348 |
+
self.col_angle= np.random.uniform(col_rotate[2]-col_rotate[0], col_rotate[2]+col_rotate[0])
|
349 |
+
|
350 |
+
# eigen vectors
|
351 |
+
self.eigvec = np.reshape([0.51,0.56,0.65,0.79,0.01,-0.62,0.35,-0.83,0.44],[3,3]).transpose()
|
352 |
+
|
353 |
+
|
354 |
+
def __call__(self, inputs, target, intr):
|
355 |
+
inputs[0] = self.pca_image(inputs[0])
|
356 |
+
inputs[1] = self.pca_image(inputs[1])
|
357 |
+
return inputs,target,intr
|
358 |
+
|
359 |
+
def pca_image(self, rgb):
|
360 |
+
eig = np.dot(rgb, self.eigvec)
|
361 |
+
max_rgb = np.clip(rgb,0,np.inf).max((0,1))
|
362 |
+
min_rgb = rgb.min((0,1))
|
363 |
+
mean_rgb = rgb.mean((0,1))
|
364 |
+
max_abs_eig = np.abs(eig).max((0,1))
|
365 |
+
max_l = np.sqrt(np.sum(max_abs_eig*max_abs_eig))
|
366 |
+
mean_eig = np.dot(mean_rgb, self.eigvec)
|
367 |
+
|
368 |
+
# no-mean stuff
|
369 |
+
eig -= mean_eig[np.newaxis, np.newaxis]
|
370 |
+
|
371 |
+
for c in range(3):
|
372 |
+
if max_abs_eig[c] > 1e-2:
|
373 |
+
mean_eig[c] /= max_abs_eig[c]
|
374 |
+
eig[:,:,c] = eig[:,:,c] / max_abs_eig[c];
|
375 |
+
eig[:,:,c] = np.power(np.abs(eig[:,:,c]),self.pow_nomean[c]) *\
|
376 |
+
((eig[:,:,c] > 0) -0.5)*2
|
377 |
+
eig[:,:,c] = eig[:,:,c] + self.add_nomean[c]
|
378 |
+
eig[:,:,c] = eig[:,:,c] * self.mult_nomean[c]
|
379 |
+
eig += mean_eig[np.newaxis,np.newaxis]
|
380 |
+
|
381 |
+
# withmean stuff
|
382 |
+
if max_abs_eig[0] > 1e-2:
|
383 |
+
eig[:,:,0] = np.power(np.abs(eig[:,:,0]),self.pow_withmean[0]) * \
|
384 |
+
((eig[:,:,0]>0)-0.5)*2;
|
385 |
+
eig[:,:,0] = eig[:,:,0] + self.add_withmean[0];
|
386 |
+
eig[:,:,0] = eig[:,:,0] * self.mult_withmean[0];
|
387 |
+
|
388 |
+
s = np.sqrt(eig[:,:,1]*eig[:,:,1] + eig[:,:,2] * eig[:,:,2])
|
389 |
+
smask = s > 1e-2
|
390 |
+
s1 = np.power(s, self.pow_withmean[1]);
|
391 |
+
s1 = np.clip(s1 + self.add_withmean[1], 0,np.inf)
|
392 |
+
s1 = s1 * self.mult_withmean[1]
|
393 |
+
s1 = s1 * smask + s*(1-smask)
|
394 |
+
|
395 |
+
# color angle
|
396 |
+
if self.col_angle!=0:
|
397 |
+
temp1 = np.cos(self.col_angle) * eig[:,:,1] - np.sin(self.col_angle) * eig[:,:,2]
|
398 |
+
temp2 = np.sin(self.col_angle) * eig[:,:,1] + np.cos(self.col_angle) * eig[:,:,2]
|
399 |
+
eig[:,:,1] = temp1
|
400 |
+
eig[:,:,2] = temp2
|
401 |
+
|
402 |
+
# to origin magnitude
|
403 |
+
for c in range(3):
|
404 |
+
if max_abs_eig[c] > 1e-2:
|
405 |
+
eig[:,:,c] = eig[:,:,c] * max_abs_eig[c]
|
406 |
+
|
407 |
+
if max_l > 1e-2:
|
408 |
+
l1 = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2])
|
409 |
+
l1 = l1 / max_l
|
410 |
+
|
411 |
+
eig[:,:,1][smask] = (eig[:,:,1] / s * s1)[smask]
|
412 |
+
eig[:,:,2][smask] = (eig[:,:,2] / s * s1)[smask]
|
413 |
+
#eig[:,:,1] = (eig[:,:,1] / s * s1) * smask + eig[:,:,1] * (1-smask)
|
414 |
+
#eig[:,:,2] = (eig[:,:,2] / s * s1) * smask + eig[:,:,2] * (1-smask)
|
415 |
+
|
416 |
+
if max_l > 1e-2:
|
417 |
+
l = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2])
|
418 |
+
l1 = np.power(l1, self.lmult_pow)
|
419 |
+
l1 = np.clip(l1 + self.lmult_add, 0, np.inf)
|
420 |
+
l1 = l1 * self.lmult_mult
|
421 |
+
l1 = l1 * max_l
|
422 |
+
lmask = l > 1e-2
|
423 |
+
eig[lmask] = (eig / l[:,:,np.newaxis] * l1[:,:,np.newaxis])[lmask]
|
424 |
+
for c in range(3):
|
425 |
+
eig[:,:,c][lmask] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c]))[lmask]
|
426 |
+
# for c in range(3):
|
427 |
+
# # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] * lmask + eig[:,:,c] * (1-lmask)
|
428 |
+
# eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask]
|
429 |
+
# eig[:,:,c] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c])) * lmask + eig[:,:,c] * (1-lmask)
|
430 |
+
|
431 |
+
return np.clip(np.dot(eig, self.eigvec.transpose()), 0, 1)
|
432 |
+
|
433 |
+
|
434 |
+
class ChromaticAug(object):
|
435 |
+
"""
|
436 |
+
Chromatic augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
|
437 |
+
"""
|
438 |
+
def __init__(self, noise = 0.06,
|
439 |
+
gamma = 0.02,
|
440 |
+
brightness = 0.02,
|
441 |
+
contrast = 0.02,
|
442 |
+
color = 0.02,
|
443 |
+
schedule_coeff=1):
|
444 |
+
|
445 |
+
self.noise = np.random.uniform(0,noise)
|
446 |
+
self.gamma = np.exp(np.random.normal(0, gamma*schedule_coeff))
|
447 |
+
self.brightness = np.random.normal(0, brightness*schedule_coeff)
|
448 |
+
self.contrast = np.exp(np.random.normal(0, contrast*schedule_coeff))
|
449 |
+
self.color = np.exp(np.random.normal(0, color*schedule_coeff,3))
|
450 |
+
|
451 |
+
def __call__(self, inputs, target, intr):
|
452 |
+
inputs[1] = self.chrom_aug(inputs[1])
|
453 |
+
# noise
|
454 |
+
inputs[0]+=np.random.normal(0, self.noise, inputs[0].shape)
|
455 |
+
inputs[1]+=np.random.normal(0, self.noise, inputs[0].shape)
|
456 |
+
return inputs,target,intr
|
457 |
+
|
458 |
+
def chrom_aug(self, rgb):
|
459 |
+
# color change
|
460 |
+
mean_in = rgb.sum(-1)
|
461 |
+
rgb = rgb*self.color[np.newaxis,np.newaxis]
|
462 |
+
brightness_coeff = mean_in / (rgb.sum(-1)+0.01)
|
463 |
+
rgb = np.clip(rgb*brightness_coeff[:,:,np.newaxis],0,1)
|
464 |
+
# gamma
|
465 |
+
rgb = np.power(rgb,self.gamma)
|
466 |
+
# brightness
|
467 |
+
rgb += self.brightness
|
468 |
+
# contrast
|
469 |
+
rgb = 0.5 + ( rgb-0.5)*self.contrast
|
470 |
+
rgb = np.clip(rgb, 0, 1)
|
471 |
+
return rgb
|
expansion/dataloader/depthloader.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numbers
|
3 |
+
import torch
|
4 |
+
import torch.utils.data as data
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
import random
|
8 |
+
from PIL import Image, ImageOps
|
9 |
+
import numpy as np
|
10 |
+
import torchvision
|
11 |
+
from . import depth_transforms as flow_transforms
|
12 |
+
import pdb
|
13 |
+
import cv2
|
14 |
+
from utils.flowlib import read_flow
|
15 |
+
from utils.util_flow import readPFM, load_calib_cam_to_cam
|
16 |
+
|
17 |
+
def default_loader(path):
|
18 |
+
return Image.open(path).convert('RGB')
|
19 |
+
|
20 |
+
def flow_loader(path):
|
21 |
+
if '.pfm' in path:
|
22 |
+
data = readPFM(path)[0]
|
23 |
+
data[:,:,2] = 1
|
24 |
+
return data
|
25 |
+
else:
|
26 |
+
return read_flow(path)
|
27 |
+
|
28 |
+
def load_exts(cam_file):
|
29 |
+
with open(cam_file, 'r') as f:
|
30 |
+
lines = f.readlines()
|
31 |
+
|
32 |
+
l_exts = []
|
33 |
+
r_exts = []
|
34 |
+
for l in lines:
|
35 |
+
if 'L ' in l:
|
36 |
+
l_exts.append(np.asarray([float(i) for i in l[2:].strip().split(' ')]).reshape(4,4))
|
37 |
+
if 'R ' in l:
|
38 |
+
r_exts.append(np.asarray([float(i) for i in l[2:].strip().split(' ')]).reshape(4,4))
|
39 |
+
return l_exts,r_exts
|
40 |
+
|
41 |
+
def disparity_loader(path):
|
42 |
+
if '.png' in path:
|
43 |
+
data = Image.open(path)
|
44 |
+
data = np.ascontiguousarray(data,dtype=np.float32)/256
|
45 |
+
return data
|
46 |
+
else:
|
47 |
+
return readPFM(path)[0]
|
48 |
+
|
49 |
+
# triangulation
|
50 |
+
def triangulation(disp, xcoord, ycoord, bl=1, fl = 450, cx = 479.5, cy = 269.5):
|
51 |
+
depth = bl*fl / disp # 450px->15mm focal length
|
52 |
+
X = (xcoord - cx) * depth / fl
|
53 |
+
Y = (ycoord - cy) * depth / fl
|
54 |
+
Z = depth
|
55 |
+
P = np.concatenate((X[np.newaxis],Y[np.newaxis],Z[np.newaxis]),0).reshape(3,-1)
|
56 |
+
P = np.concatenate((P,np.ones((1,P.shape[-1]))),0)
|
57 |
+
return P
|
58 |
+
|
59 |
+
class myImageFloder(data.Dataset):
|
60 |
+
def __init__(self, iml0, iml1, flowl0, loader=default_loader, dploader= flow_loader, scale=1.,shape=[320,448], order=1, noise=0.06, pca_augmentor=True, prob = 1.,sc=False,disp0=None,disp1=None,calib=None ):
|
61 |
+
self.iml0 = iml0
|
62 |
+
self.iml1 = iml1
|
63 |
+
self.flowl0 = flowl0
|
64 |
+
self.loader = loader
|
65 |
+
self.dploader = dploader
|
66 |
+
self.scale=scale
|
67 |
+
self.shape=shape
|
68 |
+
self.order=order
|
69 |
+
self.noise = noise
|
70 |
+
self.pca_augmentor = pca_augmentor
|
71 |
+
self.prob = prob
|
72 |
+
self.sc = sc
|
73 |
+
self.disp0 = disp0
|
74 |
+
self.disp1 = disp1
|
75 |
+
self.calib = calib
|
76 |
+
|
77 |
+
def __getitem__(self, index):
|
78 |
+
iml0 = self.iml0[index]
|
79 |
+
iml1 = self.iml1[index]
|
80 |
+
flowl0= self.flowl0[index]
|
81 |
+
th, tw = self.shape
|
82 |
+
|
83 |
+
iml0 = self.loader(iml0)
|
84 |
+
iml1 = self.loader(iml1)
|
85 |
+
|
86 |
+
# get disparity
|
87 |
+
if self.sc:
|
88 |
+
flowl0 = self.dploader(flowl0)
|
89 |
+
flowl0 = np.ascontiguousarray(flowl0,dtype=np.float32)
|
90 |
+
flowl0[np.isnan(flowl0)] = 1e6 # set to max
|
91 |
+
if 'camera_data.txt' in self.calib[index]:
|
92 |
+
bl=1
|
93 |
+
if '15mm_' in self.calib[index]:
|
94 |
+
fl=450 # 450
|
95 |
+
else:
|
96 |
+
fl=1050
|
97 |
+
cx = 479.5
|
98 |
+
cy = 269.5
|
99 |
+
# negative disp
|
100 |
+
d1 = np.abs(disparity_loader(self.disp0[index]))
|
101 |
+
d2 = np.abs(disparity_loader(self.disp1[index]) + d1)
|
102 |
+
elif 'Sintel' in self.calib[index]:
|
103 |
+
fl = 1000
|
104 |
+
bl = 1
|
105 |
+
cx = 511.5
|
106 |
+
cy = 217.5
|
107 |
+
d1 = np.zeros(flowl0.shape[:2])
|
108 |
+
d2 = np.zeros(flowl0.shape[:2])
|
109 |
+
else:
|
110 |
+
ints = load_calib_cam_to_cam(self.calib[index])
|
111 |
+
fl = ints['K_cam2'][0,0]
|
112 |
+
cx = ints['K_cam2'][0,2]
|
113 |
+
cy = ints['K_cam2'][1,2]
|
114 |
+
bl = ints['b20']-ints['b30']
|
115 |
+
d1 = disparity_loader(self.disp0[index])
|
116 |
+
d2 = disparity_loader(self.disp1[index])
|
117 |
+
#flowl0[:,:,2] = (flowl0[:,:,2]==1).astype(float)
|
118 |
+
flowl0[:,:,2] = np.logical_and(np.logical_and(flowl0[:,:,2]==1, d1!=0), d2!=0).astype(float)
|
119 |
+
|
120 |
+
shape = d1.shape
|
121 |
+
mesh = np.meshgrid(range(shape[1]),range(shape[0]))
|
122 |
+
xcoord = mesh[0].astype(float)
|
123 |
+
ycoord = mesh[1].astype(float)
|
124 |
+
|
125 |
+
# triangulation in two frames
|
126 |
+
P0 = triangulation(d1, xcoord, ycoord, bl=bl, fl = fl, cx = cx, cy = cy)
|
127 |
+
P1 = triangulation(d2, xcoord + flowl0[:,:,0], ycoord + flowl0[:,:,1], bl=bl, fl = fl, cx = cx, cy = cy)
|
128 |
+
dis0 = P0[2]
|
129 |
+
dis1 = P1[2]
|
130 |
+
|
131 |
+
change_size = dis0.reshape(shape).astype(np.float32)
|
132 |
+
flow3d = (P1-P0)[:3].reshape((3,)+shape).transpose((1,2,0))
|
133 |
+
|
134 |
+
gt_normal = np.concatenate((d1[:,:,np.newaxis],d2[:,:,np.newaxis],d2[:,:,np.newaxis]),-1)
|
135 |
+
change_size = np.concatenate((change_size[:,:,np.newaxis],gt_normal,flow3d),2)
|
136 |
+
else:
|
137 |
+
shape = iml0.size
|
138 |
+
shape=[shape[1],shape[0]]
|
139 |
+
flowl0 = np.zeros((shape[0],shape[1],3))
|
140 |
+
change_size = np.zeros((shape[0],shape[1],7))
|
141 |
+
depth = disparity_loader(self.iml1[index].replace('camera','groundtruth'))
|
142 |
+
change_size[:,:,0] = depth
|
143 |
+
|
144 |
+
seqid = self.iml0[index].split('/')[-5].rsplit('_',3)[0]
|
145 |
+
ints = load_calib_cam_to_cam('/data/gengshay/KITTI/%s/calib_cam_to_cam.txt'%seqid)
|
146 |
+
fl = ints['K_cam2'][0,0]
|
147 |
+
cx = ints['K_cam2'][0,2]
|
148 |
+
cy = ints['K_cam2'][1,2]
|
149 |
+
bl = ints['b20']-ints['b30']
|
150 |
+
|
151 |
+
|
152 |
+
iml1 = np.asarray(iml1)/255.
|
153 |
+
iml0 = np.asarray(iml0)/255.
|
154 |
+
iml0 = iml0[:,:,::-1].copy()
|
155 |
+
iml1 = iml1[:,:,::-1].copy()
|
156 |
+
|
157 |
+
## following data augmentation procedure in PWCNet
|
158 |
+
## https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
|
159 |
+
import __main__ # a workaround for "discount_coeff"
|
160 |
+
try:
|
161 |
+
with open('/scratch/gengshay/iter_counts-%d.txt'%int(__main__.args.logname.split('-')[-1]), 'r') as f:
|
162 |
+
iter_counts = int(f.readline())
|
163 |
+
except:
|
164 |
+
iter_counts = 0
|
165 |
+
schedule = [0.5, 1., 50000.] # initial coeff, final_coeff, half life
|
166 |
+
schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * \
|
167 |
+
(2/(1+np.exp(-1.0986*iter_counts/schedule[2])) - 1)
|
168 |
+
|
169 |
+
if self.pca_augmentor:
|
170 |
+
pca_augmentor = flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff)
|
171 |
+
else:
|
172 |
+
pca_augmentor = flow_transforms.Scale(1., order=0)
|
173 |
+
|
174 |
+
if np.random.binomial(1,self.prob):
|
175 |
+
co_transform1 = flow_transforms.Compose([
|
176 |
+
flow_transforms.SpatialAug([th,tw],
|
177 |
+
scale=[0.2,0.,0.1],
|
178 |
+
rot=[0.4,0.],
|
179 |
+
trans=[0.4,0.],
|
180 |
+
squeeze=[0.3,0.], schedule_coeff=schedule_coeff, order=self.order),
|
181 |
+
])
|
182 |
+
else:
|
183 |
+
co_transform1 = flow_transforms.Compose([
|
184 |
+
flow_transforms.RandomCrop([th,tw]),
|
185 |
+
])
|
186 |
+
|
187 |
+
co_transform2 = flow_transforms.Compose([
|
188 |
+
flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff),
|
189 |
+
#flow_transforms.PCAAug(schedule_coeff=schedule_coeff),
|
190 |
+
flow_transforms.ChromaticAug( schedule_coeff=schedule_coeff, noise=self.noise),
|
191 |
+
])
|
192 |
+
|
193 |
+
flowl0 = np.concatenate([flowl0,change_size],-1)
|
194 |
+
augmented,flowl0,intr = co_transform1([iml0, iml1], flowl0, [fl,cx,cy,bl])
|
195 |
+
imol0 = augmented[0]
|
196 |
+
imol1 = augmented[1]
|
197 |
+
augmented,flowl0,intr = co_transform2(augmented, flowl0, intr)
|
198 |
+
|
199 |
+
iml0 = augmented[0]
|
200 |
+
iml1 = augmented[1]
|
201 |
+
flowl0 = flowl0.astype(np.float32)
|
202 |
+
change_size = flowl0[:,:,3:]
|
203 |
+
flowl0 = flowl0[:,:,:3]
|
204 |
+
|
205 |
+
# randomly cover a region
|
206 |
+
sx=0;sy=0;cx=0;cy=0
|
207 |
+
if np.random.binomial(1,0.5):
|
208 |
+
sx = int(np.random.uniform(25,100))
|
209 |
+
sy = int(np.random.uniform(25,100))
|
210 |
+
#sx = int(np.random.uniform(50,150))
|
211 |
+
#sy = int(np.random.uniform(50,150))
|
212 |
+
cx = int(np.random.uniform(sx,iml1.shape[0]-sx))
|
213 |
+
cy = int(np.random.uniform(sy,iml1.shape[1]-sy))
|
214 |
+
iml1[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(iml1,0),0)[np.newaxis,np.newaxis]
|
215 |
+
|
216 |
+
iml0 = torch.Tensor(np.transpose(iml0,(2,0,1)))
|
217 |
+
iml1 = torch.Tensor(np.transpose(iml1,(2,0,1)))
|
218 |
+
|
219 |
+
return iml0, iml1, flowl0, change_size, intr, imol0, imol1, np.asarray([cx-sx,cx+sx,cy-sy,cy+sy])
|
220 |
+
|
221 |
+
def __len__(self):
|
222 |
+
return len(self.iml0)
|
expansion/dataloader/flow_transforms.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import numbers
|
6 |
+
import types
|
7 |
+
import scipy.ndimage as ndimage
|
8 |
+
import pdb
|
9 |
+
import torchvision
|
10 |
+
import PIL.Image as Image
|
11 |
+
import cv2
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
|
15 |
+
class Compose(object):
|
16 |
+
""" Composes several co_transforms together.
|
17 |
+
For example:
|
18 |
+
>>> co_transforms.Compose([
|
19 |
+
>>> co_transforms.CenterCrop(10),
|
20 |
+
>>> co_transforms.ToTensor(),
|
21 |
+
>>> ])
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, co_transforms):
|
25 |
+
self.co_transforms = co_transforms
|
26 |
+
|
27 |
+
def __call__(self, input, target):
|
28 |
+
for t in self.co_transforms:
|
29 |
+
input,target = t(input,target)
|
30 |
+
return input,target
|
31 |
+
|
32 |
+
|
33 |
+
class Scale(object):
|
34 |
+
""" Rescales the inputs and target arrays to the given 'size'.
|
35 |
+
'size' will be the size of the smaller edge.
|
36 |
+
For example, if height > width, then image will be
|
37 |
+
rescaled to (size * height / width, size)
|
38 |
+
size: size of the smaller edge
|
39 |
+
interpolation order: Default: 2 (bilinear)
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, size, order=1):
|
43 |
+
self.ratio = size
|
44 |
+
self.order = order
|
45 |
+
if order==0:
|
46 |
+
self.code=cv2.INTER_NEAREST
|
47 |
+
elif order==1:
|
48 |
+
self.code=cv2.INTER_LINEAR
|
49 |
+
elif order==2:
|
50 |
+
self.code=cv2.INTER_CUBIC
|
51 |
+
|
52 |
+
def __call__(self, inputs, target):
|
53 |
+
if self.ratio==1:
|
54 |
+
return inputs, target
|
55 |
+
h, w, _ = inputs[0].shape
|
56 |
+
ratio = self.ratio
|
57 |
+
|
58 |
+
inputs[0] = cv2.resize(inputs[0], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR)
|
59 |
+
inputs[1] = cv2.resize(inputs[1], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR)
|
60 |
+
# keep the mask same
|
61 |
+
tmp = cv2.resize(target[:,:,2], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_NEAREST)
|
62 |
+
target = cv2.resize(target, None, fx=ratio,fy=ratio,interpolation=self.code) * ratio
|
63 |
+
target[:,:,2] = tmp
|
64 |
+
|
65 |
+
|
66 |
+
return inputs, target
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
class SpatialAug(object):
|
72 |
+
def __init__(self, crop, scale=None, rot=None, trans=None, squeeze=None, schedule_coeff=1, order=1, black=False):
|
73 |
+
self.crop = crop
|
74 |
+
self.scale = scale
|
75 |
+
self.rot = rot
|
76 |
+
self.trans = trans
|
77 |
+
self.squeeze = squeeze
|
78 |
+
self.t = np.zeros(6)
|
79 |
+
self.schedule_coeff = schedule_coeff
|
80 |
+
self.order = order
|
81 |
+
self.black = black
|
82 |
+
|
83 |
+
def to_identity(self):
|
84 |
+
self.t[0] = 1; self.t[2] = 0; self.t[4] = 0; self.t[1] = 0; self.t[3] = 1; self.t[5] = 0;
|
85 |
+
|
86 |
+
def left_multiply(self, u0, u1, u2, u3, u4, u5):
|
87 |
+
result = np.zeros(6)
|
88 |
+
result[0] = self.t[0]*u0 + self.t[1]*u2;
|
89 |
+
result[1] = self.t[0]*u1 + self.t[1]*u3;
|
90 |
+
|
91 |
+
result[2] = self.t[2]*u0 + self.t[3]*u2;
|
92 |
+
result[3] = self.t[2]*u1 + self.t[3]*u3;
|
93 |
+
|
94 |
+
result[4] = self.t[4]*u0 + self.t[5]*u2 + u4;
|
95 |
+
result[5] = self.t[4]*u1 + self.t[5]*u3 + u5;
|
96 |
+
self.t = result
|
97 |
+
|
98 |
+
def inverse(self):
|
99 |
+
result = np.zeros(6)
|
100 |
+
a = self.t[0]; c = self.t[2]; e = self.t[4];
|
101 |
+
b = self.t[1]; d = self.t[3]; f = self.t[5];
|
102 |
+
|
103 |
+
denom = a*d - b*c;
|
104 |
+
|
105 |
+
result[0] = d / denom;
|
106 |
+
result[1] = -b / denom;
|
107 |
+
result[2] = -c / denom;
|
108 |
+
result[3] = a / denom;
|
109 |
+
result[4] = (c*f-d*e) / denom;
|
110 |
+
result[5] = (b*e-a*f) / denom;
|
111 |
+
|
112 |
+
return result
|
113 |
+
|
114 |
+
def grid_transform(self, meshgrid, t, normalize=True, gridsize=None):
|
115 |
+
if gridsize is None:
|
116 |
+
h, w = meshgrid[0].shape
|
117 |
+
else:
|
118 |
+
h, w = gridsize
|
119 |
+
vgrid = torch.cat([(meshgrid[0] * t[0] + meshgrid[1] * t[2] + t[4])[:,:,np.newaxis],
|
120 |
+
(meshgrid[0] * t[1] + meshgrid[1] * t[3] + t[5])[:,:,np.newaxis]],-1)
|
121 |
+
if normalize:
|
122 |
+
vgrid[:,:,0] = 2.0*vgrid[:,:,0]/max(w-1,1)-1.0
|
123 |
+
vgrid[:,:,1] = 2.0*vgrid[:,:,1]/max(h-1,1)-1.0
|
124 |
+
return vgrid
|
125 |
+
|
126 |
+
|
127 |
+
def __call__(self, inputs, target):
|
128 |
+
h, w, _ = inputs[0].shape
|
129 |
+
th, tw = self.crop
|
130 |
+
meshgrid = torch.meshgrid([torch.Tensor(range(th)), torch.Tensor(range(tw))])[::-1]
|
131 |
+
cornergrid = torch.meshgrid([torch.Tensor([0,th-1]), torch.Tensor([0,tw-1])])[::-1]
|
132 |
+
|
133 |
+
for i in range(50):
|
134 |
+
# im0
|
135 |
+
self.to_identity()
|
136 |
+
#TODO add mirror
|
137 |
+
if np.random.binomial(1,0.5):
|
138 |
+
mirror = True
|
139 |
+
else:
|
140 |
+
mirror = False
|
141 |
+
##TODO
|
142 |
+
#mirror = False
|
143 |
+
if mirror:
|
144 |
+
self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th);
|
145 |
+
else:
|
146 |
+
self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
|
147 |
+
scale0 = 1; scale1 = 1; squeeze0 = 1; squeeze1 = 1;
|
148 |
+
if not self.rot is None:
|
149 |
+
rot0 = np.random.uniform(-self.rot[0],+self.rot[0])
|
150 |
+
rot1 = np.random.uniform(-self.rot[1]*self.schedule_coeff, self.rot[1]*self.schedule_coeff) + rot0
|
151 |
+
self.left_multiply(np.cos(rot0), np.sin(rot0), -np.sin(rot0), np.cos(rot0), 0, 0)
|
152 |
+
if not self.trans is None:
|
153 |
+
trans0 = np.random.uniform(-self.trans[0],+self.trans[0], 2)
|
154 |
+
trans1 = np.random.uniform(-self.trans[1]*self.schedule_coeff,+self.trans[1]*self.schedule_coeff, 2) + trans0
|
155 |
+
self.left_multiply(1, 0, 0, 1, trans0[0] * tw, trans0[1] * th)
|
156 |
+
if not self.squeeze is None:
|
157 |
+
squeeze0 = np.exp(np.random.uniform(-self.squeeze[0], self.squeeze[0]))
|
158 |
+
squeeze1 = np.exp(np.random.uniform(-self.squeeze[1]*self.schedule_coeff, self.squeeze[1]*self.schedule_coeff)) * squeeze0
|
159 |
+
if not self.scale is None:
|
160 |
+
scale0 = np.exp(np.random.uniform(self.scale[2]-self.scale[0], self.scale[2]+self.scale[0]))
|
161 |
+
scale1 = np.exp(np.random.uniform(-self.scale[1]*self.schedule_coeff, self.scale[1]*self.schedule_coeff)) * scale0
|
162 |
+
self.left_multiply(1.0/(scale0*squeeze0), 0, 0, 1.0/(scale0/squeeze0), 0, 0)
|
163 |
+
|
164 |
+
self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
|
165 |
+
transmat0 = self.t.copy()
|
166 |
+
|
167 |
+
# im1
|
168 |
+
self.to_identity()
|
169 |
+
if mirror:
|
170 |
+
self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th);
|
171 |
+
else:
|
172 |
+
self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
|
173 |
+
if not self.rot is None:
|
174 |
+
self.left_multiply(np.cos(rot1), np.sin(rot1), -np.sin(rot1), np.cos(rot1), 0, 0)
|
175 |
+
if not self.trans is None:
|
176 |
+
self.left_multiply(1, 0, 0, 1, trans1[0] * tw, trans1[1] * th)
|
177 |
+
self.left_multiply(1.0/(scale1*squeeze1), 0, 0, 1.0/(scale1/squeeze1), 0, 0)
|
178 |
+
self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
|
179 |
+
transmat1 = self.t.copy()
|
180 |
+
transmat1_inv = self.inverse()
|
181 |
+
|
182 |
+
if self.black:
|
183 |
+
# black augmentation, allowing 0 values in the input images
|
184 |
+
# https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/black_augmentation_layer.cu
|
185 |
+
break
|
186 |
+
else:
|
187 |
+
if ((self.grid_transform(cornergrid, transmat0, gridsize=[float(h),float(w)]).abs()>1).sum() +\
|
188 |
+
(self.grid_transform(cornergrid, transmat1, gridsize=[float(h),float(w)]).abs()>1).sum()) == 0:
|
189 |
+
break
|
190 |
+
if i==49:
|
191 |
+
print('max_iter in augmentation')
|
192 |
+
self.to_identity()
|
193 |
+
self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
|
194 |
+
self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
|
195 |
+
transmat0 = self.t.copy()
|
196 |
+
transmat1 = self.t.copy()
|
197 |
+
|
198 |
+
# do the real work
|
199 |
+
vgrid = self.grid_transform(meshgrid, transmat0,gridsize=[float(h),float(w)])
|
200 |
+
inputs_0 = F.grid_sample(torch.Tensor(inputs[0]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
|
201 |
+
if self.order == 0:
|
202 |
+
target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0)
|
203 |
+
else:
|
204 |
+
target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
|
205 |
+
|
206 |
+
mask_0 = target[:,:,2:3].copy(); mask_0[mask_0==0]=np.nan
|
207 |
+
if self.order == 0:
|
208 |
+
mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0)
|
209 |
+
else:
|
210 |
+
mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
|
211 |
+
mask_0[torch.isnan(mask_0)] = 0
|
212 |
+
|
213 |
+
|
214 |
+
vgrid = self.grid_transform(meshgrid, transmat1,gridsize=[float(h),float(w)])
|
215 |
+
inputs_1 = F.grid_sample(torch.Tensor(inputs[1]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
|
216 |
+
|
217 |
+
# flow
|
218 |
+
pos = target_0[:,:,:2] + self.grid_transform(meshgrid, transmat0,normalize=False)
|
219 |
+
pos = self.grid_transform(pos.permute(2,0,1),transmat1_inv,normalize=False)
|
220 |
+
if target_0.shape[2]>=4:
|
221 |
+
# scale
|
222 |
+
exp = target_0[:,:,3:] * scale1 / scale0
|
223 |
+
target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1),
|
224 |
+
(pos[:,:,1] - meshgrid[1]).unsqueeze(-1),
|
225 |
+
mask_0,
|
226 |
+
exp], -1)
|
227 |
+
else:
|
228 |
+
target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1),
|
229 |
+
(pos[:,:,1] - meshgrid[1]).unsqueeze(-1),
|
230 |
+
mask_0], -1)
|
231 |
+
# target_0[:,:,2].unsqueeze(-1) ], -1)
|
232 |
+
inputs = [np.asarray(inputs_0), np.asarray(inputs_1)]
|
233 |
+
target = np.asarray(target)
|
234 |
+
|
235 |
+
return inputs,target
|
236 |
+
|
237 |
+
|
238 |
+
class pseudoPCAAug(object):
|
239 |
+
"""
|
240 |
+
Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
|
241 |
+
This version is faster.
|
242 |
+
"""
|
243 |
+
def __init__(self, schedule_coeff=1):
|
244 |
+
self.augcolor = torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.5, hue=0.5/3.14)
|
245 |
+
|
246 |
+
def __call__(self, inputs, target):
|
247 |
+
inputs[0] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[0]*255))))/255.
|
248 |
+
inputs[1] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[1]*255))))/255.
|
249 |
+
return inputs,target
|
250 |
+
|
251 |
+
|
252 |
+
class PCAAug(object):
|
253 |
+
"""
|
254 |
+
Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
|
255 |
+
"""
|
256 |
+
def __init__(self, lmult_pow =[0.4, 0,-0.2],
|
257 |
+
lmult_mult =[0.4, 0,0, ],
|
258 |
+
lmult_add =[0.03,0,0, ],
|
259 |
+
sat_pow =[0.4, 0,0, ],
|
260 |
+
sat_mult =[0.5, 0,-0.3],
|
261 |
+
sat_add =[0.03,0,0, ],
|
262 |
+
col_pow =[0.4, 0,0, ],
|
263 |
+
col_mult =[0.2, 0,0, ],
|
264 |
+
col_add =[0.02,0,0, ],
|
265 |
+
ladd_pow =[0.4, 0,0, ],
|
266 |
+
ladd_mult =[0.4, 0,0, ],
|
267 |
+
ladd_add =[0.04,0,0, ],
|
268 |
+
col_rotate =[1., 0,0, ],
|
269 |
+
schedule_coeff=1):
|
270 |
+
# no mean
|
271 |
+
self.pow_nomean = [1,1,1]
|
272 |
+
self.add_nomean = [0,0,0]
|
273 |
+
self.mult_nomean = [1,1,1]
|
274 |
+
self.pow_withmean = [1,1,1]
|
275 |
+
self.add_withmean = [0,0,0]
|
276 |
+
self.mult_withmean = [1,1,1]
|
277 |
+
self.lmult_pow = 1
|
278 |
+
self.lmult_mult = 1
|
279 |
+
self.lmult_add = 0
|
280 |
+
self.col_angle = 0
|
281 |
+
if not ladd_pow is None:
|
282 |
+
self.pow_nomean[0] =np.exp(np.random.normal(ladd_pow[2], ladd_pow[0]))
|
283 |
+
if not col_pow is None:
|
284 |
+
self.pow_nomean[1] =np.exp(np.random.normal(col_pow[2], col_pow[0]))
|
285 |
+
self.pow_nomean[2] =np.exp(np.random.normal(col_pow[2], col_pow[0]))
|
286 |
+
|
287 |
+
if not ladd_add is None:
|
288 |
+
self.add_nomean[0] =np.random.normal(ladd_add[2], ladd_add[0])
|
289 |
+
if not col_add is None:
|
290 |
+
self.add_nomean[1] =np.random.normal(col_add[2], col_add[0])
|
291 |
+
self.add_nomean[2] =np.random.normal(col_add[2], col_add[0])
|
292 |
+
|
293 |
+
if not ladd_mult is None:
|
294 |
+
self.mult_nomean[0] =np.exp(np.random.normal(ladd_mult[2], ladd_mult[0]))
|
295 |
+
if not col_mult is None:
|
296 |
+
self.mult_nomean[1] =np.exp(np.random.normal(col_mult[2], col_mult[0]))
|
297 |
+
self.mult_nomean[2] =np.exp(np.random.normal(col_mult[2], col_mult[0]))
|
298 |
+
|
299 |
+
# with mean
|
300 |
+
if not sat_pow is None:
|
301 |
+
self.pow_withmean[1] =np.exp(np.random.uniform(sat_pow[2]-sat_pow[0], sat_pow[2]+sat_pow[0]))
|
302 |
+
self.pow_withmean[2] =self.pow_withmean[1]
|
303 |
+
if not sat_add is None:
|
304 |
+
self.add_withmean[1] =np.random.uniform(sat_add[2]-sat_add[0], sat_add[2]+sat_add[0])
|
305 |
+
self.add_withmean[2] =self.add_withmean[1]
|
306 |
+
if not sat_mult is None:
|
307 |
+
self.mult_withmean[1] = np.exp(np.random.uniform(sat_mult[2]-sat_mult[0], sat_mult[2]+sat_mult[0]))
|
308 |
+
self.mult_withmean[2] = self.mult_withmean[1]
|
309 |
+
|
310 |
+
if not lmult_pow is None:
|
311 |
+
self.lmult_pow = np.exp(np.random.uniform(lmult_pow[2]-lmult_pow[0], lmult_pow[2]+lmult_pow[0]))
|
312 |
+
if not lmult_mult is None:
|
313 |
+
self.lmult_mult= np.exp(np.random.uniform(lmult_mult[2]-lmult_mult[0], lmult_mult[2]+lmult_mult[0]))
|
314 |
+
if not lmult_add is None:
|
315 |
+
self.lmult_add = np.random.uniform(lmult_add[2]-lmult_add[0], lmult_add[2]+lmult_add[0])
|
316 |
+
if not col_rotate is None:
|
317 |
+
self.col_angle= np.random.uniform(col_rotate[2]-col_rotate[0], col_rotate[2]+col_rotate[0])
|
318 |
+
|
319 |
+
# eigen vectors
|
320 |
+
self.eigvec = np.reshape([0.51,0.56,0.65,0.79,0.01,-0.62,0.35,-0.83,0.44],[3,3]).transpose()
|
321 |
+
|
322 |
+
|
323 |
+
def __call__(self, inputs, target):
|
324 |
+
inputs[0] = self.pca_image(inputs[0])
|
325 |
+
inputs[1] = self.pca_image(inputs[1])
|
326 |
+
return inputs,target
|
327 |
+
|
328 |
+
def pca_image(self, rgb):
|
329 |
+
eig = np.dot(rgb, self.eigvec)
|
330 |
+
max_rgb = np.clip(rgb,0,np.inf).max((0,1))
|
331 |
+
min_rgb = rgb.min((0,1))
|
332 |
+
mean_rgb = rgb.mean((0,1))
|
333 |
+
max_abs_eig = np.abs(eig).max((0,1))
|
334 |
+
max_l = np.sqrt(np.sum(max_abs_eig*max_abs_eig))
|
335 |
+
mean_eig = np.dot(mean_rgb, self.eigvec)
|
336 |
+
|
337 |
+
# no-mean stuff
|
338 |
+
eig -= mean_eig[np.newaxis, np.newaxis]
|
339 |
+
|
340 |
+
for c in range(3):
|
341 |
+
if max_abs_eig[c] > 1e-2:
|
342 |
+
mean_eig[c] /= max_abs_eig[c]
|
343 |
+
eig[:,:,c] = eig[:,:,c] / max_abs_eig[c];
|
344 |
+
eig[:,:,c] = np.power(np.abs(eig[:,:,c]),self.pow_nomean[c]) *\
|
345 |
+
((eig[:,:,c] > 0) -0.5)*2
|
346 |
+
eig[:,:,c] = eig[:,:,c] + self.add_nomean[c]
|
347 |
+
eig[:,:,c] = eig[:,:,c] * self.mult_nomean[c]
|
348 |
+
eig += mean_eig[np.newaxis,np.newaxis]
|
349 |
+
|
350 |
+
# withmean stuff
|
351 |
+
if max_abs_eig[0] > 1e-2:
|
352 |
+
eig[:,:,0] = np.power(np.abs(eig[:,:,0]),self.pow_withmean[0]) * \
|
353 |
+
((eig[:,:,0]>0)-0.5)*2;
|
354 |
+
eig[:,:,0] = eig[:,:,0] + self.add_withmean[0];
|
355 |
+
eig[:,:,0] = eig[:,:,0] * self.mult_withmean[0];
|
356 |
+
|
357 |
+
s = np.sqrt(eig[:,:,1]*eig[:,:,1] + eig[:,:,2] * eig[:,:,2])
|
358 |
+
smask = s > 1e-2
|
359 |
+
s1 = np.power(s, self.pow_withmean[1]);
|
360 |
+
s1 = np.clip(s1 + self.add_withmean[1], 0,np.inf)
|
361 |
+
s1 = s1 * self.mult_withmean[1]
|
362 |
+
s1 = s1 * smask + s*(1-smask)
|
363 |
+
|
364 |
+
# color angle
|
365 |
+
if self.col_angle!=0:
|
366 |
+
temp1 = np.cos(self.col_angle) * eig[:,:,1] - np.sin(self.col_angle) * eig[:,:,2]
|
367 |
+
temp2 = np.sin(self.col_angle) * eig[:,:,1] + np.cos(self.col_angle) * eig[:,:,2]
|
368 |
+
eig[:,:,1] = temp1
|
369 |
+
eig[:,:,2] = temp2
|
370 |
+
|
371 |
+
# to origin magnitude
|
372 |
+
for c in range(3):
|
373 |
+
if max_abs_eig[c] > 1e-2:
|
374 |
+
eig[:,:,c] = eig[:,:,c] * max_abs_eig[c]
|
375 |
+
|
376 |
+
if max_l > 1e-2:
|
377 |
+
l1 = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2])
|
378 |
+
l1 = l1 / max_l
|
379 |
+
|
380 |
+
eig[:,:,1][smask] = (eig[:,:,1] / s * s1)[smask]
|
381 |
+
eig[:,:,2][smask] = (eig[:,:,2] / s * s1)[smask]
|
382 |
+
#eig[:,:,1] = (eig[:,:,1] / s * s1) * smask + eig[:,:,1] * (1-smask)
|
383 |
+
#eig[:,:,2] = (eig[:,:,2] / s * s1) * smask + eig[:,:,2] * (1-smask)
|
384 |
+
|
385 |
+
if max_l > 1e-2:
|
386 |
+
l = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2])
|
387 |
+
l1 = np.power(l1, self.lmult_pow)
|
388 |
+
l1 = np.clip(l1 + self.lmult_add, 0, np.inf)
|
389 |
+
l1 = l1 * self.lmult_mult
|
390 |
+
l1 = l1 * max_l
|
391 |
+
lmask = l > 1e-2
|
392 |
+
eig[lmask] = (eig / l[:,:,np.newaxis] * l1[:,:,np.newaxis])[lmask]
|
393 |
+
for c in range(3):
|
394 |
+
eig[:,:,c][lmask] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c]))[lmask]
|
395 |
+
# for c in range(3):
|
396 |
+
# # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] * lmask + eig[:,:,c] * (1-lmask)
|
397 |
+
# eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask]
|
398 |
+
# eig[:,:,c] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c])) * lmask + eig[:,:,c] * (1-lmask)
|
399 |
+
|
400 |
+
return np.clip(np.dot(eig, self.eigvec.transpose()), 0, 1)
|
401 |
+
|
402 |
+
|
403 |
+
class ChromaticAug(object):
|
404 |
+
"""
|
405 |
+
Chromatic augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
|
406 |
+
"""
|
407 |
+
def __init__(self, noise = 0.06,
|
408 |
+
gamma = 0.02,
|
409 |
+
brightness = 0.02,
|
410 |
+
contrast = 0.02,
|
411 |
+
color = 0.02,
|
412 |
+
schedule_coeff=1):
|
413 |
+
|
414 |
+
self.noise = np.random.uniform(0,noise)
|
415 |
+
self.gamma = np.exp(np.random.normal(0, gamma*schedule_coeff))
|
416 |
+
self.brightness = np.random.normal(0, brightness*schedule_coeff)
|
417 |
+
self.contrast = np.exp(np.random.normal(0, contrast*schedule_coeff))
|
418 |
+
self.color = np.exp(np.random.normal(0, color*schedule_coeff,3))
|
419 |
+
|
420 |
+
def __call__(self, inputs, target):
|
421 |
+
inputs[1] = self.chrom_aug(inputs[1])
|
422 |
+
# noise
|
423 |
+
inputs[0]+=np.random.normal(0, self.noise, inputs[0].shape)
|
424 |
+
inputs[1]+=np.random.normal(0, self.noise, inputs[0].shape)
|
425 |
+
return inputs,target
|
426 |
+
|
427 |
+
def chrom_aug(self, rgb):
|
428 |
+
# color change
|
429 |
+
mean_in = rgb.sum(-1)
|
430 |
+
rgb = rgb*self.color[np.newaxis,np.newaxis]
|
431 |
+
brightness_coeff = mean_in / (rgb.sum(-1)+0.01)
|
432 |
+
rgb = np.clip(rgb*brightness_coeff[:,:,np.newaxis],0,1)
|
433 |
+
# gamma
|
434 |
+
rgb = np.power(rgb,self.gamma)
|
435 |
+
# brightness
|
436 |
+
rgb += self.brightness
|
437 |
+
# contrast
|
438 |
+
rgb = 0.5 + ( rgb-0.5)*self.contrast
|
439 |
+
rgb = np.clip(rgb, 0, 1)
|
440 |
+
return rgb
|
expansion/dataloader/hd1klist.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
IMG_EXTENSIONS = [
|
10 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
11 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
def is_image_file(filename):
|
16 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
17 |
+
|
18 |
+
def dataloader(filepath):
|
19 |
+
|
20 |
+
left_fold = 'image_2/'
|
21 |
+
train = [img for img in os.listdir(filepath+left_fold) if img.find('HD1K2018') > -1]
|
22 |
+
train = sorted(train)
|
23 |
+
|
24 |
+
l0_train = [filepath+left_fold+img for img in train]
|
25 |
+
l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
|
26 |
+
l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
|
27 |
+
flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
|
28 |
+
|
29 |
+
return l0_train, l1_train, flow_train
|
expansion/dataloader/kitti12list.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
IMG_EXTENSIONS = [
|
9 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
10 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
11 |
+
]
|
12 |
+
|
13 |
+
|
14 |
+
def is_image_file(filename):
|
15 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
16 |
+
|
17 |
+
def dataloader(filepath):
|
18 |
+
|
19 |
+
left_fold = 'colored_0/'
|
20 |
+
flow_noc = 'flow_occ/'
|
21 |
+
|
22 |
+
train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
|
23 |
+
|
24 |
+
l0_train = [filepath+left_fold+img for img in train]
|
25 |
+
l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
|
26 |
+
flow_train = [filepath+flow_noc+img for img in train]
|
27 |
+
|
28 |
+
|
29 |
+
return l0_train, l1_train, flow_train
|
expansion/dataloader/kitti15list.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
IMG_EXTENSIONS = [
|
9 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
10 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
11 |
+
]
|
12 |
+
|
13 |
+
|
14 |
+
def is_image_file(filename):
|
15 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
16 |
+
|
17 |
+
def dataloader(filepath):
|
18 |
+
|
19 |
+
left_fold = 'image_2/'
|
20 |
+
flow_noc = 'flow_occ/'
|
21 |
+
|
22 |
+
train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
|
23 |
+
|
24 |
+
l0_train = [filepath+left_fold+img for img in train]
|
25 |
+
l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
|
26 |
+
flow_train = [filepath+flow_noc+img for img in train]
|
27 |
+
|
28 |
+
|
29 |
+
return sorted(l0_train), sorted(l1_train), sorted(flow_train)
|
expansion/dataloader/kitti15list_train.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
IMG_EXTENSIONS = [
|
9 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
10 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
11 |
+
]
|
12 |
+
|
13 |
+
|
14 |
+
def is_image_file(filename):
|
15 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
16 |
+
|
17 |
+
def dataloader(filepath):
|
18 |
+
|
19 |
+
left_fold = 'image_2/'
|
20 |
+
flow_noc = 'flow_occ/'
|
21 |
+
|
22 |
+
train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
|
23 |
+
|
24 |
+
train = [i for i in train if int(i.split('_')[0])%5!=0]
|
25 |
+
|
26 |
+
l0_train = [filepath+left_fold+img for img in train]
|
27 |
+
l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
|
28 |
+
flow_train = [filepath+flow_noc+img for img in train]
|
29 |
+
|
30 |
+
|
31 |
+
return sorted(l0_train), sorted(l1_train), sorted(flow_train)
|
expansion/dataloader/kitti15list_train_lidar.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
IMG_EXTENSIONS = [
|
9 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
10 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
11 |
+
]
|
12 |
+
|
13 |
+
|
14 |
+
def is_image_file(filename):
|
15 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
16 |
+
|
17 |
+
def dataloader(filepath):
|
18 |
+
|
19 |
+
left_fold = 'image_2/'
|
20 |
+
flow_noc = 'flow_occ/'
|
21 |
+
|
22 |
+
train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
|
23 |
+
|
24 |
+
# train = [i for i in train if int(i.split('_')[0])%5!=0]
|
25 |
+
with open('/data/gengshay/kitti_scene/devkit/mapping/train_mapping.txt','r') as f:
|
26 |
+
flags = [True if len(i)>1 else False for i in f.readlines()]
|
27 |
+
train = [fn for (it,fn) in enumerate(sorted(train)) if flags[it] ][:100]
|
28 |
+
|
29 |
+
l0_train = [filepath+left_fold+img for img in train]
|
30 |
+
l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
|
31 |
+
flow_train = [filepath+flow_noc+img for img in train]
|
32 |
+
|
33 |
+
|
34 |
+
return sorted(l0_train), sorted(l1_train), sorted(flow_train)
|
expansion/dataloader/kitti15list_val.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
IMG_EXTENSIONS = [
|
9 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
10 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
11 |
+
]
|
12 |
+
|
13 |
+
|
14 |
+
def is_image_file(filename):
|
15 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
16 |
+
|
17 |
+
def dataloader(filepath):
|
18 |
+
|
19 |
+
left_fold = 'image_2/'
|
20 |
+
flow_noc = 'flow_occ/'
|
21 |
+
|
22 |
+
train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
|
23 |
+
|
24 |
+
train = [i for i in train if int(i.split('_')[0])%5==0]
|
25 |
+
|
26 |
+
l0_train = [filepath+left_fold+img for img in train]
|
27 |
+
l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
|
28 |
+
flow_train = [filepath+flow_noc+img for img in train]
|
29 |
+
|
30 |
+
|
31 |
+
return sorted(l0_train), sorted(l1_train), sorted(flow_train)
|
expansion/dataloader/kitti15list_val_lidar.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
IMG_EXTENSIONS = [
|
9 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
10 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
11 |
+
]
|
12 |
+
|
13 |
+
|
14 |
+
def is_image_file(filename):
|
15 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
16 |
+
|
17 |
+
def dataloader(filepath):
|
18 |
+
|
19 |
+
left_fold = 'image_2/'
|
20 |
+
flow_noc = 'flow_occ/'
|
21 |
+
|
22 |
+
train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
|
23 |
+
|
24 |
+
# train = [i for i in train if int(i.split('_')[0])%5!=0]
|
25 |
+
with open('/data/gengshay/kitti_scene/devkit/mapping/train_mapping.txt','r') as f:
|
26 |
+
flags = [True if len(i)>1 else False for i in f.readlines()]
|
27 |
+
train = [fn for (it,fn) in enumerate(sorted(train)) if flags[it] ][100:]
|
28 |
+
|
29 |
+
l0_train = [filepath+left_fold+img for img in train]
|
30 |
+
l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
|
31 |
+
flow_train = [filepath+flow_noc+img for img in train]
|
32 |
+
|
33 |
+
|
34 |
+
return sorted(l0_train), sorted(l1_train), sorted(flow_train)
|
expansion/dataloader/kitti15list_val_mr.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
IMG_EXTENSIONS = [
|
9 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
10 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
11 |
+
]
|
12 |
+
|
13 |
+
|
14 |
+
def is_image_file(filename):
|
15 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
16 |
+
|
17 |
+
def dataloader(filepath):
|
18 |
+
|
19 |
+
left_fold = 'image_2/'
|
20 |
+
flow_noc = 'flow_occ/'
|
21 |
+
|
22 |
+
train = [img for img in os.listdir(filepath+left_fold) if 'Kitti' in img and img.find('_10') > -1]
|
23 |
+
|
24 |
+
# train = [i for i in train if int(i.split('_')[1])%5==0]
|
25 |
+
import pdb; pdb.set_trace()
|
26 |
+
train = sorted([i for i in train if int(i.split('_')[1])%5==0])[0:1]
|
27 |
+
|
28 |
+
l0_train = [filepath+left_fold+img for img in train]
|
29 |
+
l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
|
30 |
+
flow_train = [filepath+flow_noc+img for img in train]
|
31 |
+
|
32 |
+
l0_train += [filepath+left_fold+img.replace('_10','_09') for img in train]
|
33 |
+
l1_train += [filepath+left_fold+img for img in train]
|
34 |
+
flow_train += flow_train
|
35 |
+
|
36 |
+
tmp = l0_train
|
37 |
+
l0_train = l0_train+ [i.replace('rob_flow', 'kitti_scene').replace('Kitti2015_','') for i in l1_train]
|
38 |
+
l1_train = l1_train+tmp
|
39 |
+
flow_train += flow_train
|
40 |
+
|
41 |
+
return l0_train, l1_train, flow_train
|
expansion/dataloader/robloader.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numbers
|
3 |
+
import torch
|
4 |
+
import torch.utils.data as data
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
import random
|
8 |
+
from PIL import Image, ImageOps
|
9 |
+
import numpy as np
|
10 |
+
import torchvision
|
11 |
+
from . import flow_transforms
|
12 |
+
import pdb
|
13 |
+
import cv2
|
14 |
+
from utils.flowlib import read_flow
|
15 |
+
from utils.util_flow import readPFM
|
16 |
+
|
17 |
+
|
18 |
+
def default_loader(path):
|
19 |
+
return Image.open(path).convert('RGB')
|
20 |
+
|
21 |
+
def flow_loader(path):
|
22 |
+
if '.pfm' in path:
|
23 |
+
data = readPFM(path)[0]
|
24 |
+
data[:,:,2] = 1
|
25 |
+
return data
|
26 |
+
else:
|
27 |
+
return read_flow(path)
|
28 |
+
|
29 |
+
|
30 |
+
def disparity_loader(path):
|
31 |
+
if '.png' in path:
|
32 |
+
data = Image.open(path)
|
33 |
+
data = np.ascontiguousarray(data,dtype=np.float32)/256
|
34 |
+
return data
|
35 |
+
else:
|
36 |
+
return readPFM(path)[0]
|
37 |
+
|
38 |
+
class myImageFloder(data.Dataset):
|
39 |
+
def __init__(self, iml0, iml1, flowl0, loader=default_loader, dploader= flow_loader, scale=1.,shape=[320,448], order=1, noise=0.06, pca_augmentor=True, prob = 1., cover=False, black=False, scale_aug=[0.4,0.2]):
|
40 |
+
self.iml0 = iml0
|
41 |
+
self.iml1 = iml1
|
42 |
+
self.flowl0 = flowl0
|
43 |
+
self.loader = loader
|
44 |
+
self.dploader = dploader
|
45 |
+
self.scale=scale
|
46 |
+
self.shape=shape
|
47 |
+
self.order=order
|
48 |
+
self.noise = noise
|
49 |
+
self.pca_augmentor = pca_augmentor
|
50 |
+
self.prob = prob
|
51 |
+
self.cover = cover
|
52 |
+
self.black = black
|
53 |
+
self.scale_aug = scale_aug
|
54 |
+
|
55 |
+
def __getitem__(self, index):
|
56 |
+
iml0 = self.iml0[index]
|
57 |
+
iml1 = self.iml1[index]
|
58 |
+
flowl0= self.flowl0[index]
|
59 |
+
th, tw = self.shape
|
60 |
+
|
61 |
+
iml0 = self.loader(iml0)
|
62 |
+
iml1 = self.loader(iml1)
|
63 |
+
iml1 = np.asarray(iml1)/255.
|
64 |
+
iml0 = np.asarray(iml0)/255.
|
65 |
+
iml0 = iml0[:,:,::-1].copy()
|
66 |
+
iml1 = iml1[:,:,::-1].copy()
|
67 |
+
flowl0 = self.dploader(flowl0)
|
68 |
+
#flowl0[:,:,-1][flowl0[:,:,0]==np.inf]=0 # for gtav window pfm files
|
69 |
+
#flowl0[:,:,0][~flowl0[:,:,2].astype(bool)]=0
|
70 |
+
#flowl0[:,:,1][~flowl0[:,:,2].astype(bool)]=0 # avoid nan in grad
|
71 |
+
flowl0 = np.ascontiguousarray(flowl0,dtype=np.float32)
|
72 |
+
flowl0[np.isnan(flowl0)] = 1e6 # set to max
|
73 |
+
|
74 |
+
## following data augmentation procedure in PWCNet
|
75 |
+
## https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
|
76 |
+
import __main__ # a workaround for "discount_coeff"
|
77 |
+
try:
|
78 |
+
with open('iter_counts-%d.txt'%int(__main__.args.logname.split('-')[-1]), 'r') as f:
|
79 |
+
iter_counts = int(f.readline())
|
80 |
+
except:
|
81 |
+
iter_counts = 0
|
82 |
+
schedule = [0.5, 1., 50000.] # initial coeff, final_coeff, half life
|
83 |
+
schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * \
|
84 |
+
(2/(1+np.exp(-1.0986*iter_counts/schedule[2])) - 1)
|
85 |
+
|
86 |
+
if self.pca_augmentor:
|
87 |
+
pca_augmentor = flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff)
|
88 |
+
else:
|
89 |
+
pca_augmentor = flow_transforms.Scale(1., order=0)
|
90 |
+
|
91 |
+
if np.random.binomial(1,self.prob):
|
92 |
+
co_transform = flow_transforms.Compose([
|
93 |
+
flow_transforms.Scale(self.scale, order=self.order),
|
94 |
+
#flow_transforms.SpatialAug([th,tw], trans=[0.2,0.03], order=self.order, black=self.black),
|
95 |
+
flow_transforms.SpatialAug([th,tw],scale=[self.scale_aug[0],0.03,self.scale_aug[1]],
|
96 |
+
rot=[0.4,0.03],
|
97 |
+
trans=[0.4,0.03],
|
98 |
+
squeeze=[0.3,0.], schedule_coeff=schedule_coeff, order=self.order, black=self.black),
|
99 |
+
#flow_transforms.pseudoPCAAug(schedule_coeff=schedule_coeff),
|
100 |
+
flow_transforms.PCAAug(schedule_coeff=schedule_coeff),
|
101 |
+
flow_transforms.ChromaticAug( schedule_coeff=schedule_coeff, noise=self.noise),
|
102 |
+
])
|
103 |
+
else:
|
104 |
+
co_transform = flow_transforms.Compose([
|
105 |
+
flow_transforms.Scale(self.scale, order=self.order),
|
106 |
+
flow_transforms.SpatialAug([th,tw], trans=[0.4,0.03], order=self.order, black=self.black)
|
107 |
+
])
|
108 |
+
|
109 |
+
augmented,flowl0 = co_transform([iml0, iml1], flowl0)
|
110 |
+
iml0 = augmented[0]
|
111 |
+
iml1 = augmented[1]
|
112 |
+
|
113 |
+
if self.cover:
|
114 |
+
## randomly cover a region
|
115 |
+
# following sec. 3.2 of http://openaccess.thecvf.com/content_CVPR_2019/html/Yang_Hierarchical_Deep_Stereo_Matching_on_High-Resolution_Images_CVPR_2019_paper.html
|
116 |
+
if np.random.binomial(1,0.5):
|
117 |
+
#sx = int(np.random.uniform(25,100))
|
118 |
+
#sy = int(np.random.uniform(25,100))
|
119 |
+
sx = int(np.random.uniform(50,125))
|
120 |
+
sy = int(np.random.uniform(50,125))
|
121 |
+
#sx = int(np.random.uniform(50,150))
|
122 |
+
#sy = int(np.random.uniform(50,150))
|
123 |
+
cx = int(np.random.uniform(sx,iml1.shape[0]-sx))
|
124 |
+
cy = int(np.random.uniform(sy,iml1.shape[1]-sy))
|
125 |
+
iml1[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(iml1,0),0)[np.newaxis,np.newaxis]
|
126 |
+
|
127 |
+
iml0 = torch.Tensor(np.transpose(iml0,(2,0,1)))
|
128 |
+
iml1 = torch.Tensor(np.transpose(iml1,(2,0,1)))
|
129 |
+
|
130 |
+
return iml0, iml1, flowl0
|
131 |
+
|
132 |
+
def __len__(self):
|
133 |
+
return len(self.iml0)
|
expansion/dataloader/sceneflowlist.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path
|
3 |
+
import glob
|
4 |
+
|
5 |
+
def dataloader(filepath, level=6):
|
6 |
+
iml0 = []
|
7 |
+
iml1 = []
|
8 |
+
flowl0 = []
|
9 |
+
disp0 = []
|
10 |
+
dispc = []
|
11 |
+
calib = []
|
12 |
+
level_stars = '/*'*level
|
13 |
+
candidate_pool = glob.glob('%s/optical_flow%s'%(filepath,level_stars))
|
14 |
+
for flow_path in sorted(candidate_pool):
|
15 |
+
if 'TEST' in flow_path: continue
|
16 |
+
if 'flower_storm_x2/into_future/right/OpticalFlowIntoFuture_0023_R.pfm' in flow_path:
|
17 |
+
continue
|
18 |
+
if 'flower_storm_x2/into_future/left/OpticalFlowIntoFuture_0023_L.pfm' in flow_path:
|
19 |
+
continue
|
20 |
+
if 'flower_storm_augmented0_x2/into_future/right/OpticalFlowIntoFuture_0023_R.pfm' in flow_path:
|
21 |
+
continue
|
22 |
+
if 'flower_storm_augmented0_x2/into_future/left/OpticalFlowIntoFuture_0023_L.pfm' in flow_path:
|
23 |
+
continue
|
24 |
+
if 'FlyingThings' in flow_path and '_0014_' in flow_path:
|
25 |
+
continue
|
26 |
+
if 'FlyingThings' in flow_path and '_0015_' in flow_path:
|
27 |
+
continue
|
28 |
+
idd = flow_path.split('/')[-1].split('_')[-2]
|
29 |
+
if 'into_future' in flow_path:
|
30 |
+
idd_p1 = '%04d'%(int(idd)+1)
|
31 |
+
else:
|
32 |
+
idd_p1 = '%04d'%(int(idd)-1)
|
33 |
+
if os.path.exists(flow_path.replace(idd,idd_p1)):
|
34 |
+
d0_path = flow_path.replace('/into_future/','/').replace('/into_past/','/').replace('optical_flow','disparity')
|
35 |
+
d0_path = '%s/%s.pfm'%(d0_path.rsplit('/',1)[0],idd)
|
36 |
+
dc_path = flow_path.replace('optical_flow','disparity_change')
|
37 |
+
dc_path = '%s/%s.pfm'%(dc_path.rsplit('/',1)[0],idd)
|
38 |
+
im_path = flow_path.replace('/into_future/','/').replace('/into_past/','/').replace('optical_flow','frames_cleanpass')
|
39 |
+
im0_path = '%s/%s.png'%(im_path.rsplit('/',1)[0],idd)
|
40 |
+
im1_path = '%s/%s.png'%(im_path.rsplit('/',1)[0],idd_p1)
|
41 |
+
#with open('%s/camera_data.txt'%(im0_path.replace('frames_cleanpass','camera_data').rsplit('/',2)[0]),'r') as f:
|
42 |
+
# if 'FlyingThings' in flow_path and len(f.readlines())!=40:
|
43 |
+
# print(flow_path)
|
44 |
+
# continue
|
45 |
+
iml0.append(im0_path)
|
46 |
+
iml1.append(im1_path)
|
47 |
+
flowl0.append(flow_path)
|
48 |
+
disp0.append(d0_path)
|
49 |
+
dispc.append(dc_path)
|
50 |
+
calib.append('%s/camera_data.txt'%(im0_path.replace('frames_cleanpass','camera_data').rsplit('/',2)[0]))
|
51 |
+
return iml0, iml1, flowl0, disp0, dispc, calib
|
expansion/dataloader/seqlist.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
import glob
|
8 |
+
|
9 |
+
IMG_EXTENSIONS = [
|
10 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
11 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
def is_image_file(filename):
|
16 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
17 |
+
|
18 |
+
def dataloader(filepath):
|
19 |
+
|
20 |
+
train = [img for img in sorted(glob.glob('%s/*'%filepath))]
|
21 |
+
|
22 |
+
l0_train = train[:-1]
|
23 |
+
l1_train = train[1:]
|
24 |
+
|
25 |
+
|
26 |
+
return sorted(l0_train), sorted(l1_train), sorted(l0_train)
|
expansion/dataloader/sintellist.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
IMG_EXTENSIONS = [
|
10 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
11 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
def is_image_file(filename):
|
16 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
17 |
+
|
18 |
+
def dataloader(filepath):
|
19 |
+
|
20 |
+
left_fold = 'image_2/'
|
21 |
+
train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
|
22 |
+
|
23 |
+
l0_train = [filepath+left_fold+img for img in train]
|
24 |
+
l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
|
25 |
+
|
26 |
+
#l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
|
27 |
+
|
28 |
+
l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
|
29 |
+
flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
|
30 |
+
|
31 |
+
|
32 |
+
return l0_train, l1_train, flow_train
|
expansion/dataloader/sintellist_clean.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
IMG_EXTENSIONS = [
|
10 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
11 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
def is_image_file(filename):
|
16 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
17 |
+
|
18 |
+
def dataloader(filepath):
|
19 |
+
|
20 |
+
left_fold = 'image_2/'
|
21 |
+
train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_clean') > -1]
|
22 |
+
|
23 |
+
l0_train = [filepath+left_fold+img for img in train]
|
24 |
+
l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
|
25 |
+
|
26 |
+
#l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
|
27 |
+
|
28 |
+
l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
|
29 |
+
flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
|
30 |
+
|
31 |
+
return l0_train, l1_train, flow_train
|
expansion/dataloader/sintellist_final.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
IMG_EXTENSIONS = [
|
10 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
11 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
def is_image_file(filename):
|
16 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
17 |
+
|
18 |
+
def dataloader(filepath):
|
19 |
+
|
20 |
+
left_fold = 'image_2/'
|
21 |
+
train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_final') > -1]
|
22 |
+
|
23 |
+
l0_train = [filepath+left_fold+img for img in train]
|
24 |
+
l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
|
25 |
+
|
26 |
+
#l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
|
27 |
+
|
28 |
+
l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
|
29 |
+
flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
|
30 |
+
|
31 |
+
pdb.set_trace()
|
32 |
+
return l0_train, l1_train, flow_train
|
expansion/dataloader/sintellist_train.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
IMG_EXTENSIONS = [
|
10 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
11 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
def is_image_file(filename):
|
16 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
17 |
+
|
18 |
+
def dataloader(filepath):
|
19 |
+
|
20 |
+
left_fold = 'image_2/'
|
21 |
+
train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
|
22 |
+
|
23 |
+
l0_train = [filepath+left_fold+img for img in train]
|
24 |
+
l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
|
25 |
+
|
26 |
+
l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val
|
27 |
+
|
28 |
+
l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
|
29 |
+
flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
|
30 |
+
|
31 |
+
|
32 |
+
return l0_train, l1_train, flow_train
|
expansion/dataloader/sintellist_val.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
IMG_EXTENSIONS = [
|
10 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
11 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
def is_image_file(filename):
|
16 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
17 |
+
|
18 |
+
def dataloader(filepath):
|
19 |
+
|
20 |
+
left_fold = 'image_2/'
|
21 |
+
train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
|
22 |
+
|
23 |
+
l0_train = [filepath+left_fold+img for img in train]
|
24 |
+
l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
|
25 |
+
|
26 |
+
l0_train = [i for i in l0_train if ('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i)] # remove 10 as val
|
27 |
+
#l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val
|
28 |
+
|
29 |
+
l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
|
30 |
+
flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
|
31 |
+
|
32 |
+
|
33 |
+
return sorted(l0_train)[::3], sorted(l1_train)[::3], sorted(flow_train)[::3]
|
34 |
+
# return sorted(l0_train)[::10], sorted(l1_train)[::10], sorted(flow_train)[::10]
|
expansion/dataloader/thingslist.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
IMG_EXTENSIONS = [
|
9 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
10 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
11 |
+
]
|
12 |
+
|
13 |
+
|
14 |
+
def is_image_file(filename):
|
15 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
16 |
+
|
17 |
+
def dataloader(filepath):
|
18 |
+
exc_list = [
|
19 |
+
'0004117.flo',
|
20 |
+
'0003149.flo',
|
21 |
+
'0001203.flo',
|
22 |
+
'0003147.flo',
|
23 |
+
'0003666.flo',
|
24 |
+
'0006337.flo',
|
25 |
+
'0006336.flo',
|
26 |
+
'0007126.flo',
|
27 |
+
'0004118.flo',
|
28 |
+
]
|
29 |
+
|
30 |
+
left_fold = 'image_clean/left/'
|
31 |
+
flow_noc = 'flow/left/into_future/'
|
32 |
+
train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
|
33 |
+
|
34 |
+
l0_trainlf = [filepath+left_fold+img.replace('flo','png') for img in train]
|
35 |
+
l1_trainlf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlf]
|
36 |
+
flow_trainlf = [filepath+flow_noc+img for img in train]
|
37 |
+
|
38 |
+
|
39 |
+
exc_list = [
|
40 |
+
'0003148.flo',
|
41 |
+
'0004117.flo',
|
42 |
+
'0002890.flo',
|
43 |
+
'0003149.flo',
|
44 |
+
'0001203.flo',
|
45 |
+
'0003666.flo',
|
46 |
+
'0006337.flo',
|
47 |
+
'0006336.flo',
|
48 |
+
'0004118.flo',
|
49 |
+
]
|
50 |
+
|
51 |
+
left_fold = 'image_clean/right/'
|
52 |
+
flow_noc = 'flow/right/into_future/'
|
53 |
+
train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
|
54 |
+
|
55 |
+
l0_trainrf = [filepath+left_fold+img.replace('flo','png') for img in train]
|
56 |
+
l1_trainrf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrf]
|
57 |
+
flow_trainrf = [filepath+flow_noc+img for img in train]
|
58 |
+
|
59 |
+
|
60 |
+
exc_list = [
|
61 |
+
'0004237.flo',
|
62 |
+
'0004705.flo',
|
63 |
+
'0004045.flo',
|
64 |
+
'0004346.flo',
|
65 |
+
'0000161.flo',
|
66 |
+
'0000931.flo',
|
67 |
+
'0000121.flo',
|
68 |
+
'0010822.flo',
|
69 |
+
'0004117.flo',
|
70 |
+
'0006023.flo',
|
71 |
+
'0005034.flo',
|
72 |
+
'0005054.flo',
|
73 |
+
'0000162.flo',
|
74 |
+
'0000053.flo',
|
75 |
+
'0005055.flo',
|
76 |
+
'0003147.flo',
|
77 |
+
'0004876.flo',
|
78 |
+
'0000163.flo',
|
79 |
+
'0006878.flo',
|
80 |
+
]
|
81 |
+
|
82 |
+
left_fold = 'image_clean/left/'
|
83 |
+
flow_noc = 'flow/left/into_past/'
|
84 |
+
train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
|
85 |
+
|
86 |
+
l0_trainlp = [filepath+left_fold+img.replace('flo','png') for img in train]
|
87 |
+
l1_trainlp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlp]
|
88 |
+
flow_trainlp = [filepath+flow_noc+img for img in train]
|
89 |
+
|
90 |
+
exc_list = [
|
91 |
+
'0003148.flo',
|
92 |
+
'0004705.flo',
|
93 |
+
'0000161.flo',
|
94 |
+
'0000121.flo',
|
95 |
+
'0004117.flo',
|
96 |
+
'0000160.flo',
|
97 |
+
'0005034.flo',
|
98 |
+
'0005054.flo',
|
99 |
+
'0000162.flo',
|
100 |
+
'0000053.flo',
|
101 |
+
'0005055.flo',
|
102 |
+
'0003147.flo',
|
103 |
+
'0001549.flo',
|
104 |
+
'0000163.flo',
|
105 |
+
'0006336.flo',
|
106 |
+
'0001648.flo',
|
107 |
+
'0006878.flo',
|
108 |
+
]
|
109 |
+
|
110 |
+
left_fold = 'image_clean/right/'
|
111 |
+
flow_noc = 'flow/right/into_past/'
|
112 |
+
train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
|
113 |
+
|
114 |
+
l0_trainrp = [filepath+left_fold+img.replace('flo','png') for img in train]
|
115 |
+
l1_trainrp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrp]
|
116 |
+
flow_trainrp = [filepath+flow_noc+img for img in train]
|
117 |
+
|
118 |
+
|
119 |
+
l0_train = l0_trainlf + l0_trainrf + l0_trainlp + l0_trainrp
|
120 |
+
l1_train = l1_trainlf + l1_trainrf + l1_trainlp + l1_trainrp
|
121 |
+
flow_train = flow_trainlf + flow_trainrf + flow_trainlp + flow_trainrp
|
122 |
+
return l0_train, l1_train, flow_train
|
expansion/models/VCN_exp.py
ADDED
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Variable
|
5 |
+
import os
|
6 |
+
os.environ['PYTHON_EGG_CACHE'] = 'tmp/' # a writable directory
|
7 |
+
import numpy as np
|
8 |
+
import math
|
9 |
+
import pdb
|
10 |
+
import time
|
11 |
+
|
12 |
+
from .submodule import pspnet, bfmodule, conv
|
13 |
+
from .conv4d import sepConv4d, sepConv4dBlock, butterfly4D
|
14 |
+
|
15 |
+
class flow_reg(nn.Module):
|
16 |
+
"""
|
17 |
+
Soft winner-take-all that selects the most likely diplacement.
|
18 |
+
Set ent=True to enable entropy output.
|
19 |
+
Set maxdisp to adjust maximum allowed displacement towards one side.
|
20 |
+
maxdisp=4 searches for a 9x9 region.
|
21 |
+
Set fac to squeeze search window.
|
22 |
+
maxdisp=4 and fac=2 gives search window of 9x5
|
23 |
+
"""
|
24 |
+
def __init__(self, size, ent=False, maxdisp = int(4), fac=1):
|
25 |
+
B,W,H = size
|
26 |
+
super(flow_reg, self).__init__()
|
27 |
+
self.ent = ent
|
28 |
+
self.md = maxdisp
|
29 |
+
self.fac = fac
|
30 |
+
self.truncated = True
|
31 |
+
self.wsize = 3 # by default using truncation 7x7
|
32 |
+
|
33 |
+
flowrangey = range(-maxdisp,maxdisp+1)
|
34 |
+
flowrangex = range(-int(maxdisp//self.fac),int(maxdisp//self.fac)+1)
|
35 |
+
meshgrid = np.meshgrid(flowrangex,flowrangey)
|
36 |
+
flowy = np.tile( np.reshape(meshgrid[0],[1,2*maxdisp+1,2*int(maxdisp//self.fac)+1,1,1]), (B,1,1,H,W) )
|
37 |
+
flowx = np.tile( np.reshape(meshgrid[1],[1,2*maxdisp+1,2*int(maxdisp//self.fac)+1,1,1]), (B,1,1,H,W) )
|
38 |
+
self.register_buffer('flowx',torch.Tensor(flowx))
|
39 |
+
self.register_buffer('flowy',torch.Tensor(flowy))
|
40 |
+
|
41 |
+
self.pool3d = nn.MaxPool3d((self.wsize*2+1,self.wsize*2+1,1),stride=1,padding=(self.wsize,self.wsize,0))
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
b,u,v,h,w = x.shape
|
45 |
+
oldx = x
|
46 |
+
|
47 |
+
if self.truncated:
|
48 |
+
# truncated softmax
|
49 |
+
x = x.view(b,u*v,h,w)
|
50 |
+
|
51 |
+
idx = x.argmax(1)[:,np.newaxis]
|
52 |
+
if x.is_cuda:
|
53 |
+
mask = Variable(torch.cuda.HalfTensor(b,u*v,h,w)).fill_(0)
|
54 |
+
else:
|
55 |
+
mask = Variable(torch.FloatTensor(b,u*v,h,w)).fill_(0)
|
56 |
+
mask.scatter_(1,idx,1)
|
57 |
+
mask = mask.view(b,1,u,v,-1)
|
58 |
+
mask = self.pool3d(mask)[:,0].view(b,u,v,h,w)
|
59 |
+
|
60 |
+
ninf = x.clone().fill_(-np.inf).view(b,u,v,h,w)
|
61 |
+
x = torch.where(mask.byte(),oldx,ninf)
|
62 |
+
else:
|
63 |
+
self.wsize = (np.sqrt(u*v)-1)/2
|
64 |
+
|
65 |
+
b,u,v,h,w = x.shape
|
66 |
+
x = F.softmax(x.view(b,-1,h,w),1).view(b,u,v,h,w)
|
67 |
+
outx = torch.sum(torch.sum(x*self.flowx,1),1,keepdim=True)
|
68 |
+
outy = torch.sum(torch.sum(x*self.flowy,1),1,keepdim=True)
|
69 |
+
|
70 |
+
if self.ent:
|
71 |
+
# local
|
72 |
+
local_entropy = (-x*torch.clamp(x,1e-9,1-1e-9).log()).sum(1).sum(1)[:,np.newaxis]
|
73 |
+
if self.wsize == 0:
|
74 |
+
local_entropy[:] = 1.
|
75 |
+
else:
|
76 |
+
local_entropy /= np.log((self.wsize*2+1)**2)
|
77 |
+
|
78 |
+
# global
|
79 |
+
x = F.softmax(oldx.view(b,-1,h,w),1).view(b,u,v,h,w)
|
80 |
+
global_entropy = (-x*torch.clamp(x,1e-9,1-1e-9).log()).sum(1).sum(1)[:,np.newaxis]
|
81 |
+
global_entropy /= np.log(x.shape[1]*x.shape[2])
|
82 |
+
return torch.cat([outx,outy],1),torch.cat([local_entropy, global_entropy],1)
|
83 |
+
else:
|
84 |
+
return torch.cat([outx,outy],1),None
|
85 |
+
|
86 |
+
|
87 |
+
class WarpModule(nn.Module):
|
88 |
+
"""
|
89 |
+
taken from https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py
|
90 |
+
"""
|
91 |
+
def __init__(self, size):
|
92 |
+
super(WarpModule, self).__init__()
|
93 |
+
B,W,H = size
|
94 |
+
# mesh grid
|
95 |
+
xx = torch.arange(0, W).view(1,-1).repeat(H,1)
|
96 |
+
yy = torch.arange(0, H).view(-1,1).repeat(1,W)
|
97 |
+
xx = xx.view(1,1,H,W).repeat(B,1,1,1)
|
98 |
+
yy = yy.view(1,1,H,W).repeat(B,1,1,1)
|
99 |
+
self.register_buffer('grid',torch.cat((xx,yy),1).float())
|
100 |
+
|
101 |
+
def forward(self, x, flo):
|
102 |
+
"""
|
103 |
+
warp an image/tensor (im2) back to im1, according to the optical flow
|
104 |
+
|
105 |
+
x: [B, C, H, W] (im2)
|
106 |
+
flo: [B, 2, H, W] flow
|
107 |
+
|
108 |
+
"""
|
109 |
+
B, C, H, W = x.size()
|
110 |
+
vgrid = self.grid + flo
|
111 |
+
|
112 |
+
# scale grid to [-1,1]
|
113 |
+
vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:]/max(W-1,1)-1.0
|
114 |
+
vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:]/max(H-1,1)-1.0
|
115 |
+
|
116 |
+
vgrid = vgrid.permute(0,2,3,1)
|
117 |
+
output = nn.functional.grid_sample(x, vgrid,align_corners=True)
|
118 |
+
mask = ((vgrid[:,:,:,0].abs()<1) * (vgrid[:,:,:,1].abs()<1)) >0
|
119 |
+
return output*mask.unsqueeze(1).float(), mask
|
120 |
+
|
121 |
+
|
122 |
+
def get_grid(B,H,W):
|
123 |
+
meshgrid_base = np.meshgrid(range(0,W), range(0,H))[::-1]
|
124 |
+
basey = np.reshape(meshgrid_base[0],[1,1,1,H,W])
|
125 |
+
basex = np.reshape(meshgrid_base[1],[1,1,1,H,W])
|
126 |
+
grid = torch.tensor(np.concatenate((basex.reshape((-1,H,W,1)),basey.reshape((-1,H,W,1))),-1)).cuda().float()
|
127 |
+
return grid.view(1,1,H,W,2)
|
128 |
+
|
129 |
+
|
130 |
+
class VCN(nn.Module):
|
131 |
+
"""
|
132 |
+
VCN.
|
133 |
+
md defines maximum displacement for each level, following a coarse-to-fine-warping scheme
|
134 |
+
fac defines squeeze parameter for the coarsest level
|
135 |
+
"""
|
136 |
+
def __init__(self, size, md=[4,4,4,4,4], fac=1.,exp_unc=False): # exp_uncertainty
|
137 |
+
super(VCN,self).__init__()
|
138 |
+
self.md = md
|
139 |
+
self.fac = fac
|
140 |
+
use_entropy = True
|
141 |
+
withbn = True
|
142 |
+
|
143 |
+
## pspnet
|
144 |
+
self.pspnet = pspnet(is_proj=False)
|
145 |
+
|
146 |
+
### Volumetric-UNet
|
147 |
+
fdima1 = 128 # 6/5/4
|
148 |
+
fdima2 = 64 # 3/2
|
149 |
+
fdimb1 = 16 # 6/5/4/3
|
150 |
+
fdimb2 = 12 # 2
|
151 |
+
|
152 |
+
full=False
|
153 |
+
self.f6 = butterfly4D(fdima1, fdimb1,withbn=withbn,full=full)
|
154 |
+
self.p6 = sepConv4d(fdimb1,fdimb1, with_bn=False, full=full)
|
155 |
+
|
156 |
+
self.f5 = butterfly4D(fdima1, fdimb1,withbn=withbn, full=full)
|
157 |
+
self.p5 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full)
|
158 |
+
|
159 |
+
self.f4 = butterfly4D(fdima1, fdimb1,withbn=withbn,full=full)
|
160 |
+
self.p4 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full)
|
161 |
+
|
162 |
+
self.f3 = butterfly4D(fdima2, fdimb1,withbn=withbn,full=full)
|
163 |
+
self.p3 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full)
|
164 |
+
|
165 |
+
full=True
|
166 |
+
self.f2 = butterfly4D(fdima2, fdimb2,withbn=withbn,full=full)
|
167 |
+
self.p2 = sepConv4d(fdimb2,fdimb2, with_bn=False,full=full)
|
168 |
+
|
169 |
+
self.flow_reg64 = flow_reg([fdimb1*size[0],size[1]//64,size[2]//64], ent=use_entropy, maxdisp=self.md[0], fac=self.fac)
|
170 |
+
self.flow_reg32 = flow_reg([fdimb1*size[0],size[1]//32,size[2]//32], ent=use_entropy, maxdisp=self.md[1])
|
171 |
+
self.flow_reg16 = flow_reg([fdimb1*size[0],size[1]//16,size[2]//16], ent=use_entropy, maxdisp=self.md[2])
|
172 |
+
self.flow_reg8 = flow_reg([fdimb1*size[0],size[1]//8,size[2]//8] , ent=use_entropy, maxdisp=self.md[3])
|
173 |
+
self.flow_reg4 = flow_reg([fdimb2*size[0],size[1]//4,size[2]//4] , ent=use_entropy, maxdisp=self.md[4])
|
174 |
+
|
175 |
+
self.warp5 = WarpModule([size[0],size[1]//32,size[2]//32])
|
176 |
+
self.warp4 = WarpModule([size[0],size[1]//16,size[2]//16])
|
177 |
+
self.warp3 = WarpModule([size[0],size[1]//8,size[2]//8])
|
178 |
+
self.warp2 = WarpModule([size[0],size[1]//4,size[2]//4])
|
179 |
+
|
180 |
+
## hypotheses fusion modules, adopted from the refinement module of PWCNet
|
181 |
+
# https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py
|
182 |
+
# c6
|
183 |
+
self.dc6_conv1 = conv(128+4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1)
|
184 |
+
self.dc6_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
|
185 |
+
self.dc6_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
|
186 |
+
self.dc6_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
|
187 |
+
self.dc6_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
|
188 |
+
self.dc6_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
|
189 |
+
self.dc6_conv7 = nn.Conv2d(32,2*fdimb1,kernel_size=3,stride=1,padding=1,bias=True)
|
190 |
+
|
191 |
+
# c5
|
192 |
+
self.dc5_conv1 = conv(128+4*fdimb1*2, 128, kernel_size=3, stride=1, padding=1, dilation=1)
|
193 |
+
self.dc5_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
|
194 |
+
self.dc5_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
|
195 |
+
self.dc5_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
|
196 |
+
self.dc5_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
|
197 |
+
self.dc5_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
|
198 |
+
self.dc5_conv7 = nn.Conv2d(32,2*fdimb1*2,kernel_size=3,stride=1,padding=1,bias=True)
|
199 |
+
|
200 |
+
# c4
|
201 |
+
self.dc4_conv1 = conv(128+4*fdimb1*3, 128, kernel_size=3, stride=1, padding=1, dilation=1)
|
202 |
+
self.dc4_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
|
203 |
+
self.dc4_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
|
204 |
+
self.dc4_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
|
205 |
+
self.dc4_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
|
206 |
+
self.dc4_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
|
207 |
+
self.dc4_conv7 = nn.Conv2d(32,2*fdimb1*3,kernel_size=3,stride=1,padding=1,bias=True)
|
208 |
+
|
209 |
+
# c3
|
210 |
+
self.dc3_conv1 = conv(64+16*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1)
|
211 |
+
self.dc3_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
|
212 |
+
self.dc3_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
|
213 |
+
self.dc3_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
|
214 |
+
self.dc3_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
|
215 |
+
self.dc3_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
|
216 |
+
self.dc3_conv7 = nn.Conv2d(32,8*fdimb1,kernel_size=3,stride=1,padding=1,bias=True)
|
217 |
+
|
218 |
+
# c2
|
219 |
+
self.dc2_conv1 = conv(64+16*fdimb1+4*fdimb2, 128, kernel_size=3, stride=1, padding=1, dilation=1)
|
220 |
+
self.dc2_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
|
221 |
+
self.dc2_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
|
222 |
+
self.dc2_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
|
223 |
+
self.dc2_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
|
224 |
+
self.dc2_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
|
225 |
+
self.dc2_conv7 = nn.Conv2d(32,4*2*fdimb1 + 2*fdimb2,kernel_size=3,stride=1,padding=1,bias=True)
|
226 |
+
|
227 |
+
self.dc6_conv = nn.Sequential( self.dc6_conv1,
|
228 |
+
self.dc6_conv2,
|
229 |
+
self.dc6_conv3,
|
230 |
+
self.dc6_conv4,
|
231 |
+
self.dc6_conv5,
|
232 |
+
self.dc6_conv6,
|
233 |
+
self.dc6_conv7)
|
234 |
+
self.dc5_conv = nn.Sequential( self.dc5_conv1,
|
235 |
+
self.dc5_conv2,
|
236 |
+
self.dc5_conv3,
|
237 |
+
self.dc5_conv4,
|
238 |
+
self.dc5_conv5,
|
239 |
+
self.dc5_conv6,
|
240 |
+
self.dc5_conv7)
|
241 |
+
self.dc4_conv = nn.Sequential( self.dc4_conv1,
|
242 |
+
self.dc4_conv2,
|
243 |
+
self.dc4_conv3,
|
244 |
+
self.dc4_conv4,
|
245 |
+
self.dc4_conv5,
|
246 |
+
self.dc4_conv6,
|
247 |
+
self.dc4_conv7)
|
248 |
+
self.dc3_conv = nn.Sequential( self.dc3_conv1,
|
249 |
+
self.dc3_conv2,
|
250 |
+
self.dc3_conv3,
|
251 |
+
self.dc3_conv4,
|
252 |
+
self.dc3_conv5,
|
253 |
+
self.dc3_conv6,
|
254 |
+
self.dc3_conv7)
|
255 |
+
self.dc2_conv = nn.Sequential( self.dc2_conv1,
|
256 |
+
self.dc2_conv2,
|
257 |
+
self.dc2_conv3,
|
258 |
+
self.dc2_conv4,
|
259 |
+
self.dc2_conv5,
|
260 |
+
self.dc2_conv6,
|
261 |
+
self.dc2_conv7)
|
262 |
+
|
263 |
+
## Out-of-range detection
|
264 |
+
self.dc6_convo = nn.Sequential(conv(128+4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1),
|
265 |
+
conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
|
266 |
+
conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
|
267 |
+
conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
|
268 |
+
conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
|
269 |
+
conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
|
270 |
+
nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
|
271 |
+
|
272 |
+
self.dc5_convo = nn.Sequential(conv(128+2*4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1),
|
273 |
+
conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
|
274 |
+
conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
|
275 |
+
conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
|
276 |
+
conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
|
277 |
+
conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
|
278 |
+
nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
|
279 |
+
|
280 |
+
self.dc4_convo = nn.Sequential(conv(128+3*4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1),
|
281 |
+
conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
|
282 |
+
conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
|
283 |
+
conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
|
284 |
+
conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
|
285 |
+
conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
|
286 |
+
nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
|
287 |
+
|
288 |
+
self.dc3_convo = nn.Sequential(conv(64+16*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1),
|
289 |
+
conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
|
290 |
+
conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
|
291 |
+
conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
|
292 |
+
conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
|
293 |
+
conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
|
294 |
+
nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
|
295 |
+
|
296 |
+
self.dc2_convo = nn.Sequential(conv(64+16*fdimb1+4*fdimb2, 128, kernel_size=3, stride=1, padding=1, dilation=1),
|
297 |
+
conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
|
298 |
+
conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
|
299 |
+
conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
|
300 |
+
conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
|
301 |
+
conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
|
302 |
+
nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
|
303 |
+
|
304 |
+
# affine-exp
|
305 |
+
self.f3d2v1 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
|
306 |
+
self.f3d2v2 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
|
307 |
+
self.f3d2v3 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
|
308 |
+
self.f3d2v4 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
|
309 |
+
self.f3d2v5 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
|
310 |
+
self.f3d2v6 = conv(12*81, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
|
311 |
+
self.f3d2 = bfmodule(128-64,1)
|
312 |
+
|
313 |
+
# depth change net
|
314 |
+
self.dcnetv1 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
|
315 |
+
self.dcnetv2 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
|
316 |
+
self.dcnetv3 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
|
317 |
+
self.dcnetv4 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
|
318 |
+
self.dcnetv5 = conv(12*81, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
|
319 |
+
self.dcnetv6 = conv(4, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
|
320 |
+
if exp_unc:
|
321 |
+
self.dcnet = bfmodule(128,2)
|
322 |
+
else:
|
323 |
+
self.dcnet = bfmodule(128,1)
|
324 |
+
|
325 |
+
for m in self.modules():
|
326 |
+
if isinstance(m, nn.Conv3d):
|
327 |
+
n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels
|
328 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
329 |
+
if hasattr(m.bias,'data'):
|
330 |
+
m.bias.data.zero_()
|
331 |
+
|
332 |
+
self.facs = [self.fac,1,1,1,1]
|
333 |
+
self.warp_modules = nn.ModuleList([None, self.warp5, self.warp4, self.warp3, self.warp2])
|
334 |
+
self.f_modules = nn.ModuleList([self.f6, self.f5, self.f4, self.f3, self.f2])
|
335 |
+
self.p_modules = nn.ModuleList([self.p6, self.p5, self.p4, self.p3, self.p2])
|
336 |
+
self.reg_modules = nn.ModuleList([self.flow_reg64, self.flow_reg32, self.flow_reg16, self.flow_reg8, self.flow_reg4])
|
337 |
+
self.oor_modules = nn.ModuleList([self.dc6_convo, self.dc5_convo, self.dc4_convo, self.dc3_convo, self.dc2_convo])
|
338 |
+
self.fuse_modules = nn.ModuleList([self.dc6_conv, self.dc5_conv, self.dc4_conv, self.dc3_conv, self.dc2_conv])
|
339 |
+
|
340 |
+
def corrf(self, refimg_fea, targetimg_fea,maxdisp, fac=1):
|
341 |
+
"""
|
342 |
+
slow correlation function
|
343 |
+
"""
|
344 |
+
b,c,height,width = refimg_fea.shape
|
345 |
+
if refimg_fea.is_cuda:
|
346 |
+
cost = Variable(torch.cuda.FloatTensor(b,c,2*maxdisp+1,2*int(maxdisp//fac)+1,height,width)).fill_(0.) # b,c,u,v,h,w
|
347 |
+
else:
|
348 |
+
cost = Variable(torch.FloatTensor(b,c,2*maxdisp+1,2*int(maxdisp//fac)+1,height,width)).fill_(0.) # b,c,u,v,h,w
|
349 |
+
for i in range(2*maxdisp+1):
|
350 |
+
ind = i-maxdisp
|
351 |
+
for j in range(2*int(maxdisp//fac)+1):
|
352 |
+
indd = j-int(maxdisp//fac)
|
353 |
+
feata = refimg_fea[:,:,max(0,-indd):height-indd,max(0,-ind):width-ind]
|
354 |
+
featb = targetimg_fea[:,:,max(0,+indd):height+indd,max(0,ind):width+ind]
|
355 |
+
diff = (feata*featb)
|
356 |
+
cost[:, :, i,j,max(0,-indd):height-indd,max(0,-ind):width-ind] = diff # standard
|
357 |
+
cost = F.leaky_relu(cost, 0.1,inplace=True)
|
358 |
+
return cost
|
359 |
+
|
360 |
+
def cost_matching(self,up_flow, c1, c2, flowh, enth, level):
|
361 |
+
"""
|
362 |
+
up_flow: upsample coarse flow
|
363 |
+
c1: normalized feature of image 1
|
364 |
+
c2: normalized feature of image 2
|
365 |
+
flowh: flow hypotheses
|
366 |
+
enth: entropy
|
367 |
+
"""
|
368 |
+
|
369 |
+
# normalize
|
370 |
+
c1n = c1 / (c1.norm(dim=1, keepdim=True)+1e-9)
|
371 |
+
c2n = c2 / (c2.norm(dim=1, keepdim=True)+1e-9)
|
372 |
+
|
373 |
+
# cost volume
|
374 |
+
if level == 0:
|
375 |
+
warp = c2n
|
376 |
+
else:
|
377 |
+
warp,_ = self.warp_modules[level](c2n, up_flow)
|
378 |
+
|
379 |
+
feat = self.corrf(c1n,warp,self.md[level],fac=self.facs[level])
|
380 |
+
feat = self.f_modules[level](feat)
|
381 |
+
cost = self.p_modules[level](feat) # b, 16, u,v,h,w
|
382 |
+
|
383 |
+
# soft WTA
|
384 |
+
b,c,u,v,h,w = cost.shape
|
385 |
+
cost = cost.view(-1,u,v,h,w) # bx16, 9,9,h,w, also predict uncertainty from here
|
386 |
+
flowhh,enthh = self.reg_modules[level](cost) # bx16, 2, h, w
|
387 |
+
flowhh = flowhh.view(b,c,2,h,w)
|
388 |
+
if level > 0:
|
389 |
+
flowhh = flowhh + up_flow[:,np.newaxis]
|
390 |
+
flowhh = flowhh.view(b,-1,h,w) # b, 16*2, h, w
|
391 |
+
enthh = enthh.view(b,-1,h,w) # b, 16*1, h, w
|
392 |
+
|
393 |
+
# append coarse hypotheses
|
394 |
+
if level == 0:
|
395 |
+
flowh = flowhh
|
396 |
+
enth = enthh
|
397 |
+
else:
|
398 |
+
flowh = torch.cat((flowhh, F.upsample(flowh.detach()*2, [flowhh.shape[2],flowhh.shape[3]], mode='bilinear')),1) # b, k2--k2, h, w
|
399 |
+
enth = torch.cat((enthh, F.upsample(enth, [flowhh.shape[2],flowhh.shape[3]], mode='bilinear')),1)
|
400 |
+
|
401 |
+
if self.training or level==4:
|
402 |
+
x = torch.cat((enth.detach(), flowh.detach(), c1),1)
|
403 |
+
oor = self.oor_modules[level](x)[:,0]
|
404 |
+
else: oor = None
|
405 |
+
|
406 |
+
# hypotheses fusion
|
407 |
+
x = torch.cat((enth.detach(), flowh.detach(), c1),1)
|
408 |
+
va = self.fuse_modules[level](x)
|
409 |
+
va = va.view(b,-1,2,h,w)
|
410 |
+
flow = ( flowh.view(b,-1,2,h,w) * F.softmax(va,1) ).sum(1) # b, 2k, 2, h, w
|
411 |
+
|
412 |
+
return flow, flowh, enth, oor
|
413 |
+
|
414 |
+
def affine(self,pref,flow, pw=1):
|
415 |
+
b,_,lh,lw=flow.shape
|
416 |
+
ptar = pref + flow
|
417 |
+
pw = 1
|
418 |
+
pref = F.unfold(pref, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-pref[:,:,np.newaxis]
|
419 |
+
ptar = F.unfold(ptar, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-ptar[:,:,np.newaxis] # b, 2,9,h,w
|
420 |
+
pref = pref.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2)
|
421 |
+
ptar = ptar.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2)
|
422 |
+
|
423 |
+
prefprefT = pref.matmul(pref.permute(0,2,1))
|
424 |
+
ppdet = prefprefT[:,0,0]*prefprefT[:,1,1]-prefprefT[:,1,0]*prefprefT[:,0,1]
|
425 |
+
ppinv = torch.cat((prefprefT[:,1,1:],-prefprefT[:,0,1:], -prefprefT[:,1:,0], prefprefT[:,0:1,0]),1).view(-1,2,2)/ppdet.clamp(1e-10,np.inf)[:,np.newaxis,np.newaxis]
|
426 |
+
|
427 |
+
Affine = ptar.matmul(pref.permute(0,2,1)).matmul(ppinv)
|
428 |
+
Error = (Affine.matmul(pref)-ptar).norm(2,1).mean(1).view(b,1,lh,lw)
|
429 |
+
|
430 |
+
Avol = (Affine[:,0,0]*Affine[:,1,1]-Affine[:,1,0]*Affine[:,0,1]).view(b,1,lh,lw).abs().clamp(1e-10,np.inf)
|
431 |
+
exp = Avol.sqrt()
|
432 |
+
mask = (exp>0.5) & (exp<2) & (Error<0.1)
|
433 |
+
mask = mask[:,0]
|
434 |
+
|
435 |
+
exp = exp.clamp(0.5,2)
|
436 |
+
exp[Error>0.1]=1
|
437 |
+
return exp, Error, mask
|
438 |
+
|
439 |
+
def affine_mask(self,pref,flow, pw=3):
|
440 |
+
"""
|
441 |
+
pref: reference coordinates
|
442 |
+
pw: patch width
|
443 |
+
"""
|
444 |
+
flmask = flow[:,2:]
|
445 |
+
flow = flow[:,:2]
|
446 |
+
b,_,lh,lw=flow.shape
|
447 |
+
ptar = pref + flow
|
448 |
+
pref = F.unfold(pref, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-pref[:,:,np.newaxis]
|
449 |
+
ptar = F.unfold(ptar, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-ptar[:,:,np.newaxis] # b, 2,9,h,w
|
450 |
+
|
451 |
+
conf_flow = flmask
|
452 |
+
conf_flow = F.unfold(conf_flow,(pw*2+1,pw*2+1), padding=(pw)).view(b,1,(pw*2+1)**2,lh,lw)
|
453 |
+
count = conf_flow.sum(2,keepdims=True)
|
454 |
+
conf_flow = ((pw*2+1)**2)*conf_flow / count
|
455 |
+
pref = pref * conf_flow
|
456 |
+
ptar = ptar * conf_flow
|
457 |
+
|
458 |
+
pref = pref.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2)
|
459 |
+
ptar = ptar.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2)
|
460 |
+
|
461 |
+
prefprefT = pref.matmul(pref.permute(0,2,1))
|
462 |
+
ppdet = prefprefT[:,0,0]*prefprefT[:,1,1]-prefprefT[:,1,0]*prefprefT[:,0,1]
|
463 |
+
ppinv = torch.cat((prefprefT[:,1,1:],-prefprefT[:,0,1:], -prefprefT[:,1:,0], prefprefT[:,0:1,0]),1).view(-1,2,2)/ppdet.clamp(1e-10,np.inf)[:,np.newaxis,np.newaxis]
|
464 |
+
|
465 |
+
Affine = ptar.matmul(pref.permute(0,2,1)).matmul(ppinv)
|
466 |
+
Error = (Affine.matmul(pref)-ptar).norm(2,1).mean(1).view(b,1,lh,lw)
|
467 |
+
|
468 |
+
Avol = (Affine[:,0,0]*Affine[:,1,1]-Affine[:,1,0]*Affine[:,0,1]).view(b,1,lh,lw).abs().clamp(1e-10,np.inf)
|
469 |
+
exp = Avol.sqrt()
|
470 |
+
mask = (exp>0.5) & (exp<2) & (Error<0.2) & (flmask.bool()) & (count[:,0]>4)
|
471 |
+
mask = mask[:,0]
|
472 |
+
|
473 |
+
exp = exp.clamp(0.5,2)
|
474 |
+
exp[Error>0.2]=1
|
475 |
+
return exp, Error, mask
|
476 |
+
|
477 |
+
def weight_parameters(self):
|
478 |
+
return [param for name, param in self.named_parameters() if 'weight' in name]
|
479 |
+
|
480 |
+
def bias_parameters(self):
|
481 |
+
return [param for name, param in self.named_parameters() if 'bias' in name]
|
482 |
+
|
483 |
+
def forward(self,im,disc_aux=None):
|
484 |
+
bs = im.shape[0]//2
|
485 |
+
|
486 |
+
if self.training and disc_aux[-1]: # if only fine-tuning expansion
|
487 |
+
reset=True
|
488 |
+
self.eval()
|
489 |
+
torch.set_grad_enabled(False)
|
490 |
+
else: reset=False
|
491 |
+
|
492 |
+
c06,c05,c04,c03,c02 = self.pspnet(im)
|
493 |
+
c16 = c06[:bs]; c26 = c06[bs:]
|
494 |
+
c15 = c05[:bs]; c25 = c05[bs:]
|
495 |
+
c14 = c04[:bs]; c24 = c04[bs:]
|
496 |
+
c13 = c03[:bs]; c23 = c03[bs:]
|
497 |
+
c12 = c02[:bs]; c22 = c02[bs:]
|
498 |
+
|
499 |
+
## matching 6
|
500 |
+
flow6, flow6h, ent6h, oor6 = self.cost_matching(None, c16, c26, None, None,level=0)
|
501 |
+
|
502 |
+
## matching 5
|
503 |
+
up_flow6 = F.upsample(flow6, [im.size()[2]//32,im.size()[3]//32], mode='bilinear')*2
|
504 |
+
flow5, flow5h, ent5h, oor5 = self.cost_matching(up_flow6, c15, c25, flow6h, ent6h,level=1)
|
505 |
+
|
506 |
+
## matching 4
|
507 |
+
up_flow5 = F.upsample(flow5, [im.size()[2]//16,im.size()[3]//16], mode='bilinear')*2
|
508 |
+
flow4, flow4h, ent4h, oor4 = self.cost_matching(up_flow5, c14, c24, flow5h, ent5h,level=2)
|
509 |
+
|
510 |
+
## matching 3
|
511 |
+
up_flow4 = F.upsample(flow4, [im.size()[2]//8,im.size()[3]//8], mode='bilinear')*2
|
512 |
+
flow3, flow3h, ent3h, oor3 = self.cost_matching(up_flow4, c13, c23, flow4h, ent4h,level=3)
|
513 |
+
|
514 |
+
## matching 2
|
515 |
+
up_flow3 = F.upsample(flow3, [im.size()[2]//4,im.size()[3]//4], mode='bilinear')*2
|
516 |
+
flow2, flow2h, ent2h, oor2 = self.cost_matching(up_flow3, c12, c22, flow3h, ent3h,level=4)
|
517 |
+
|
518 |
+
if reset:
|
519 |
+
torch.set_grad_enabled(True)
|
520 |
+
self.train()
|
521 |
+
|
522 |
+
# expansion
|
523 |
+
b,_,h,w = flow2.shape
|
524 |
+
exp2,err2,_ = self.affine(get_grid(b,h,w)[:,0].permute(0,3,1,2).repeat(b,1,1,1).clone(), flow2.detach(),pw=1)
|
525 |
+
x = torch.cat((
|
526 |
+
self.f3d2v2(-exp2.log()),
|
527 |
+
self.f3d2v3(err2),
|
528 |
+
),1)
|
529 |
+
dchange2 = -exp2.log()+1./200*self.f3d2(x)[0]
|
530 |
+
|
531 |
+
# depth change net
|
532 |
+
iexp2 = F.upsample(dchange2.clone(), [im.size()[2],im.size()[3]], mode='bilinear')
|
533 |
+
|
534 |
+
x = torch.cat((self.dcnetv1(c12.detach()),
|
535 |
+
self.dcnetv2(dchange2.detach()),
|
536 |
+
self.dcnetv3(-exp2.log()),
|
537 |
+
self.dcnetv4(err2),
|
538 |
+
),1)
|
539 |
+
dcneto = 1./200*self.dcnet(x)[0]
|
540 |
+
dchange2 = dchange2.detach() + dcneto[:,:1]
|
541 |
+
|
542 |
+
flow2 = F.upsample(flow2.detach(), [im.size()[2],im.size()[3]], mode='bilinear')*4
|
543 |
+
dchange2 = F.upsample(dchange2, [im.size()[2],im.size()[3]], mode='bilinear')
|
544 |
+
|
545 |
+
if self.training:
|
546 |
+
flowl0 = disc_aux[0].permute(0,3,1,2).clone()
|
547 |
+
gt_depth = disc_aux[2][:,:,:,0]
|
548 |
+
gt_f3d = disc_aux[2][:,:,:,4:7].permute(0,3,1,2).clone()
|
549 |
+
gt_dchange = (1+gt_f3d[:,2]/gt_depth)
|
550 |
+
maskdc = (gt_dchange < 2) & (gt_dchange > 0.5) & disc_aux[1]
|
551 |
+
|
552 |
+
gt_expi,gt_expi_err,maskoe = self.affine_mask(get_grid(b,4*h,4*w)[:,0].permute(0,3,1,2).repeat(b,1,1,1), flowl0,pw=3)
|
553 |
+
gt_exp = 1./gt_expi[:,0]
|
554 |
+
|
555 |
+
loss = 0.1* (dchange2[:,0]-gt_dchange.log()).abs()[maskdc].mean()
|
556 |
+
loss += 0.1* (iexp2[:,0]-gt_exp.log()).abs()[maskoe].mean()
|
557 |
+
return flow2*4, flow3*8,flow4*16,flow5*32,flow6*64,loss, dchange2[:,0], iexp2[:,0]
|
558 |
+
|
559 |
+
else:
|
560 |
+
return flow2, oor2, dchange2, iexp2
|
561 |
+
|
expansion/models/__init__.py
ADDED
File without changes
|
expansion/models/__pycache__/VCN_exp.cpython-38.pyc
ADDED
Binary file (17.7 kB). View file
|
|
expansion/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (158 Bytes). View file
|
|
expansion/models/__pycache__/conv4d.cpython-38.pyc
ADDED
Binary file (8.26 kB). View file
|
|
expansion/models/__pycache__/submodule.cpython-38.pyc
ADDED
Binary file (12 kB). View file
|
|
expansion/models/conv4d.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
from torch.nn.parameter import Parameter
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn import Module
|
8 |
+
from torch.nn.modules.conv import _ConvNd
|
9 |
+
from torch.nn.modules.utils import _quadruple
|
10 |
+
from torch.autograd import Variable
|
11 |
+
from torch.nn import Conv2d
|
12 |
+
|
13 |
+
def conv4d(data,filters,bias=None,permute_filters=True,use_half=False):
|
14 |
+
"""
|
15 |
+
This is done by stacking results of multiple 3D convolutions, and is very slow.
|
16 |
+
Taken from https://github.com/ignacio-rocco/ncnet
|
17 |
+
"""
|
18 |
+
b,c,h,w,d,t=data.size()
|
19 |
+
|
20 |
+
data=data.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop
|
21 |
+
|
22 |
+
# Same permutation is done with filters, unless already provided with permutation
|
23 |
+
if permute_filters:
|
24 |
+
filters=filters.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop
|
25 |
+
|
26 |
+
c_out=filters.size(1)
|
27 |
+
if use_half:
|
28 |
+
output = Variable(torch.HalfTensor(h,b,c_out,w,d,t),requires_grad=data.requires_grad)
|
29 |
+
else:
|
30 |
+
output = Variable(torch.zeros(h,b,c_out,w,d,t),requires_grad=data.requires_grad)
|
31 |
+
|
32 |
+
padding=filters.size(0)//2
|
33 |
+
if use_half:
|
34 |
+
Z=Variable(torch.zeros(padding,b,c,w,d,t).half())
|
35 |
+
else:
|
36 |
+
Z=Variable(torch.zeros(padding,b,c,w,d,t))
|
37 |
+
|
38 |
+
if data.is_cuda:
|
39 |
+
Z=Z.cuda(data.get_device())
|
40 |
+
output=output.cuda(data.get_device())
|
41 |
+
|
42 |
+
data_padded = torch.cat((Z,data,Z),0)
|
43 |
+
|
44 |
+
|
45 |
+
for i in range(output.size(0)): # loop on first feature dimension
|
46 |
+
# convolve with center channel of filter (at position=padding)
|
47 |
+
output[i,:,:,:,:,:]=F.conv3d(data_padded[i+padding,:,:,:,:,:],
|
48 |
+
filters[padding,:,:,:,:,:], bias=bias, stride=1, padding=padding)
|
49 |
+
# convolve with upper/lower channels of filter (at postions [:padding] [padding+1:])
|
50 |
+
for p in range(1,padding+1):
|
51 |
+
output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding-p,:,:,:,:,:],
|
52 |
+
filters[padding-p,:,:,:,:,:], bias=None, stride=1, padding=padding)
|
53 |
+
output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding+p,:,:,:,:,:],
|
54 |
+
filters[padding+p,:,:,:,:,:], bias=None, stride=1, padding=padding)
|
55 |
+
|
56 |
+
output=output.permute(1,2,0,3,4,5).contiguous()
|
57 |
+
return output
|
58 |
+
|
59 |
+
class Conv4d(_ConvNd):
|
60 |
+
"""Applies a 4D convolution over an input signal composed of several input
|
61 |
+
planes.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True):
|
65 |
+
# stride, dilation and groups !=1 functionality not tested
|
66 |
+
stride=1
|
67 |
+
dilation=1
|
68 |
+
groups=1
|
69 |
+
# zero padding is added automatically in conv4d function to preserve tensor size
|
70 |
+
padding = 0
|
71 |
+
kernel_size = _quadruple(kernel_size)
|
72 |
+
stride = _quadruple(stride)
|
73 |
+
padding = _quadruple(padding)
|
74 |
+
dilation = _quadruple(dilation)
|
75 |
+
super(Conv4d, self).__init__(
|
76 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
77 |
+
False, _quadruple(0), groups, bias)
|
78 |
+
# weights will be sliced along one dimension during convolution loop
|
79 |
+
# make the looping dimension to be the first one in the tensor,
|
80 |
+
# so that we don't need to call contiguous() inside the loop
|
81 |
+
self.pre_permuted_filters=pre_permuted_filters
|
82 |
+
if self.pre_permuted_filters:
|
83 |
+
self.weight.data=self.weight.data.permute(2,0,1,3,4,5).contiguous()
|
84 |
+
self.use_half=False
|
85 |
+
# self.isbias = bias
|
86 |
+
# if not self.isbias:
|
87 |
+
# self.bn = torch.nn.BatchNorm1d(out_channels)
|
88 |
+
|
89 |
+
|
90 |
+
def forward(self, input):
|
91 |
+
out = conv4d(input, self.weight, bias=self.bias,permute_filters=not self.pre_permuted_filters,use_half=self.use_half) # filters pre-permuted in constructor
|
92 |
+
# if not self.isbias:
|
93 |
+
# b,c,u,v,h,w = out.shape
|
94 |
+
# out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w)
|
95 |
+
return out
|
96 |
+
|
97 |
+
class fullConv4d(torch.nn.Module):
|
98 |
+
def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True):
|
99 |
+
super(fullConv4d, self).__init__()
|
100 |
+
self.conv = Conv4d(in_channels, out_channels, kernel_size, bias=bias, pre_permuted_filters=pre_permuted_filters)
|
101 |
+
self.isbias = bias
|
102 |
+
if not self.isbias:
|
103 |
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
104 |
+
|
105 |
+
def forward(self, input):
|
106 |
+
out = self.conv(input)
|
107 |
+
if not self.isbias:
|
108 |
+
b,c,u,v,h,w = out.shape
|
109 |
+
out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w)
|
110 |
+
return out
|
111 |
+
|
112 |
+
class butterfly4D(torch.nn.Module):
|
113 |
+
'''
|
114 |
+
butterfly 4d
|
115 |
+
'''
|
116 |
+
def __init__(self, fdima, fdimb, withbn=True, full=True,groups=1):
|
117 |
+
super(butterfly4D, self).__init__()
|
118 |
+
self.proj = nn.Sequential(projfeat4d(fdima, fdimb, 1, with_bn=withbn,groups=groups),
|
119 |
+
nn.ReLU(inplace=True),)
|
120 |
+
self.conva1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups)
|
121 |
+
self.conva2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups)
|
122 |
+
self.convb3 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
|
123 |
+
self.convb2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
|
124 |
+
self.convb1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
|
125 |
+
|
126 |
+
#@profile
|
127 |
+
def forward(self,x):
|
128 |
+
out = self.proj(x)
|
129 |
+
b,c,u,v,h,w = out.shape # 9x9
|
130 |
+
|
131 |
+
out1 = self.conva1(out) # 5x5, 3
|
132 |
+
_,c1,u1,v1,h1,w1 = out1.shape
|
133 |
+
|
134 |
+
out2 = self.conva2(out1) # 3x3, 9
|
135 |
+
_,c2,u2,v2,h2,w2 = out2.shape
|
136 |
+
|
137 |
+
out2 = self.convb3(out2) # 3x3, 9
|
138 |
+
|
139 |
+
tout1 = F.upsample(out2.view(b,c,u2,v2,-1),(u1,v1,h2*w2),mode='trilinear').view(b,c,u1,v1,h2,w2) # 5x5
|
140 |
+
tout1 = F.upsample(tout1.view(b,c,-1,h2,w2),(u1*v1,h1,w1),mode='trilinear').view(b,c,u1,v1,h1,w1) # 5x5
|
141 |
+
out1 = tout1 + out1
|
142 |
+
out1 = self.convb2(out1)
|
143 |
+
|
144 |
+
tout = F.upsample(out1.view(b,c,u1,v1,-1),(u,v,h1*w1),mode='trilinear').view(b,c,u,v,h1,w1)
|
145 |
+
tout = F.upsample(tout.view(b,c,-1,h1,w1),(u*v,h,w),mode='trilinear').view(b,c,u,v,h,w)
|
146 |
+
out = tout + out
|
147 |
+
out = self.convb1(out)
|
148 |
+
|
149 |
+
return out
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
class projfeat4d(torch.nn.Module):
|
154 |
+
'''
|
155 |
+
Turn 3d projection into 2d projection
|
156 |
+
'''
|
157 |
+
def __init__(self, in_planes, out_planes, stride, with_bn=True,groups=1):
|
158 |
+
super(projfeat4d, self).__init__()
|
159 |
+
self.with_bn = with_bn
|
160 |
+
self.stride = stride
|
161 |
+
self.conv1 = nn.Conv3d(in_planes, out_planes, 1, (stride,stride,1), padding=0,bias=not with_bn,groups=groups)
|
162 |
+
self.bn = nn.BatchNorm3d(out_planes)
|
163 |
+
|
164 |
+
def forward(self,x):
|
165 |
+
b,c,u,v,h,w = x.size()
|
166 |
+
x = self.conv1(x.view(b,c,u,v,h*w))
|
167 |
+
if self.with_bn:
|
168 |
+
x = self.bn(x)
|
169 |
+
_,c,u,v,_ = x.shape
|
170 |
+
x = x.view(b,c,u,v,h,w)
|
171 |
+
return x
|
172 |
+
|
173 |
+
class sepConv4d(torch.nn.Module):
|
174 |
+
'''
|
175 |
+
Separable 4d convolution block as 2 3D convolutions
|
176 |
+
'''
|
177 |
+
def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, ksize=3, full=True,groups=1):
|
178 |
+
super(sepConv4d, self).__init__()
|
179 |
+
bias = not with_bn
|
180 |
+
self.isproj = False
|
181 |
+
self.stride = stride[0]
|
182 |
+
expand = 1
|
183 |
+
|
184 |
+
if with_bn:
|
185 |
+
if in_planes != out_planes:
|
186 |
+
self.isproj = True
|
187 |
+
self.proj = nn.Sequential(nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups),
|
188 |
+
nn.BatchNorm2d(out_planes))
|
189 |
+
if full:
|
190 |
+
self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups),
|
191 |
+
nn.BatchNorm3d(in_planes))
|
192 |
+
else:
|
193 |
+
self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups),
|
194 |
+
nn.BatchNorm3d(in_planes))
|
195 |
+
self.conv2 = nn.Sequential(nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups),
|
196 |
+
nn.BatchNorm3d(in_planes*expand))
|
197 |
+
else:
|
198 |
+
if in_planes != out_planes:
|
199 |
+
self.isproj = True
|
200 |
+
self.proj = nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups)
|
201 |
+
if full:
|
202 |
+
self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups)
|
203 |
+
else:
|
204 |
+
self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups)
|
205 |
+
self.conv2 = nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups)
|
206 |
+
self.relu = nn.ReLU(inplace=True)
|
207 |
+
|
208 |
+
#@profile
|
209 |
+
def forward(self,x):
|
210 |
+
b,c,u,v,h,w = x.shape
|
211 |
+
x = self.conv2(x.view(b,c,u,v,-1))
|
212 |
+
b,c,u,v,_ = x.shape
|
213 |
+
x = self.relu(x)
|
214 |
+
x = self.conv1(x.view(b,c,-1,h,w))
|
215 |
+
b,c,_,h,w = x.shape
|
216 |
+
|
217 |
+
if self.isproj:
|
218 |
+
x = self.proj(x.view(b,c,-1,w))
|
219 |
+
x = x.view(b,-1,u,v,h,w)
|
220 |
+
return x
|
221 |
+
|
222 |
+
|
223 |
+
class sepConv4dBlock(torch.nn.Module):
|
224 |
+
'''
|
225 |
+
Separable 4d convolution block as 2 2D convolutions and a projection
|
226 |
+
layer
|
227 |
+
'''
|
228 |
+
def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, full=True,groups=1):
|
229 |
+
super(sepConv4dBlock, self).__init__()
|
230 |
+
if in_planes == out_planes and stride==(1,1,1):
|
231 |
+
self.downsample = None
|
232 |
+
else:
|
233 |
+
if full:
|
234 |
+
self.downsample = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn,ksize=1, full=full,groups=groups)
|
235 |
+
else:
|
236 |
+
self.downsample = projfeat4d(in_planes, out_planes,stride[0], with_bn=with_bn,groups=groups)
|
237 |
+
self.conv1 = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn, full=full ,groups=groups)
|
238 |
+
self.conv2 = sepConv4d(out_planes, out_planes,(1,1,1), with_bn=with_bn, full=full,groups=groups)
|
239 |
+
self.relu1 = nn.ReLU(inplace=True)
|
240 |
+
self.relu2 = nn.ReLU(inplace=True)
|
241 |
+
|
242 |
+
#@profile
|
243 |
+
def forward(self,x):
|
244 |
+
out = self.relu1(self.conv1(x))
|
245 |
+
if self.downsample:
|
246 |
+
x = self.downsample(x)
|
247 |
+
out = self.relu2(x + self.conv2(out))
|
248 |
+
return out
|
249 |
+
|
250 |
+
|
251 |
+
##import torch.backends.cudnn as cudnn
|
252 |
+
##cudnn.benchmark = True
|
253 |
+
#import time
|
254 |
+
##im = torch.randn(9,64,9,160,224).cuda()
|
255 |
+
##net = torch.nn.Conv3d(64, 64, 3).cuda()
|
256 |
+
##net = Conv4d(1,1,3,bias=True,pre_permuted_filters=True).cuda()
|
257 |
+
##net = sepConv4dBlock(2,2,stride=(1,1,1)).cuda()
|
258 |
+
#
|
259 |
+
##im = torch.randn(1,16,9,9,96,320).cuda()
|
260 |
+
##net = sepConv4d(16,16,with_bn=False).cuda()
|
261 |
+
#
|
262 |
+
##im = torch.randn(1,16,81,96,320).cuda()
|
263 |
+
##net = torch.nn.Conv3d(16,16,(1,3,3),padding=(0,1,1)).cuda()
|
264 |
+
#
|
265 |
+
##im = torch.randn(1,16,9,9,96*320).cuda()
|
266 |
+
##net = torch.nn.Conv3d(16,16,(3,3,1),padding=(1,1,0)).cuda()
|
267 |
+
#
|
268 |
+
##im = torch.randn(10000,10,9,9).cuda()
|
269 |
+
##net = torch.nn.Conv2d(10,10,3,padding=1).cuda()
|
270 |
+
#
|
271 |
+
##im = torch.randn(81,16,96,320).cuda()
|
272 |
+
##net = torch.nn.Conv2d(16,16,3,padding=1).cuda()
|
273 |
+
#c= int(16 *1)
|
274 |
+
#cp = int(16 *1)
|
275 |
+
#h=int(96 *4)
|
276 |
+
#w=int(320 *4)
|
277 |
+
#k=3
|
278 |
+
#im = torch.randn(1,c,h,w).cuda()
|
279 |
+
#net = torch.nn.Conv2d(c,cp,k,padding=k//2).cuda()
|
280 |
+
#
|
281 |
+
#im2 = torch.randn(cp,k*k*c).cuda()
|
282 |
+
#im1 = F.unfold(im, (k,k), padding=k//2)[0]
|
283 |
+
#
|
284 |
+
#
|
285 |
+
#net(im)
|
286 |
+
#net(im)
|
287 |
+
#torch.mm(im2,im1)
|
288 |
+
#torch.mm(im2,im1)
|
289 |
+
#torch.cuda.synchronize()
|
290 |
+
#beg = time.time()
|
291 |
+
#for i in range(100):
|
292 |
+
# net(im)
|
293 |
+
# #im1 = F.unfold(im, (k,k), padding=k//2)[0]
|
294 |
+
# torch.mm(im2,im1)
|
295 |
+
#torch.cuda.synchronize()
|
296 |
+
#print('%f'%((time.time()-beg)*10.))
|
expansion/models/submodule.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.utils.data
|
5 |
+
from torch.autograd import Variable
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
import pdb
|
10 |
+
|
11 |
+
class residualBlock(nn.Module):
|
12 |
+
expansion = 1
|
13 |
+
|
14 |
+
def __init__(self, in_channels, n_filters, stride=1, downsample=None,dilation=1,with_bn=True):
|
15 |
+
super(residualBlock, self).__init__()
|
16 |
+
if dilation > 1:
|
17 |
+
padding = dilation
|
18 |
+
else:
|
19 |
+
padding = 1
|
20 |
+
|
21 |
+
if with_bn:
|
22 |
+
self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, padding, dilation=dilation)
|
23 |
+
self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1)
|
24 |
+
else:
|
25 |
+
self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, padding, dilation=dilation,with_bn=False)
|
26 |
+
self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, with_bn=False)
|
27 |
+
self.downsample = downsample
|
28 |
+
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
residual = x
|
32 |
+
|
33 |
+
out = self.convbnrelu1(x)
|
34 |
+
out = self.convbn2(out)
|
35 |
+
|
36 |
+
if self.downsample is not None:
|
37 |
+
residual = self.downsample(x)
|
38 |
+
|
39 |
+
out += residual
|
40 |
+
return self.relu(out)
|
41 |
+
|
42 |
+
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
43 |
+
return nn.Sequential(
|
44 |
+
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
45 |
+
padding=padding, dilation=dilation, bias=True),
|
46 |
+
nn.BatchNorm2d(out_planes),
|
47 |
+
nn.LeakyReLU(0.1,inplace=True))
|
48 |
+
|
49 |
+
|
50 |
+
class conv2DBatchNorm(nn.Module):
|
51 |
+
def __init__(self, in_channels, n_filters, k_size, stride, padding, dilation=1, with_bn=True):
|
52 |
+
super(conv2DBatchNorm, self).__init__()
|
53 |
+
bias = not with_bn
|
54 |
+
|
55 |
+
if dilation > 1:
|
56 |
+
conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
|
57 |
+
padding=padding, stride=stride, bias=bias, dilation=dilation)
|
58 |
+
|
59 |
+
else:
|
60 |
+
conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
|
61 |
+
padding=padding, stride=stride, bias=bias, dilation=1)
|
62 |
+
|
63 |
+
|
64 |
+
if with_bn:
|
65 |
+
self.cb_unit = nn.Sequential(conv_mod,
|
66 |
+
nn.BatchNorm2d(int(n_filters)),)
|
67 |
+
else:
|
68 |
+
self.cb_unit = nn.Sequential(conv_mod,)
|
69 |
+
|
70 |
+
def forward(self, inputs):
|
71 |
+
outputs = self.cb_unit(inputs)
|
72 |
+
return outputs
|
73 |
+
|
74 |
+
class conv2DBatchNormRelu(nn.Module):
|
75 |
+
def __init__(self, in_channels, n_filters, k_size, stride, padding, dilation=1, with_bn=True):
|
76 |
+
super(conv2DBatchNormRelu, self).__init__()
|
77 |
+
bias = not with_bn
|
78 |
+
if dilation > 1:
|
79 |
+
conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
|
80 |
+
padding=padding, stride=stride, bias=bias, dilation=dilation)
|
81 |
+
|
82 |
+
else:
|
83 |
+
conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
|
84 |
+
padding=padding, stride=stride, bias=bias, dilation=1)
|
85 |
+
|
86 |
+
if with_bn:
|
87 |
+
self.cbr_unit = nn.Sequential(conv_mod,
|
88 |
+
nn.BatchNorm2d(int(n_filters)),
|
89 |
+
nn.LeakyReLU(0.1, inplace=True),)
|
90 |
+
else:
|
91 |
+
self.cbr_unit = nn.Sequential(conv_mod,
|
92 |
+
nn.LeakyReLU(0.1, inplace=True),)
|
93 |
+
|
94 |
+
def forward(self, inputs):
|
95 |
+
outputs = self.cbr_unit(inputs)
|
96 |
+
return outputs
|
97 |
+
|
98 |
+
class pyramidPooling(nn.Module):
|
99 |
+
|
100 |
+
def __init__(self, in_channels, with_bn=True, levels=4):
|
101 |
+
super(pyramidPooling, self).__init__()
|
102 |
+
self.levels = levels
|
103 |
+
|
104 |
+
self.paths = []
|
105 |
+
for i in range(levels):
|
106 |
+
self.paths.append(conv2DBatchNormRelu(in_channels, in_channels, 1, 1, 0, with_bn=with_bn))
|
107 |
+
self.path_module_list = nn.ModuleList(self.paths)
|
108 |
+
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
h, w = x.shape[2:]
|
112 |
+
|
113 |
+
k_sizes = []
|
114 |
+
strides = []
|
115 |
+
for pool_size in np.linspace(1,min(h,w)//2,self.levels,dtype=int):
|
116 |
+
k_sizes.append((int(h/pool_size), int(w/pool_size)))
|
117 |
+
strides.append((int(h/pool_size), int(w/pool_size)))
|
118 |
+
k_sizes = k_sizes[::-1]
|
119 |
+
strides = strides[::-1]
|
120 |
+
|
121 |
+
pp_sum = x
|
122 |
+
|
123 |
+
for i, module in enumerate(self.path_module_list):
|
124 |
+
out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
|
125 |
+
out = module(out)
|
126 |
+
out = F.upsample(out, size=(h,w), mode='bilinear')
|
127 |
+
pp_sum = pp_sum + 1./self.levels*out
|
128 |
+
pp_sum = self.relu(pp_sum/2.)
|
129 |
+
|
130 |
+
return pp_sum
|
131 |
+
|
132 |
+
class pspnet(nn.Module):
|
133 |
+
"""
|
134 |
+
Modified PSPNet. https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/pspnet.py
|
135 |
+
"""
|
136 |
+
def __init__(self, is_proj=True,groups=1):
|
137 |
+
super(pspnet, self).__init__()
|
138 |
+
self.inplanes = 32
|
139 |
+
self.is_proj = is_proj
|
140 |
+
|
141 |
+
# Encoder
|
142 |
+
self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16,
|
143 |
+
padding=1, stride=2)
|
144 |
+
self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16,
|
145 |
+
padding=1, stride=1)
|
146 |
+
self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32,
|
147 |
+
padding=1, stride=1)
|
148 |
+
# Vanilla Residual Blocks
|
149 |
+
self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
|
150 |
+
self.res_block5 = self._make_layer(residualBlock,128,1,stride=2)
|
151 |
+
self.res_block6 = self._make_layer(residualBlock,128,1,stride=2)
|
152 |
+
self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
|
153 |
+
self.pyramid_pooling = pyramidPooling(128, levels=3)
|
154 |
+
|
155 |
+
# Iconvs
|
156 |
+
self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2),
|
157 |
+
conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
|
158 |
+
padding=1, stride=1))
|
159 |
+
self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
|
160 |
+
padding=1, stride=1)
|
161 |
+
self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2),
|
162 |
+
conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
|
163 |
+
padding=1, stride=1))
|
164 |
+
self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
|
165 |
+
padding=1, stride=1)
|
166 |
+
self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2),
|
167 |
+
conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
|
168 |
+
padding=1, stride=1))
|
169 |
+
self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
|
170 |
+
padding=1, stride=1)
|
171 |
+
self.upconv3 = nn.Sequential(nn.Upsample(scale_factor=2),
|
172 |
+
conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
|
173 |
+
padding=1, stride=1))
|
174 |
+
self.iconv2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64,
|
175 |
+
padding=1, stride=1)
|
176 |
+
|
177 |
+
if self.is_proj:
|
178 |
+
self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
|
179 |
+
self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
|
180 |
+
self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
|
181 |
+
self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
|
182 |
+
self.proj2 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
|
183 |
+
|
184 |
+
for m in self.modules():
|
185 |
+
if isinstance(m, nn.Conv2d):
|
186 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
187 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
188 |
+
if hasattr(m.bias,'data'):
|
189 |
+
m.bias.data.zero_()
|
190 |
+
|
191 |
+
|
192 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
193 |
+
downsample = None
|
194 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
195 |
+
downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,
|
196 |
+
kernel_size=1, stride=stride, bias=False),
|
197 |
+
nn.BatchNorm2d(planes * block.expansion),)
|
198 |
+
layers = []
|
199 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
200 |
+
self.inplanes = planes * block.expansion
|
201 |
+
for i in range(1, blocks):
|
202 |
+
layers.append(block(self.inplanes, planes))
|
203 |
+
return nn.Sequential(*layers)
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
# H, W -> H/2, W/2
|
207 |
+
conv1 = self.convbnrelu1_1(x)
|
208 |
+
conv1 = self.convbnrelu1_2(conv1)
|
209 |
+
conv1 = self.convbnrelu1_3(conv1)
|
210 |
+
|
211 |
+
## H/2, W/2 -> H/4, W/4
|
212 |
+
pool1 = F.max_pool2d(conv1, 3, 2, 1)
|
213 |
+
|
214 |
+
# H/4, W/4 -> H/16, W/16
|
215 |
+
rconv3 = self.res_block3(pool1)
|
216 |
+
conv4 = self.res_block5(rconv3)
|
217 |
+
conv5 = self.res_block6(conv4)
|
218 |
+
conv6 = self.res_block7(conv5)
|
219 |
+
conv6 = self.pyramid_pooling(conv6)
|
220 |
+
|
221 |
+
conv6x = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]],mode='bilinear')
|
222 |
+
concat5 = torch.cat((conv5,self.upconv6[1](conv6x)),dim=1)
|
223 |
+
conv5 = self.iconv5(concat5)
|
224 |
+
|
225 |
+
conv5x = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]],mode='bilinear')
|
226 |
+
concat4 = torch.cat((conv4,self.upconv5[1](conv5x)),dim=1)
|
227 |
+
conv4 = self.iconv4(concat4)
|
228 |
+
|
229 |
+
conv4x = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]],mode='bilinear')
|
230 |
+
concat3 = torch.cat((rconv3,self.upconv4[1](conv4x)),dim=1)
|
231 |
+
conv3 = self.iconv3(concat3)
|
232 |
+
|
233 |
+
conv3x = F.upsample(conv3, [pool1.size()[2],pool1.size()[3]],mode='bilinear')
|
234 |
+
concat2 = torch.cat((pool1,self.upconv3[1](conv3x)),dim=1)
|
235 |
+
conv2 = self.iconv2(concat2)
|
236 |
+
|
237 |
+
if self.is_proj:
|
238 |
+
proj6 = self.proj6(conv6)
|
239 |
+
proj5 = self.proj5(conv5)
|
240 |
+
proj4 = self.proj4(conv4)
|
241 |
+
proj3 = self.proj3(conv3)
|
242 |
+
proj2 = self.proj2(conv2)
|
243 |
+
return proj6,proj5,proj4,proj3,proj2
|
244 |
+
else:
|
245 |
+
return conv6, conv5, conv4, conv3, conv2
|
246 |
+
|
247 |
+
|
248 |
+
class pspnet_s(nn.Module):
|
249 |
+
"""
|
250 |
+
Modified PSPNet. https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/pspnet.py
|
251 |
+
"""
|
252 |
+
def __init__(self, is_proj=True,groups=1):
|
253 |
+
super(pspnet_s, self).__init__()
|
254 |
+
self.inplanes = 32
|
255 |
+
self.is_proj = is_proj
|
256 |
+
|
257 |
+
# Encoder
|
258 |
+
self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16,
|
259 |
+
padding=1, stride=2)
|
260 |
+
self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16,
|
261 |
+
padding=1, stride=1)
|
262 |
+
self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32,
|
263 |
+
padding=1, stride=1)
|
264 |
+
# Vanilla Residual Blocks
|
265 |
+
self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
|
266 |
+
self.res_block5 = self._make_layer(residualBlock,128,1,stride=2)
|
267 |
+
self.res_block6 = self._make_layer(residualBlock,128,1,stride=2)
|
268 |
+
self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
|
269 |
+
self.pyramid_pooling = pyramidPooling(128, levels=3)
|
270 |
+
|
271 |
+
# Iconvs
|
272 |
+
self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2),
|
273 |
+
conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
|
274 |
+
padding=1, stride=1))
|
275 |
+
self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
|
276 |
+
padding=1, stride=1)
|
277 |
+
self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2),
|
278 |
+
conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
|
279 |
+
padding=1, stride=1))
|
280 |
+
self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
|
281 |
+
padding=1, stride=1)
|
282 |
+
self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2),
|
283 |
+
conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
|
284 |
+
padding=1, stride=1))
|
285 |
+
self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
|
286 |
+
padding=1, stride=1)
|
287 |
+
#self.upconv3 = nn.Sequential(nn.Upsample(scale_factor=2),
|
288 |
+
# conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
|
289 |
+
# padding=1, stride=1))
|
290 |
+
#self.iconv2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64,
|
291 |
+
# padding=1, stride=1)
|
292 |
+
|
293 |
+
if self.is_proj:
|
294 |
+
self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
|
295 |
+
self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
|
296 |
+
self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
|
297 |
+
self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
|
298 |
+
#self.proj2 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
|
299 |
+
|
300 |
+
for m in self.modules():
|
301 |
+
if isinstance(m, nn.Conv2d):
|
302 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
303 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
304 |
+
if hasattr(m.bias,'data'):
|
305 |
+
m.bias.data.zero_()
|
306 |
+
|
307 |
+
|
308 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
309 |
+
downsample = None
|
310 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
311 |
+
downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,
|
312 |
+
kernel_size=1, stride=stride, bias=False),
|
313 |
+
nn.BatchNorm2d(planes * block.expansion),)
|
314 |
+
layers = []
|
315 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
316 |
+
self.inplanes = planes * block.expansion
|
317 |
+
for i in range(1, blocks):
|
318 |
+
layers.append(block(self.inplanes, planes))
|
319 |
+
return nn.Sequential(*layers)
|
320 |
+
|
321 |
+
def forward(self, x):
|
322 |
+
# H, W -> H/2, W/2
|
323 |
+
conv1 = self.convbnrelu1_1(x)
|
324 |
+
conv1 = self.convbnrelu1_2(conv1)
|
325 |
+
conv1 = self.convbnrelu1_3(conv1)
|
326 |
+
|
327 |
+
## H/2, W/2 -> H/4, W/4
|
328 |
+
pool1 = F.max_pool2d(conv1, 3, 2, 1)
|
329 |
+
|
330 |
+
# H/4, W/4 -> H/16, W/16
|
331 |
+
rconv3 = self.res_block3(pool1)
|
332 |
+
conv4 = self.res_block5(rconv3)
|
333 |
+
conv5 = self.res_block6(conv4)
|
334 |
+
conv6 = self.res_block7(conv5)
|
335 |
+
conv6 = self.pyramid_pooling(conv6)
|
336 |
+
|
337 |
+
conv6x = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]],mode='bilinear')
|
338 |
+
concat5 = torch.cat((conv5,self.upconv6[1](conv6x)),dim=1)
|
339 |
+
conv5 = self.iconv5(concat5)
|
340 |
+
|
341 |
+
conv5x = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]],mode='bilinear')
|
342 |
+
concat4 = torch.cat((conv4,self.upconv5[1](conv5x)),dim=1)
|
343 |
+
conv4 = self.iconv4(concat4)
|
344 |
+
|
345 |
+
conv4x = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]],mode='bilinear')
|
346 |
+
concat3 = torch.cat((rconv3,self.upconv4[1](conv4x)),dim=1)
|
347 |
+
conv3 = self.iconv3(concat3)
|
348 |
+
|
349 |
+
#conv3x = F.upsample(conv3, [pool1.size()[2],pool1.size()[3]],mode='bilinear')
|
350 |
+
#concat2 = torch.cat((pool1,self.upconv3[1](conv3x)),dim=1)
|
351 |
+
#conv2 = self.iconv2(concat2)
|
352 |
+
|
353 |
+
if self.is_proj:
|
354 |
+
proj6 = self.proj6(conv6)
|
355 |
+
proj5 = self.proj5(conv5)
|
356 |
+
proj4 = self.proj4(conv4)
|
357 |
+
proj3 = self.proj3(conv3)
|
358 |
+
# proj2 = self.proj2(conv2)
|
359 |
+
# return proj6,proj5,proj4,proj3,proj2
|
360 |
+
return proj6,proj5,proj4,proj3
|
361 |
+
else:
|
362 |
+
# return conv6, conv5, conv4, conv3, conv2
|
363 |
+
return conv6, conv5, conv4, conv3
|
364 |
+
|
365 |
+
class bfmodule(nn.Module):
|
366 |
+
def __init__(self, inplanes, outplanes):
|
367 |
+
super(bfmodule, self).__init__()
|
368 |
+
self.proj = conv2DBatchNormRelu(in_channels=inplanes,k_size=1,n_filters=64,padding=0,stride=1)
|
369 |
+
self.inplanes = 64
|
370 |
+
# Vanilla Residual Blocks
|
371 |
+
self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
|
372 |
+
self.res_block5 = self._make_layer(residualBlock,64,1,stride=2)
|
373 |
+
self.res_block6 = self._make_layer(residualBlock,64,1,stride=2)
|
374 |
+
self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
|
375 |
+
self.pyramid_pooling = pyramidPooling(128, levels=3)
|
376 |
+
# Iconvs
|
377 |
+
self.upconv6 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
|
378 |
+
padding=1, stride=1)
|
379 |
+
self.upconv5 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
|
380 |
+
padding=1, stride=1)
|
381 |
+
self.upconv4 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
|
382 |
+
padding=1, stride=1)
|
383 |
+
self.upconv3 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
|
384 |
+
padding=1, stride=1)
|
385 |
+
self.iconv5 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
|
386 |
+
padding=1, stride=1)
|
387 |
+
self.iconv4 = conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64,
|
388 |
+
padding=1, stride=1)
|
389 |
+
self.iconv3 = conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64,
|
390 |
+
padding=1, stride=1)
|
391 |
+
self.iconv2 = nn.Sequential(conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64,
|
392 |
+
padding=1, stride=1),
|
393 |
+
nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True))
|
394 |
+
|
395 |
+
self.proj6 = nn.Conv2d(128, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
|
396 |
+
self.proj5 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
|
397 |
+
self.proj4 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
|
398 |
+
self.proj3 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
|
399 |
+
|
400 |
+
for m in self.modules():
|
401 |
+
if isinstance(m, nn.Conv2d):
|
402 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
403 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
404 |
+
if hasattr(m.bias,'data'):
|
405 |
+
m.bias.data.zero_()
|
406 |
+
|
407 |
+
|
408 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
409 |
+
downsample = None
|
410 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
411 |
+
downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,
|
412 |
+
kernel_size=1, stride=stride, bias=False),
|
413 |
+
nn.BatchNorm2d(planes * block.expansion),)
|
414 |
+
layers = []
|
415 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
416 |
+
self.inplanes = planes * block.expansion
|
417 |
+
for i in range(1, blocks):
|
418 |
+
layers.append(block(self.inplanes, planes))
|
419 |
+
return nn.Sequential(*layers)
|
420 |
+
|
421 |
+
def forward(self, x):
|
422 |
+
proj = self.proj(x) # 4x
|
423 |
+
rconv3 = self.res_block3(proj) #8x
|
424 |
+
conv4 = self.res_block5(rconv3) #16x
|
425 |
+
conv5 = self.res_block6(conv4) #32x
|
426 |
+
conv6 = self.res_block7(conv5) #64x
|
427 |
+
conv6 = self.pyramid_pooling(conv6) #64x
|
428 |
+
pred6 = self.proj6(conv6)
|
429 |
+
|
430 |
+
conv6u = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]], mode='bilinear')
|
431 |
+
concat5 = torch.cat((conv5,self.upconv6(conv6u)),dim=1)
|
432 |
+
conv5 = self.iconv5(concat5) #32x
|
433 |
+
pred5 = self.proj5(conv5)
|
434 |
+
|
435 |
+
conv5u = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]], mode='bilinear')
|
436 |
+
concat4 = torch.cat((conv4,self.upconv5(conv5u)),dim=1)
|
437 |
+
conv4 = self.iconv4(concat4) #16x
|
438 |
+
pred4 = self.proj4(conv4)
|
439 |
+
|
440 |
+
conv4u = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]], mode='bilinear')
|
441 |
+
concat3 = torch.cat((rconv3,self.upconv4(conv4u)),dim=1)
|
442 |
+
conv3 = self.iconv3(concat3) # 8x
|
443 |
+
pred3 = self.proj3(conv3)
|
444 |
+
|
445 |
+
conv3u = F.upsample(conv3, [x.size()[2],x.size()[3]], mode='bilinear')
|
446 |
+
concat2 = torch.cat((proj,self.upconv3(conv3u)),dim=1)
|
447 |
+
pred2 = self.iconv2(concat2) # 4x
|
448 |
+
|
449 |
+
return pred2, pred3, pred4, pred5, pred6
|
450 |
+
|
expansion/submission.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import sys
|
3 |
+
import cv2
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.backends.cudnn as cudnn
|
9 |
+
import torch.optim as optim
|
10 |
+
import torch.nn.functional as F
|
11 |
+
cudnn.benchmark = False
|
12 |
+
|
13 |
+
class Expansion():
|
14 |
+
|
15 |
+
def __init__(self, loadmodel = 'pretrained_models/optical_expansion/robust.pth', testres = 1, maxdisp = 256, fac = 1):
|
16 |
+
|
17 |
+
maxw,maxh = [int(testres*1280), int(testres*384)]
|
18 |
+
|
19 |
+
max_h = int(maxh // 64 * 64)
|
20 |
+
max_w = int(maxw // 64 * 64)
|
21 |
+
if max_h < maxh: max_h += 64
|
22 |
+
if max_w < maxw: max_w += 64
|
23 |
+
maxh = max_h
|
24 |
+
maxw = max_w
|
25 |
+
|
26 |
+
mean_L = [[0.33,0.33,0.33]]
|
27 |
+
mean_R = [[0.33,0.33,0.33]]
|
28 |
+
|
29 |
+
# construct model, VCN-expansion
|
30 |
+
from expansion.models.VCN_exp import VCN
|
31 |
+
model = VCN([1, maxw, maxh], md=[int(4*(maxdisp/256)),4,4,4,4], fac=fac,
|
32 |
+
exp_unc=('robust' in loadmodel)) # expansion uncertainty only in the new model
|
33 |
+
model = nn.DataParallel(model, device_ids=[0])
|
34 |
+
model.cuda()
|
35 |
+
|
36 |
+
if loadmodel is not None:
|
37 |
+
pretrained_dict = torch.load(loadmodel)
|
38 |
+
mean_L=pretrained_dict['mean_L']
|
39 |
+
mean_R=pretrained_dict['mean_R']
|
40 |
+
pretrained_dict['state_dict'] = {k:v for k,v in pretrained_dict['state_dict'].items()}
|
41 |
+
model.load_state_dict(pretrained_dict['state_dict'],strict=False)
|
42 |
+
else:
|
43 |
+
print('dry run')
|
44 |
+
|
45 |
+
model.eval()
|
46 |
+
# resize
|
47 |
+
maxh = 256
|
48 |
+
maxw = 256
|
49 |
+
max_h = int(maxh // 64 * 64)
|
50 |
+
max_w = int(maxw // 64 * 64)
|
51 |
+
if max_h < maxh: max_h += 64
|
52 |
+
if max_w < maxw: max_w += 64
|
53 |
+
|
54 |
+
# modify module according to inputs
|
55 |
+
from expansion.models.VCN_exp import WarpModule, flow_reg
|
56 |
+
for i in range(len(model.module.reg_modules)):
|
57 |
+
model.module.reg_modules[i] = flow_reg([1,max_w//(2**(6-i)), max_h//(2**(6-i))],
|
58 |
+
ent=getattr(model.module, 'flow_reg%d'%2**(6-i)).ent,\
|
59 |
+
maxdisp=getattr(model.module, 'flow_reg%d'%2**(6-i)).md,\
|
60 |
+
fac=getattr(model.module, 'flow_reg%d'%2**(6-i)).fac).cuda()
|
61 |
+
for i in range(len(model.module.warp_modules)):
|
62 |
+
model.module.warp_modules[i] = WarpModule([1,max_w//(2**(6-i)), max_h//(2**(6-i))]).cuda()
|
63 |
+
|
64 |
+
mean_L = torch.from_numpy(np.asarray(mean_L).astype(np.float32).mean(0)[np.newaxis,:,np.newaxis,np.newaxis]).cuda()
|
65 |
+
mean_R = torch.from_numpy(np.asarray(mean_R).astype(np.float32).mean(0)[np.newaxis,:,np.newaxis,np.newaxis]).cuda()
|
66 |
+
|
67 |
+
self.max_h = max_h
|
68 |
+
self.max_w = max_w
|
69 |
+
self.model = model
|
70 |
+
self.mean_L = mean_L
|
71 |
+
self.mean_R = mean_R
|
72 |
+
|
73 |
+
def run(self, imgL_o, imgR_o):
|
74 |
+
model = self.model
|
75 |
+
mean_L = self.mean_L
|
76 |
+
mean_R = self.mean_R
|
77 |
+
|
78 |
+
imgL_o[imgL_o<-1] = -1
|
79 |
+
imgL_o[imgL_o>1] = 1
|
80 |
+
imgR_o[imgR_o<-1] = -1
|
81 |
+
imgR_o[imgR_o>1] = 1
|
82 |
+
imgL = (imgL_o+1.)*0.5-mean_L
|
83 |
+
imgR = (imgR_o*1.)*0.5-mean_R
|
84 |
+
|
85 |
+
with torch.no_grad():
|
86 |
+
imgLR = torch.cat([imgL,imgR],0)
|
87 |
+
model.eval()
|
88 |
+
torch.cuda.synchronize()
|
89 |
+
rts = model(imgLR)
|
90 |
+
torch.cuda.synchronize()
|
91 |
+
flow, occ, logmid, logexp = rts
|
92 |
+
|
93 |
+
torch.cuda.empty_cache()
|
94 |
+
|
95 |
+
return flow, logexp
|
expansion/utils/__init__.py
ADDED
File without changes
|
expansion/utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (157 Bytes). View file
|
|
expansion/utils/__pycache__/flowlib.cpython-38.pyc
ADDED
Binary file (16 kB). View file
|
|
expansion/utils/__pycache__/io.cpython-38.pyc
ADDED
Binary file (3.97 kB). View file
|
|
expansion/utils/__pycache__/pfm.cpython-38.pyc
ADDED
Binary file (1.65 kB). View file
|
|