diff --git a/LICENSE b/LICENSE
new file mode 100755
index 0000000000000000000000000000000000000000..7f570ae807e29e43c7d7b8c56954bc2b6c969a0b
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Yuki Endo
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..bab85c36e531f0723147e22554a45c2f9a7bfbd2
--- /dev/null
+++ b/README.md
@@ -0,0 +1,51 @@
+# User-Controllable Latent Transformer for StyleGAN Image Layout Editing
+
+
+
+
+
+
+This repository contains our implementation of the following paper:
+
+Yuki Endo: "User-Controllable Latent Transformer for StyleGAN Image Layout Editing," Computer Graphpics Forum (Pacific Graphics 2022) [[Project](http://www.cgg.cs.tsukuba.ac.jp/~endo/projects/UserControllableLT)] [[PDF (preprint)]()]
+
+## Prerequisites
+1. Python 3.8
+2. PyTorch 1.9.0
+3. Flask
+4. Others (see env.yml)
+
+## Preparation
+Download and decompress our pre-trained models.
+
+## Inference with our pre-trained models
+
+We provide an interactive interface based on Flask. This interface can be locally launched with
+```
+python interface/flask_app.py --checkpoint_path=pretrained_models/latent_transformer/cat.pt
+```
+The interface can be accessed via http://localhost:8000/.
+
+## Training
+The latent transformer can be trained with
+```
+python scripts/train.py --exp_dir=results --stylegan_weights=pretrained_models/stylegan2-cat-config-f.pt
+```
+
+## Citation
+Please cite our paper if you find the code useful:
+```
+@Article{endoPG2022,
+Title = {User-Controllable Latent Transformer for StyleGAN Image Layout Editing},
+Author = {Yuki Endo},
+Journal = {Computer Graphics Forum},
+volume = {},
+number = {},
+pages = {},
+doi = {},
+Year = {2022}
+}
+```
+
+## Acknowledgements
+This code heavily borrows from the [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel) and [expansion](https://github.com/gengshan-y/expansion) repositories.
diff --git a/criteria/__init__.py b/criteria/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/criteria/lpips/__init__.py b/criteria/lpips/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/criteria/lpips/lpips.py b/criteria/lpips/lpips.py
new file mode 100755
index 0000000000000000000000000000000000000000..36b220b37a60c08391a2e31b24c4fde49b726a8a
--- /dev/null
+++ b/criteria/lpips/lpips.py
@@ -0,0 +1,35 @@
+import torch
+import torch.nn as nn
+
+from criteria.lpips.networks import get_network, LinLayers
+from criteria.lpips.utils import get_state_dict
+
+
+class LPIPS(nn.Module):
+ r"""Creates a criterion that measures
+ Learned Perceptual Image Patch Similarity (LPIPS).
+ Arguments:
+ net_type (str): the network type to compare the features:
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
+ version (str): the version of LPIPS. Default: 0.1.
+ """
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
+
+ assert version in ['0.1'], 'v0.1 is only supported now'
+
+ super(LPIPS, self).__init__()
+
+ # pretrained network
+ self.net = get_network(net_type).to("cuda")
+
+ # linear layers
+ self.lin = LinLayers(self.net.n_channels_list).to("cuda")
+ self.lin.load_state_dict(get_state_dict(net_type, version))
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
+ feat_x, feat_y = self.net(x), self.net(y)
+
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
+
+ return torch.sum(torch.cat(res, 0)) / x.shape[0]
diff --git a/criteria/lpips/networks.py b/criteria/lpips/networks.py
new file mode 100755
index 0000000000000000000000000000000000000000..7b182d0dfd3f7da57f13fdb9d724efd2f2ec5615
--- /dev/null
+++ b/criteria/lpips/networks.py
@@ -0,0 +1,96 @@
+from typing import Sequence
+
+from itertools import chain
+
+import torch
+import torch.nn as nn
+from torchvision import models
+
+from criteria.lpips.utils import normalize_activation
+
+
+def get_network(net_type: str):
+ if net_type == 'alex':
+ return AlexNet()
+ elif net_type == 'squeeze':
+ return SqueezeNet()
+ elif net_type == 'vgg':
+ return VGG16()
+ else:
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
+
+
+class LinLayers(nn.ModuleList):
+ def __init__(self, n_channels_list: Sequence[int]):
+ super(LinLayers, self).__init__([
+ nn.Sequential(
+ nn.Identity(),
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
+ ) for nc in n_channels_list
+ ])
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+
+class BaseNet(nn.Module):
+ def __init__(self):
+ super(BaseNet, self).__init__()
+
+ # register buffer
+ self.register_buffer(
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
+ self.register_buffer(
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
+
+ def set_requires_grad(self, state: bool):
+ for param in chain(self.parameters(), self.buffers()):
+ param.requires_grad = state
+
+ def z_score(self, x: torch.Tensor):
+ return (x - self.mean) / self.std
+
+ def forward(self, x: torch.Tensor):
+ x = self.z_score(x)
+
+ output = []
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
+ x = layer(x)
+ if i in self.target_layers:
+ output.append(normalize_activation(x))
+ if len(output) == len(self.target_layers):
+ break
+ return output
+
+
+class SqueezeNet(BaseNet):
+ def __init__(self):
+ super(SqueezeNet, self).__init__()
+
+ self.layers = models.squeezenet1_1(True).features
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
+
+ self.set_requires_grad(False)
+
+
+class AlexNet(BaseNet):
+ def __init__(self):
+ super(AlexNet, self).__init__()
+
+ self.layers = models.alexnet(True).features
+ self.target_layers = [2, 5, 8, 10, 12]
+ self.n_channels_list = [64, 192, 384, 256, 256]
+
+ self.set_requires_grad(False)
+
+
+class VGG16(BaseNet):
+ def __init__(self):
+ super(VGG16, self).__init__()
+
+ self.layers = models.vgg16(True).features
+ self.target_layers = [4, 9, 16, 23, 30]
+ self.n_channels_list = [64, 128, 256, 512, 512]
+
+ self.set_requires_grad(False)
\ No newline at end of file
diff --git a/criteria/lpips/utils.py b/criteria/lpips/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..28f1867edc419a4b679b2775bcfa26ad6c9704b4
--- /dev/null
+++ b/criteria/lpips/utils.py
@@ -0,0 +1,30 @@
+from collections import OrderedDict
+
+import torch
+
+
+def normalize_activation(x, eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
+ return x / (norm_factor + eps)
+
+
+def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
+ # build url
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
+
+ # download
+ old_state_dict = torch.hub.load_state_dict_from_url(
+ url, progress=True,
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
+ )
+
+ # rename keys
+ new_state_dict = OrderedDict()
+ for key, val in old_state_dict.items():
+ new_key = key
+ new_key = new_key.replace('lin', '')
+ new_key = new_key.replace('model.', '')
+ new_state_dict[new_key] = val
+
+ return new_state_dict
diff --git a/docs/teaser.jpg b/docs/teaser.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..9edca9dea6895789be942a061776ead041efdee6
Binary files /dev/null and b/docs/teaser.jpg differ
diff --git a/docs/thumb.gif b/docs/thumb.gif
new file mode 100755
index 0000000000000000000000000000000000000000..80f434f5e416339e96def927957d8ada091b7130
Binary files /dev/null and b/docs/thumb.gif differ
diff --git a/env.yaml b/env.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..95f073b68ae7f500d5c0311e5912350e0fd6c116
--- /dev/null
+++ b/env.yaml
@@ -0,0 +1,380 @@
+name: uclt
+channels:
+ - pytorch
+ - anaconda
+ - nvidia
+ - conda-forge
+ - defaults
+dependencies:
+ - _ipyw_jlab_nb_ext_conf=0.1.0=py38_0
+ - _libgcc_mutex=0.1=conda_forge
+ - _openmp_mutex=4.5=1_llvm
+ - absl-py=0.13.0=pyhd8ed1ab_0
+ - aiohttp=3.7.4.post0=py38h497a2fe_0
+ - albumentations=1.0.3=pyhd8ed1ab_0
+ - alsa-lib=1.2.3=h516909a_0
+ - anaconda-client=1.8.0=py38h06a4308_0
+ - anaconda-navigator=2.0.4=py38_0
+ - anyio=2.2.0=py38h06a4308_1
+ - appdirs=1.4.4=pyh9f0ad1d_0
+ - argon2-cffi=20.1.0=py38h27cfd23_1
+ - async-timeout=3.0.1=py_1000
+ - async_generator=1.10=pyhd3eb1b0_0
+ - attrs=21.2.0=pyhd3eb1b0_0
+ - babel=2.9.1=pyhd3eb1b0_0
+ - backcall=0.2.0=pyhd3eb1b0_0
+ - backports=1.0=pyhd3eb1b0_2
+ - backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0
+ - backports.tempfile=1.0=pyhd3eb1b0_1
+ - backports.weakref=1.0.post1=py_1
+ - beautifulsoup4=4.9.3=pyha847dfd_0
+ - blas=1.0=mkl
+ - bleach=4.0.0=pyhd3eb1b0_0
+ - blinker=1.4=py_1
+ - brotli=1.0.9=h7f98852_5
+ - brotli-bin=1.0.9=h7f98852_5
+ - brotlipy=0.7.0=py38h27cfd23_1003
+ - bzip2=1.0.8=h7b6447c_0
+ - c-ares=1.17.1=h27cfd23_0
+ - ca-certificates=2021.10.8=ha878542_0
+ - cachetools=4.2.2=pyhd8ed1ab_0
+ - cairo=1.16.0=hf32fb01_1
+ - certifi=2021.10.8=py38h578d9bd_1
+ - cffi=1.14.6=py38h400218f_0
+ - chardet=4.0.0=py38h06a4308_1003
+ - click=8.0.1=pyhd3eb1b0_0
+ - cloudpickle=1.6.0=py_0
+ - clyent=1.2.2=py38_1
+ - conda=4.11.0=py38h578d9bd_0
+ - conda-build=3.21.4=py38h06a4308_0
+ - conda-content-trust=0.1.1=pyhd3eb1b0_0
+ - conda-env=2.6.0=1
+ - conda-package-handling=1.7.3=py38h27cfd23_1
+ - conda-repo-cli=1.0.4=pyhd3eb1b0_0
+ - conda-token=0.3.0=pyhd3eb1b0_0
+ - conda-verify=3.4.2=py_1
+ - cryptography=3.4.7=py38hd23ed53_0
+ - cudatoolkit=11.1.74=h6bb024c_0
+ - cycler=0.10.0=py_2
+ - cytoolz=0.11.0=py38h497a2fe_3
+ - dask-core=2021.8.1=pyhd8ed1ab_0
+ - dbus=1.13.18=hb2f20db_0
+ - decorator=5.0.9=pyhd3eb1b0_0
+ - defusedxml=0.7.1=pyhd3eb1b0_0
+ - dill=0.3.4=pyhd8ed1ab_0
+ - dominate=2.6.0=pyhd8ed1ab_0
+ - entrypoints=0.3=py38_0
+ - enum34=1.1.10=py38h32f6830_2
+ - expat=2.4.1=h2531618_2
+ - ffmpeg=4.3.2=hca11adc_0
+ - filelock=3.0.12=pyhd3eb1b0_1
+ - flask=1.1.2=pyh9f0ad1d_0
+ - flask-httpauth=4.4.0=pyhd8ed1ab_0
+ - fontconfig=2.13.1=h6c09931_0
+ - fonttools=4.25.0=pyhd3eb1b0_0
+ - freetype=2.10.4=h5ab3b9f_0
+ - fsspec=2021.7.0=pyhd8ed1ab_0
+ - ftfy=6.0.3=pyhd8ed1ab_0
+ - func_timeout=4.3.5=py_0
+ - future=0.18.2=py38_1
+ - gdown=4.2.0=pyhd8ed1ab_0
+ - geos=3.10.0=h9c3ff4c_0
+ - gettext=0.19.8.1=h0b5b191_1005
+ - git=2.23.0=pl526hacde149_0
+ - glib=2.68.4=h9c3ff4c_0
+ - glib-tools=2.68.4=h9c3ff4c_0
+ - glob2=0.7=pyhd3eb1b0_0
+ - gmp=6.2.1=h58526e2_0
+ - gnutls=3.6.13=h85f3911_1
+ - google-auth=1.35.0=pyh6c4a22f_0
+ - google-auth-oauthlib=0.4.5=pyhd8ed1ab_0
+ - gputil=1.4.0=pyh9f0ad1d_0
+ - graphite2=1.3.13=h58526e2_1001
+ - gst-plugins-base=1.18.4=hf529b03_2
+ - gstreamer=1.18.4=h76c114f_2
+ - harfbuzz=2.9.0=h83ec7ef_0
+ - hdf5=1.10.6=nompi_h6a2412b_1114
+ - icu=68.1=h58526e2_0
+ - idna=2.10=pyhd3eb1b0_0
+ - imagecodecs-lite=2019.12.3=py38h5c078b8_3
+ - imageio=2.9.0=py_0
+ - imageio-ffmpeg=0.4.5=pyhd8ed1ab_0
+ - imgaug=0.4.0=py_1
+ - importlib-metadata=3.10.0=py38h06a4308_0
+ - importlib_metadata=3.10.0=hd3eb1b0_0
+ - intel-openmp=2021.3.0=h06a4308_3350
+ - ipykernel=5.3.4=py38h5ca1d4c_0
+ - ipympl=0.8.2=pyhd8ed1ab_0
+ - ipython=7.26.0=py38hb070fc8_0
+ - ipython_genutils=0.2.0=pyhd3eb1b0_1
+ - ipywidgets=7.6.3=pyhd3eb1b0_1
+ - itsdangerous=2.0.1=pyhd3eb1b0_0
+ - jasper=1.900.1=h07fcdf6_1006
+ - jedi=0.18.0=py38h06a4308_1
+ - jinja2=2.11.3=pyhd3eb1b0_0
+ - joblib=1.1.0=pyhd8ed1ab_0
+ - jpeg=9d=h36c2ea0_0
+ - json5=0.9.6=pyhd3eb1b0_0
+ - jsonnet=0.17.0=py38hadf7658_0
+ - jsonschema=3.2.0=py_2
+ - jupyter_client=6.1.12=pyhd3eb1b0_0
+ - jupyter_core=4.7.1=py38h06a4308_0
+ - jupyter_server=1.4.1=py38h06a4308_0
+ - jupyterlab=3.1.7=pyhd3eb1b0_0
+ - jupyterlab_pygments=0.1.2=py_0
+ - jupyterlab_server=2.7.1=pyhd3eb1b0_0
+ - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
+ - kiwisolver=1.3.1=py38h1fd1430_1
+ - krb5=1.19.2=hcc1bbae_0
+ - lame=3.100=h7f98852_1001
+ - lcms2=2.12=h3be6417_0
+ - ld_impl_linux-64=2.35.1=h7274673_9
+ - libarchive=3.4.2=h62408e4_0
+ - libblas=3.9.0=11_linux64_mkl
+ - libbrotlicommon=1.0.9=h7f98852_5
+ - libbrotlidec=1.0.9=h7f98852_5
+ - libbrotlienc=1.0.9=h7f98852_5
+ - libcblas=3.9.0=11_linux64_mkl
+ - libcurl=7.78.0=h2574ce0_0
+ - libedit=3.1.20191231=he28a2e2_2
+ - libev=4.33=h516909a_1
+ - libevent=2.1.10=hcdb4288_3
+ - libffi=3.3=he6710b0_2
+ - libgcc-ng=11.1.0=hc902ee8_8
+ - libgfortran-ng=11.1.0=h69a702a_8
+ - libgfortran5=11.1.0=h6c583b3_8
+ - libglib=2.68.4=h3e27bee_0
+ - libiconv=1.16=h516909a_0
+ - liblapack=3.9.0=11_linux64_mkl
+ - liblapacke=3.9.0=11_linux64_mkl
+ - liblief=0.10.1=he6710b0_0
+ - libllvm11=11.1.0=hf817b99_2
+ - libnghttp2=1.43.0=h812cca2_0
+ - libogg=1.3.4=h7f98852_1
+ - libopencv=4.5.2=py38hcdf9bf1_0
+ - libopus=1.3.1=h7f98852_1
+ - libpng=1.6.37=hbc83047_0
+ - libpq=13.3=hd57d9b9_0
+ - libprotobuf=3.15.8=h780b84a_0
+ - libsodium=1.0.18=h7b6447c_0
+ - libssh2=1.9.0=ha56f1ee_6
+ - libstdcxx-ng=11.1.0=h56837e0_8
+ - libtiff=4.2.0=h85742a9_0
+ - libuuid=1.0.3=h1bed415_2
+ - libuv=1.40.0=h7b6447c_0
+ - libvorbis=1.3.7=h9c3ff4c_0
+ - libwebp-base=1.2.0=h27cfd23_0
+ - libxcb=1.14=h7b6447c_0
+ - libxkbcommon=1.0.3=he3ba5ed_0
+ - libxml2=2.9.12=h72842e0_0
+ - llvm-openmp=12.0.1=h4bd325d_1
+ - locket=0.2.0=py_2
+ - lz4-c=1.9.3=h295c915_1
+ - markdown=3.3.4=pyhd8ed1ab_0
+ - markupsafe=2.0.1=py38h27cfd23_0
+ - matplotlib=3.4.2=py38h578d9bd_0
+ - matplotlib-base=3.4.2=py38hab158f2_0
+ - matplotlib-inline=0.1.2=pyhd3eb1b0_2
+ - mistune=0.8.4=py38h7b6447c_1000
+ - mkl=2021.3.0=h06a4308_520
+ - mkl-service=2.4.0=py38h7f8727e_0
+ - mkl_fft=1.3.0=py38h42c9631_2
+ - mkl_random=1.2.2=py38h51133e4_0
+ - multidict=5.1.0=py38h497a2fe_1
+ - munkres=1.1.4=pyh9f0ad1d_0
+ - mysql-common=8.0.25=ha770c72_0
+ - mysql-libs=8.0.25=h935591d_0
+ - navigator-updater=0.2.1=py38_0
+ - nbclassic=0.2.6=pyhd3eb1b0_0
+ - nbclient=0.5.3=pyhd3eb1b0_0
+ - nbconvert=6.1.0=py38h06a4308_0
+ - nbformat=5.1.3=pyhd3eb1b0_0
+ - ncurses=6.2=he6710b0_1
+ - nest-asyncio=1.5.1=pyhd3eb1b0_0
+ - nettle=3.6=he412f7d_0
+ - networkx=2.3=py_0
+ - ninja=1.10.2=hff7bd54_1
+ - notebook=6.4.3=py38h06a4308_0
+ - nspr=4.30=h9c3ff4c_0
+ - nss=3.69=hb5efdd6_0
+ - numpy=1.20.3=py38hf144106_0
+ - numpy-base=1.20.3=py38h74d4b33_0
+ - oauthlib=3.1.1=pyhd8ed1ab_0
+ - olefile=0.46=py_0
+ - opencv=4.5.2=py38h578d9bd_0
+ - openh264=2.1.1=h780b84a_0
+ - openjpeg=2.3.0=h05c96fa_1
+ - openssl=1.1.1l=h7f98852_0
+ - packaging=21.0=pyhd3eb1b0_0
+ - pandas=1.3.2=py38h43a58ef_0
+ - pandocfilters=1.4.3=py38h06a4308_1
+ - parso=0.8.2=pyhd3eb1b0_0
+ - partd=1.2.0=pyhd8ed1ab_0
+ - patchelf=0.12=h2531618_1
+ - pathlib=1.0.1=py38h578d9bd_4
+ - patsy=0.5.1=py_0
+ - pcre=8.45=h295c915_0
+ - perl=5.26.2=h14c3975_0
+ - pexpect=4.8.0=pyhd3eb1b0_3
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
+ - pillow=8.3.1=py38h2c7a002_0
+ - pip=21.2.2=py38h06a4308_0
+ - pixman=0.40.0=h36c2ea0_0
+ - pkginfo=1.7.1=py38h06a4308_0
+ - pooch=1.5.1=pyhd8ed1ab_0
+ - portalocker=1.7.0=py38h578d9bd_1
+ - prometheus_client=0.11.0=pyhd3eb1b0_0
+ - prompt-toolkit=3.0.17=pyh06a4308_0
+ - protobuf=3.15.8=py38h709712a_0
+ - psutil=5.8.0=py38h27cfd23_1
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
+ - py-lief=0.10.1=py38h403a769_0
+ - py-opencv=4.5.2=py38hd0cf306_0
+ - pyasn1=0.4.8=py_0
+ - pyasn1-modules=0.2.7=py_0
+ - pycosat=0.6.3=py38h7b6447c_1
+ - pycparser=2.20=py_2
+ - pygments=2.10.0=pyhd3eb1b0_0
+ - pyjwt=2.1.0=pyhd8ed1ab_0
+ - pyopenssl=20.0.1=pyhd3eb1b0_1
+ - pyparsing=2.4.7=pyhd3eb1b0_0
+ - pypng=0.0.20=py_0
+ - pyqt=5.12.3=py38h578d9bd_7
+ - pyqt-impl=5.12.3=py38h7400c14_7
+ - pyqt5-sip=4.19.18=py38h709712a_7
+ - pyqtchart=5.12=py38h7400c14_7
+ - pyqtwebengine=5.12.1=py38h7400c14_7
+ - pyrsistent=0.17.3=py38h7b6447c_0
+ - pysocks=1.7.1=py38h06a4308_0
+ - python=3.8.10=h12debd9_8
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
+ - python-libarchive-c=2.9=pyhd3eb1b0_1
+ - python-lmdb=0.99=py38h709712a_0
+ - python_abi=3.8=2_cp38
+ - pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0
+ - pytz=2021.1=pyhd3eb1b0_0
+ - pyu2f=0.1.5=pyhd8ed1ab_0
+ - pywavelets=1.1.1=py38h5c078b8_3
+ - pyyaml=5.4.1=py38h27cfd23_1
+ - pyzmq=22.2.1=py38h295c915_1
+ - qt=5.12.9=hda022c4_4
+ - qtpy=1.9.0=py_0
+ - readline=8.1=h27cfd23_0
+ - regex=2021.8.28=py38h497a2fe_0
+ - requests=2.25.1=pyhd3eb1b0_0
+ - requests-oauthlib=1.3.0=pyh9f0ad1d_0
+ - ripgrep=12.1.1=0
+ - rsa=4.7.2=pyh44b312d_0
+ - ruamel_yaml=0.15.100=py38h27cfd23_0
+ - scikit-image=0.18.3=py38h43a58ef_0
+ - scikit-learn=1.0=py38hacb3eff_1
+ - scipy=1.7.1=py38h56a6a73_0
+ - seaborn=0.11.2=hd8ed1ab_0
+ - seaborn-base=0.11.2=pyhd8ed1ab_0
+ - send2trash=1.5.0=pyhd3eb1b0_1
+ - setuptools=52.0.0=py38h06a4308_0
+ - shapely=1.8.0=py38hf7953bd_1
+ - sip=4.19.13=py38he6710b0_0
+ - sniffio=1.2.0=py38h06a4308_1
+ - soupsieve=2.2.1=pyhd3eb1b0_0
+ - sqlite=3.36.0=hc218d9a_0
+ - statsmodels=0.12.2=py38h5c078b8_0
+ - tensorboard=2.6.0=pyhd8ed1ab_1
+ - tensorboard-data-server=0.6.0=py38h2b97feb_0
+ - tensorboard-plugin-wit=1.8.0=pyh44b312d_0
+ - tensorboardx=2.4=pyhd8ed1ab_0
+ - terminado=0.9.4=py38h06a4308_0
+ - testpath=0.5.0=pyhd3eb1b0_0
+ - threadpoolctl=3.0.0=pyh8a188c0_0
+ - tifffile=2019.7.26.2=py38_0
+ - tk=8.6.10=hbc83047_0
+ - toolz=0.11.1=py_0
+ - torchfile=0.1.0=py_0
+ - tornado=6.1=py38h27cfd23_0
+ - tqdm=4.62.1=pyhd3eb1b0_1
+ - traitlets=5.0.5=pyhd3eb1b0_0
+ - typing_extensions=3.10.0.0=pyh06a4308_0
+ - urllib3=1.26.6=pyhd3eb1b0_1
+ - wcwidth=0.2.5=py_0
+ - webencodings=0.5.1=py38_1
+ - werkzeug=1.0.1=pyhd3eb1b0_0
+ - wheel=0.37.0=pyhd3eb1b0_0
+ - widgetsnbextension=3.5.1=py38_0
+ - x264=1!161.3030=h7f98852_1
+ - xmltodict=0.12.0=py_0
+ - xz=5.2.5=h7b6447c_0
+ - yacs=0.1.6=py_0
+ - yaml=0.2.5=h7b6447c_0
+ - yarl=1.6.3=py38h497a2fe_2
+ - zeromq=4.3.4=h2531618_0
+ - zipp=3.5.0=pyhd3eb1b0_0
+ - zlib=1.2.11=h7b6447c_3
+ - zstd=1.4.9=haebb681_0
+ - pip:
+ - addict==2.4.0
+ - altair==4.2.0
+ - astor==0.8.1
+ - astunparse==1.6.3
+ - backports-zoneinfo==0.2.1
+ - base58==2.1.1
+ - basicsr==1.3.4.1
+ - boto3==1.18.33
+ - botocore==1.21.33
+ - clang==5.0
+ - clean-fid==0.1.22
+ - clip==1.0
+ - colorama==0.4.4
+ - commonmark==0.9.1
+ - cython==0.29.30
+ - einops==0.3.2
+ - enum-compat==0.0.3
+ - facexlib==0.2.0.3
+ - filterpy==1.4.5
+ - flatbuffers==1.12
+ - gast==0.4.0
+ - google-pasta==0.2.0
+ - grpcio==1.39.0
+ - h5py==3.1.0
+ - ipdb==0.13.9
+ - jacinle==1.0.0
+ - jmespath==0.10.0
+ - jsonpickle==2.2.0
+ - keras==2.7.0
+ - keras-preprocessing==1.1.2
+ - libclang==12.0.0
+ - llvmlite==0.37.0
+ - lpips==0.1.4
+ - numba==0.54.0
+ - opencv-python==4.5.3.56
+ - opt-einsum==3.3.0
+ - pkgconfig==1.5.5
+ - pyarrow==8.0.0
+ - pydantic==1.8.2
+ - pydeck==0.7.1
+ - pyhocon==0.3.58
+ - pytz-deprecation-shim==0.1.0.post0
+ - pyvis==0.2.1
+ - realesrgan==0.2.2.3
+ - rich==10.9.0
+ - s3transfer==0.5.0
+ - six==1.15.0
+ - sklearn==0.0
+ - streamlit==0.64.0
+ - tabulate==0.8.9
+ - tb-nightly==2.7.0a20210827
+ - tensorflow-estimator==2.7.0
+ - tensorflow-gpu==2.7.0
+ - tensorflow-io-gcs-filesystem==0.21.0
+ - tensorfn==0.1.19
+ - termcolor==1.1.0
+ - toml==0.10.2
+ - torchsample==0.1.3
+ - torchvision==0.10.0+cu111
+ - typing-extensions==3.7.4.3
+ - tzdata==2022.1
+ - tzlocal==4.2
+ - validators==0.19.0
+ - vit-pytorch==0.24.3
+ - watchdog==2.1.8
+ - wrapt==1.12.1
+ - yapf==0.31.0
\ No newline at end of file
diff --git a/expansion/__init__.py b/expansion/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/expansion/dataloader/__init__.py b/expansion/dataloader/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/expansion/dataloader/__pycache__/__init__.cpython-38.pyc b/expansion/dataloader/__pycache__/__init__.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..36079aff3f6dfa75f083f17f30c7a0dbc6f69dd5
Binary files /dev/null and b/expansion/dataloader/__pycache__/__init__.cpython-38.pyc differ
diff --git a/expansion/dataloader/__pycache__/seqlist.cpython-38.pyc b/expansion/dataloader/__pycache__/seqlist.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..980a3a34d42ed22b8c711cb73be9424ace31bf95
Binary files /dev/null and b/expansion/dataloader/__pycache__/seqlist.cpython-38.pyc differ
diff --git a/expansion/dataloader/chairslist.py b/expansion/dataloader/chairslist.py
new file mode 100755
index 0000000000000000000000000000000000000000..107f4ee2914a369d31fd2c89b49a7d4ebba51638
--- /dev/null
+++ b/expansion/dataloader/chairslist.py
@@ -0,0 +1,33 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+import glob
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+ l0_train = []
+ l1_train = []
+ flow_train = []
+ for flow_map in sorted(glob.glob(os.path.join(filepath,'*_flow.flo'))):
+ root_filename = flow_map[:-9]
+ img1 = root_filename+'_img1.ppm'
+ img2 = root_filename+'_img2.ppm'
+ if not (os.path.isfile(os.path.join(filepath,img1)) and os.path.isfile(os.path.join(filepath,img2))):
+ continue
+
+ l0_train.append(img1)
+ l1_train.append(img2)
+ flow_train.append(flow_map)
+
+ return l0_train, l1_train, flow_train
diff --git a/expansion/dataloader/chairssdlist.py b/expansion/dataloader/chairssdlist.py
new file mode 100755
index 0000000000000000000000000000000000000000..714c81c8a801beb549806bd72ac6da7d284d25e9
--- /dev/null
+++ b/expansion/dataloader/chairssdlist.py
@@ -0,0 +1,30 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+import glob
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+ l0_train = []
+ l1_train = []
+ flow_train = []
+ for flow_map in sorted(glob.glob('%s/flow/*.pfm'%filepath)):
+ img1 = flow_map.replace('flow','t0').replace('.pfm','.png')
+ img2 = flow_map.replace('flow','t1').replace('.pfm','.png')
+
+ l0_train.append(img1)
+ l1_train.append(img2)
+ flow_train.append(flow_map)
+
+ return l0_train, l1_train, flow_train
diff --git a/expansion/dataloader/depth_transforms.py b/expansion/dataloader/depth_transforms.py
new file mode 100755
index 0000000000000000000000000000000000000000..19a768c788137bfb89077b6c415dc9401a540e4e
--- /dev/null
+++ b/expansion/dataloader/depth_transforms.py
@@ -0,0 +1,471 @@
+from __future__ import division
+import torch
+import random
+import numpy as np
+import numbers
+import types
+import scipy.ndimage as ndimage
+import pdb
+import torchvision
+import PIL.Image as Image
+import cv2
+from torch.nn import functional as F
+
+
+class Compose(object):
+ """ Composes several co_transforms together.
+ For example:
+ >>> co_transforms.Compose([
+ >>> co_transforms.CenterCrop(10),
+ >>> co_transforms.ToTensor(),
+ >>> ])
+ """
+
+ def __init__(self, co_transforms):
+ self.co_transforms = co_transforms
+
+ def __call__(self, input, target,intr):
+ for t in self.co_transforms:
+ input,target,intr = t(input,target,intr)
+ return input,target,intr
+
+
+class Scale(object):
+ """ Rescales the inputs and target arrays to the given 'size'.
+ 'size' will be the size of the smaller edge.
+ For example, if height > width, then image will be
+ rescaled to (size * height / width, size)
+ size: size of the smaller edge
+ interpolation order: Default: 2 (bilinear)
+ """
+
+ def __init__(self, size, order=1):
+ self.ratio = size
+ self.order = order
+ if order==0:
+ self.code=cv2.INTER_NEAREST
+ elif order==1:
+ self.code=cv2.INTER_LINEAR
+ elif order==2:
+ self.code=cv2.INTER_CUBIC
+
+ def __call__(self, inputs, target):
+ if self.ratio==1:
+ return inputs, target
+ h, w, _ = inputs[0].shape
+ ratio = self.ratio
+
+ inputs[0] = cv2.resize(inputs[0], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR)
+ inputs[1] = cv2.resize(inputs[1], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR)
+ # keep the mask same
+ tmp = cv2.resize(target[:,:,2], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_NEAREST)
+ target = cv2.resize(target, None, fx=ratio,fy=ratio,interpolation=self.code) * ratio
+ target[:,:,2] = tmp
+
+
+ return inputs, target
+
+
+class RandomCrop(object):
+ """Crops the given PIL.Image at a random location to have a region of
+ the given size. size can be a tuple (target_height, target_width)
+ or an integer, in which case the target will be of a square shape (size, size)
+ """
+
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+
+ def __call__(self, inputs,target,intr):
+ h, w, _ = inputs[0].shape
+ th, tw = self.size
+ if w < tw: tw=w
+ if h < th: th=h
+
+ x1 = random.randint(0, w - tw)
+ y1 = random.randint(0, h - th)
+ intr[1] -= x1
+ intr[2] -= y1
+
+ inputs[0] = inputs[0][y1: y1 + th,x1: x1 + tw].astype(float)
+ inputs[1] = inputs[1][y1: y1 + th,x1: x1 + tw].astype(float)
+ return inputs, target[y1: y1 + th,x1: x1 + tw].astype(float), list(np.asarray(intr).astype(float)) + list(np.asarray([1.,0.,0.,1.,0.,0.]).astype(float))
+
+
+
+class SpatialAug(object):
+ def __init__(self, crop, scale=None, rot=None, trans=None, squeeze=None, schedule_coeff=1, order=1, black=False):
+ self.crop = crop
+ self.scale = scale
+ self.rot = rot
+ self.trans = trans
+ self.squeeze = squeeze
+ self.t = np.zeros(6)
+ self.schedule_coeff = schedule_coeff
+ self.order = order
+ self.black = black
+
+ def to_identity(self):
+ self.t[0] = 1; self.t[2] = 0; self.t[4] = 0; self.t[1] = 0; self.t[3] = 1; self.t[5] = 0;
+
+ def left_multiply(self, u0, u1, u2, u3, u4, u5):
+ result = np.zeros(6)
+ result[0] = self.t[0]*u0 + self.t[1]*u2;
+ result[1] = self.t[0]*u1 + self.t[1]*u3;
+
+ result[2] = self.t[2]*u0 + self.t[3]*u2;
+ result[3] = self.t[2]*u1 + self.t[3]*u3;
+
+ result[4] = self.t[4]*u0 + self.t[5]*u2 + u4;
+ result[5] = self.t[4]*u1 + self.t[5]*u3 + u5;
+ self.t = result
+
+ def inverse(self):
+ result = np.zeros(6)
+ a = self.t[0]; c = self.t[2]; e = self.t[4];
+ b = self.t[1]; d = self.t[3]; f = self.t[5];
+
+ denom = a*d - b*c;
+
+ result[0] = d / denom;
+ result[1] = -b / denom;
+ result[2] = -c / denom;
+ result[3] = a / denom;
+ result[4] = (c*f-d*e) / denom;
+ result[5] = (b*e-a*f) / denom;
+
+ return result
+
+ def grid_transform(self, meshgrid, t, normalize=True, gridsize=None):
+ if gridsize is None:
+ h, w = meshgrid[0].shape
+ else:
+ h, w = gridsize
+ vgrid = torch.cat([(meshgrid[0] * t[0] + meshgrid[1] * t[2] + t[4])[:,:,np.newaxis],
+ (meshgrid[0] * t[1] + meshgrid[1] * t[3] + t[5])[:,:,np.newaxis]],-1)
+ if normalize:
+ vgrid[:,:,0] = 2.0*vgrid[:,:,0]/max(w-1,1)-1.0
+ vgrid[:,:,1] = 2.0*vgrid[:,:,1]/max(h-1,1)-1.0
+ return vgrid
+
+
+ def __call__(self, inputs, target, intr):
+ h, w, _ = inputs[0].shape
+ th, tw = self.crop
+ meshgrid = torch.meshgrid([torch.Tensor(range(th)), torch.Tensor(range(tw))])[::-1]
+ cornergrid = torch.meshgrid([torch.Tensor([0,th-1]), torch.Tensor([0,tw-1])])[::-1]
+
+ for i in range(50):
+ # im0
+ self.to_identity()
+ #TODO add mirror
+ if np.random.binomial(1,0.5):
+ mirror = True
+ else:
+ mirror = False
+ ##TODO
+ #mirror = False
+ if mirror:
+ self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th);
+ else:
+ self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
+ scale0 = 1; scale1 = 1; squeeze0 = 1; squeeze1 = 1;
+ if not self.rot is None:
+ rot0 = np.random.uniform(-self.rot[0],+self.rot[0])
+ rot1 = np.random.uniform(-self.rot[1]*self.schedule_coeff, self.rot[1]*self.schedule_coeff) + rot0
+ self.left_multiply(np.cos(rot0), np.sin(rot0), -np.sin(rot0), np.cos(rot0), 0, 0)
+ if not self.trans is None:
+ trans0 = np.random.uniform(-self.trans[0],+self.trans[0], 2)
+ trans1 = np.random.uniform(-self.trans[1]*self.schedule_coeff,+self.trans[1]*self.schedule_coeff, 2) + trans0
+ self.left_multiply(1, 0, 0, 1, trans0[0] * tw, trans0[1] * th)
+ if not self.squeeze is None:
+ squeeze0 = np.exp(np.random.uniform(-self.squeeze[0], self.squeeze[0]))
+ squeeze1 = np.exp(np.random.uniform(-self.squeeze[1]*self.schedule_coeff, self.squeeze[1]*self.schedule_coeff)) * squeeze0
+ if not self.scale is None:
+ scale0 = np.exp(np.random.uniform(self.scale[2]-self.scale[0], self.scale[2]+self.scale[0]))
+ scale1 = np.exp(np.random.uniform(-self.scale[1]*self.schedule_coeff, self.scale[1]*self.schedule_coeff)) * scale0
+ self.left_multiply(1.0/(scale0*squeeze0), 0, 0, 1.0/(scale0/squeeze0), 0, 0)
+
+ self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
+ transmat0 = self.t.copy()
+
+ # im1
+ self.to_identity()
+ if mirror:
+ self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th);
+ else:
+ self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
+ if not self.rot is None:
+ self.left_multiply(np.cos(rot1), np.sin(rot1), -np.sin(rot1), np.cos(rot1), 0, 0)
+ if not self.trans is None:
+ self.left_multiply(1, 0, 0, 1, trans1[0] * tw, trans1[1] * th)
+ self.left_multiply(1.0/(scale1*squeeze1), 0, 0, 1.0/(scale1/squeeze1), 0, 0)
+ self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
+ transmat1 = self.t.copy()
+ transmat1_inv = self.inverse()
+
+ if self.black:
+ # black augmentation, allowing 0 values in the input images
+ # https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/black_augmentation_layer.cu
+ break
+ else:
+ if ((self.grid_transform(cornergrid, transmat0, gridsize=[float(h),float(w)]).abs()>1).sum() +\
+ (self.grid_transform(cornergrid, transmat1, gridsize=[float(h),float(w)]).abs()>1).sum()) == 0:
+ break
+ if i==49:
+ print('max_iter in augmentation')
+ self.to_identity()
+ self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
+ self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
+ transmat0 = self.t.copy()
+ transmat1 = self.t.copy()
+
+ # do the real work
+ vgrid = self.grid_transform(meshgrid, transmat0,gridsize=[float(h),float(w)])
+ inputs_0 = F.grid_sample(torch.Tensor(inputs[0]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
+ if self.order == 0:
+ target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0)
+ else:
+ target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
+
+ mask_0 = target[:,:,2:3].copy(); mask_0[mask_0==0]=np.nan
+ if self.order == 0:
+ mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0)
+ else:
+ mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
+ mask_0[torch.isnan(mask_0)] = 0
+
+
+ vgrid = self.grid_transform(meshgrid, transmat1,gridsize=[float(h),float(w)])
+ inputs_1 = F.grid_sample(torch.Tensor(inputs[1]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
+
+ # flow
+ pos = target_0[:,:,:2] + self.grid_transform(meshgrid, transmat0,normalize=False)
+ pos = self.grid_transform(pos.permute(2,0,1),transmat1_inv,normalize=False)
+ if target_0.shape[2]>=4:
+ # scale
+ exp = target_0[:,:,3:] * scale1 / scale0
+ target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1),
+ (pos[:,:,1] - meshgrid[1]).unsqueeze(-1),
+ mask_0,
+ exp], -1)
+ else:
+ target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1),
+ (pos[:,:,1] - meshgrid[1]).unsqueeze(-1),
+ mask_0], -1)
+ inputs = [np.asarray(inputs_0).astype(float), np.asarray(inputs_1).astype(float)]
+ target = np.asarray(target).astype(float)
+ return inputs,target, list(np.asarray(intr+list(transmat0)).astype(float))
+
+
+
+class pseudoPCAAug(object):
+ """
+ Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
+ This version is faster.
+ """
+ def __init__(self, schedule_coeff=1):
+ self.augcolor = torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.5, hue=0.5/3.14)
+
+ def __call__(self, inputs, target,intr):
+ img = np.concatenate([inputs[0],inputs[1]],0)
+ shape = img.shape[0]//2
+ aug_img = np.asarray(self.augcolor(Image.fromarray(np.uint8(img*255))))/255.
+ inputs[0] = aug_img[:shape]
+ inputs[1] = aug_img[shape:]
+ #inputs[0] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[0]*255))))/255.
+ #inputs[1] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[1]*255))))/255.
+ return inputs,target,intr
+
+
+class PCAAug(object):
+ """
+ Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
+ """
+ def __init__(self, lmult_pow =[0.4, 0,-0.2],
+ lmult_mult =[0.4, 0,0, ],
+ lmult_add =[0.03,0,0, ],
+ sat_pow =[0.4, 0,0, ],
+ sat_mult =[0.5, 0,-0.3],
+ sat_add =[0.03,0,0, ],
+ col_pow =[0.4, 0,0, ],
+ col_mult =[0.2, 0,0, ],
+ col_add =[0.02,0,0, ],
+ ladd_pow =[0.4, 0,0, ],
+ ladd_mult =[0.4, 0,0, ],
+ ladd_add =[0.04,0,0, ],
+ col_rotate =[1., 0,0, ],
+ schedule_coeff=1):
+ # no mean
+ self.pow_nomean = [1,1,1]
+ self.add_nomean = [0,0,0]
+ self.mult_nomean = [1,1,1]
+ self.pow_withmean = [1,1,1]
+ self.add_withmean = [0,0,0]
+ self.mult_withmean = [1,1,1]
+ self.lmult_pow = 1
+ self.lmult_mult = 1
+ self.lmult_add = 0
+ self.col_angle = 0
+ if not ladd_pow is None:
+ self.pow_nomean[0] =np.exp(np.random.normal(ladd_pow[2], ladd_pow[0]))
+ if not col_pow is None:
+ self.pow_nomean[1] =np.exp(np.random.normal(col_pow[2], col_pow[0]))
+ self.pow_nomean[2] =np.exp(np.random.normal(col_pow[2], col_pow[0]))
+
+ if not ladd_add is None:
+ self.add_nomean[0] =np.random.normal(ladd_add[2], ladd_add[0])
+ if not col_add is None:
+ self.add_nomean[1] =np.random.normal(col_add[2], col_add[0])
+ self.add_nomean[2] =np.random.normal(col_add[2], col_add[0])
+
+ if not ladd_mult is None:
+ self.mult_nomean[0] =np.exp(np.random.normal(ladd_mult[2], ladd_mult[0]))
+ if not col_mult is None:
+ self.mult_nomean[1] =np.exp(np.random.normal(col_mult[2], col_mult[0]))
+ self.mult_nomean[2] =np.exp(np.random.normal(col_mult[2], col_mult[0]))
+
+ # with mean
+ if not sat_pow is None:
+ self.pow_withmean[1] =np.exp(np.random.uniform(sat_pow[2]-sat_pow[0], sat_pow[2]+sat_pow[0]))
+ self.pow_withmean[2] =self.pow_withmean[1]
+ if not sat_add is None:
+ self.add_withmean[1] =np.random.uniform(sat_add[2]-sat_add[0], sat_add[2]+sat_add[0])
+ self.add_withmean[2] =self.add_withmean[1]
+ if not sat_mult is None:
+ self.mult_withmean[1] = np.exp(np.random.uniform(sat_mult[2]-sat_mult[0], sat_mult[2]+sat_mult[0]))
+ self.mult_withmean[2] = self.mult_withmean[1]
+
+ if not lmult_pow is None:
+ self.lmult_pow = np.exp(np.random.uniform(lmult_pow[2]-lmult_pow[0], lmult_pow[2]+lmult_pow[0]))
+ if not lmult_mult is None:
+ self.lmult_mult= np.exp(np.random.uniform(lmult_mult[2]-lmult_mult[0], lmult_mult[2]+lmult_mult[0]))
+ if not lmult_add is None:
+ self.lmult_add = np.random.uniform(lmult_add[2]-lmult_add[0], lmult_add[2]+lmult_add[0])
+ if not col_rotate is None:
+ self.col_angle= np.random.uniform(col_rotate[2]-col_rotate[0], col_rotate[2]+col_rotate[0])
+
+ # eigen vectors
+ self.eigvec = np.reshape([0.51,0.56,0.65,0.79,0.01,-0.62,0.35,-0.83,0.44],[3,3]).transpose()
+
+
+ def __call__(self, inputs, target, intr):
+ inputs[0] = self.pca_image(inputs[0])
+ inputs[1] = self.pca_image(inputs[1])
+ return inputs,target,intr
+
+ def pca_image(self, rgb):
+ eig = np.dot(rgb, self.eigvec)
+ max_rgb = np.clip(rgb,0,np.inf).max((0,1))
+ min_rgb = rgb.min((0,1))
+ mean_rgb = rgb.mean((0,1))
+ max_abs_eig = np.abs(eig).max((0,1))
+ max_l = np.sqrt(np.sum(max_abs_eig*max_abs_eig))
+ mean_eig = np.dot(mean_rgb, self.eigvec)
+
+ # no-mean stuff
+ eig -= mean_eig[np.newaxis, np.newaxis]
+
+ for c in range(3):
+ if max_abs_eig[c] > 1e-2:
+ mean_eig[c] /= max_abs_eig[c]
+ eig[:,:,c] = eig[:,:,c] / max_abs_eig[c];
+ eig[:,:,c] = np.power(np.abs(eig[:,:,c]),self.pow_nomean[c]) *\
+ ((eig[:,:,c] > 0) -0.5)*2
+ eig[:,:,c] = eig[:,:,c] + self.add_nomean[c]
+ eig[:,:,c] = eig[:,:,c] * self.mult_nomean[c]
+ eig += mean_eig[np.newaxis,np.newaxis]
+
+ # withmean stuff
+ if max_abs_eig[0] > 1e-2:
+ eig[:,:,0] = np.power(np.abs(eig[:,:,0]),self.pow_withmean[0]) * \
+ ((eig[:,:,0]>0)-0.5)*2;
+ eig[:,:,0] = eig[:,:,0] + self.add_withmean[0];
+ eig[:,:,0] = eig[:,:,0] * self.mult_withmean[0];
+
+ s = np.sqrt(eig[:,:,1]*eig[:,:,1] + eig[:,:,2] * eig[:,:,2])
+ smask = s > 1e-2
+ s1 = np.power(s, self.pow_withmean[1]);
+ s1 = np.clip(s1 + self.add_withmean[1], 0,np.inf)
+ s1 = s1 * self.mult_withmean[1]
+ s1 = s1 * smask + s*(1-smask)
+
+ # color angle
+ if self.col_angle!=0:
+ temp1 = np.cos(self.col_angle) * eig[:,:,1] - np.sin(self.col_angle) * eig[:,:,2]
+ temp2 = np.sin(self.col_angle) * eig[:,:,1] + np.cos(self.col_angle) * eig[:,:,2]
+ eig[:,:,1] = temp1
+ eig[:,:,2] = temp2
+
+ # to origin magnitude
+ for c in range(3):
+ if max_abs_eig[c] > 1e-2:
+ eig[:,:,c] = eig[:,:,c] * max_abs_eig[c]
+
+ if max_l > 1e-2:
+ l1 = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2])
+ l1 = l1 / max_l
+
+ eig[:,:,1][smask] = (eig[:,:,1] / s * s1)[smask]
+ eig[:,:,2][smask] = (eig[:,:,2] / s * s1)[smask]
+ #eig[:,:,1] = (eig[:,:,1] / s * s1) * smask + eig[:,:,1] * (1-smask)
+ #eig[:,:,2] = (eig[:,:,2] / s * s1) * smask + eig[:,:,2] * (1-smask)
+
+ if max_l > 1e-2:
+ l = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2])
+ l1 = np.power(l1, self.lmult_pow)
+ l1 = np.clip(l1 + self.lmult_add, 0, np.inf)
+ l1 = l1 * self.lmult_mult
+ l1 = l1 * max_l
+ lmask = l > 1e-2
+ eig[lmask] = (eig / l[:,:,np.newaxis] * l1[:,:,np.newaxis])[lmask]
+ for c in range(3):
+ eig[:,:,c][lmask] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c]))[lmask]
+ # for c in range(3):
+# # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] * lmask + eig[:,:,c] * (1-lmask)
+ # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask]
+ # eig[:,:,c] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c])) * lmask + eig[:,:,c] * (1-lmask)
+
+ return np.clip(np.dot(eig, self.eigvec.transpose()), 0, 1)
+
+
+class ChromaticAug(object):
+ """
+ Chromatic augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
+ """
+ def __init__(self, noise = 0.06,
+ gamma = 0.02,
+ brightness = 0.02,
+ contrast = 0.02,
+ color = 0.02,
+ schedule_coeff=1):
+
+ self.noise = np.random.uniform(0,noise)
+ self.gamma = np.exp(np.random.normal(0, gamma*schedule_coeff))
+ self.brightness = np.random.normal(0, brightness*schedule_coeff)
+ self.contrast = np.exp(np.random.normal(0, contrast*schedule_coeff))
+ self.color = np.exp(np.random.normal(0, color*schedule_coeff,3))
+
+ def __call__(self, inputs, target, intr):
+ inputs[1] = self.chrom_aug(inputs[1])
+ # noise
+ inputs[0]+=np.random.normal(0, self.noise, inputs[0].shape)
+ inputs[1]+=np.random.normal(0, self.noise, inputs[0].shape)
+ return inputs,target,intr
+
+ def chrom_aug(self, rgb):
+ # color change
+ mean_in = rgb.sum(-1)
+ rgb = rgb*self.color[np.newaxis,np.newaxis]
+ brightness_coeff = mean_in / (rgb.sum(-1)+0.01)
+ rgb = np.clip(rgb*brightness_coeff[:,:,np.newaxis],0,1)
+ # gamma
+ rgb = np.power(rgb,self.gamma)
+ # brightness
+ rgb += self.brightness
+ # contrast
+ rgb = 0.5 + ( rgb-0.5)*self.contrast
+ rgb = np.clip(rgb, 0, 1)
+ return rgb
diff --git a/expansion/dataloader/depthloader.py b/expansion/dataloader/depthloader.py
new file mode 100755
index 0000000000000000000000000000000000000000..b189ba5b70a546eeb0a7e78812f19057e43e02fa
--- /dev/null
+++ b/expansion/dataloader/depthloader.py
@@ -0,0 +1,222 @@
+import os
+import numbers
+import torch
+import torch.utils.data as data
+import torch
+import torchvision.transforms as transforms
+import random
+from PIL import Image, ImageOps
+import numpy as np
+import torchvision
+from . import depth_transforms as flow_transforms
+import pdb
+import cv2
+from utils.flowlib import read_flow
+from utils.util_flow import readPFM, load_calib_cam_to_cam
+
+def default_loader(path):
+ return Image.open(path).convert('RGB')
+
+def flow_loader(path):
+ if '.pfm' in path:
+ data = readPFM(path)[0]
+ data[:,:,2] = 1
+ return data
+ else:
+ return read_flow(path)
+
+def load_exts(cam_file):
+ with open(cam_file, 'r') as f:
+ lines = f.readlines()
+
+ l_exts = []
+ r_exts = []
+ for l in lines:
+ if 'L ' in l:
+ l_exts.append(np.asarray([float(i) for i in l[2:].strip().split(' ')]).reshape(4,4))
+ if 'R ' in l:
+ r_exts.append(np.asarray([float(i) for i in l[2:].strip().split(' ')]).reshape(4,4))
+ return l_exts,r_exts
+
+def disparity_loader(path):
+ if '.png' in path:
+ data = Image.open(path)
+ data = np.ascontiguousarray(data,dtype=np.float32)/256
+ return data
+ else:
+ return readPFM(path)[0]
+
+# triangulation
+def triangulation(disp, xcoord, ycoord, bl=1, fl = 450, cx = 479.5, cy = 269.5):
+ depth = bl*fl / disp # 450px->15mm focal length
+ X = (xcoord - cx) * depth / fl
+ Y = (ycoord - cy) * depth / fl
+ Z = depth
+ P = np.concatenate((X[np.newaxis],Y[np.newaxis],Z[np.newaxis]),0).reshape(3,-1)
+ P = np.concatenate((P,np.ones((1,P.shape[-1]))),0)
+ return P
+
+class myImageFloder(data.Dataset):
+ def __init__(self, iml0, iml1, flowl0, loader=default_loader, dploader= flow_loader, scale=1.,shape=[320,448], order=1, noise=0.06, pca_augmentor=True, prob = 1.,sc=False,disp0=None,disp1=None,calib=None ):
+ self.iml0 = iml0
+ self.iml1 = iml1
+ self.flowl0 = flowl0
+ self.loader = loader
+ self.dploader = dploader
+ self.scale=scale
+ self.shape=shape
+ self.order=order
+ self.noise = noise
+ self.pca_augmentor = pca_augmentor
+ self.prob = prob
+ self.sc = sc
+ self.disp0 = disp0
+ self.disp1 = disp1
+ self.calib = calib
+
+ def __getitem__(self, index):
+ iml0 = self.iml0[index]
+ iml1 = self.iml1[index]
+ flowl0= self.flowl0[index]
+ th, tw = self.shape
+
+ iml0 = self.loader(iml0)
+ iml1 = self.loader(iml1)
+
+ # get disparity
+ if self.sc:
+ flowl0 = self.dploader(flowl0)
+ flowl0 = np.ascontiguousarray(flowl0,dtype=np.float32)
+ flowl0[np.isnan(flowl0)] = 1e6 # set to max
+ if 'camera_data.txt' in self.calib[index]:
+ bl=1
+ if '15mm_' in self.calib[index]:
+ fl=450 # 450
+ else:
+ fl=1050
+ cx = 479.5
+ cy = 269.5
+ # negative disp
+ d1 = np.abs(disparity_loader(self.disp0[index]))
+ d2 = np.abs(disparity_loader(self.disp1[index]) + d1)
+ elif 'Sintel' in self.calib[index]:
+ fl = 1000
+ bl = 1
+ cx = 511.5
+ cy = 217.5
+ d1 = np.zeros(flowl0.shape[:2])
+ d2 = np.zeros(flowl0.shape[:2])
+ else:
+ ints = load_calib_cam_to_cam(self.calib[index])
+ fl = ints['K_cam2'][0,0]
+ cx = ints['K_cam2'][0,2]
+ cy = ints['K_cam2'][1,2]
+ bl = ints['b20']-ints['b30']
+ d1 = disparity_loader(self.disp0[index])
+ d2 = disparity_loader(self.disp1[index])
+ #flowl0[:,:,2] = (flowl0[:,:,2]==1).astype(float)
+ flowl0[:,:,2] = np.logical_and(np.logical_and(flowl0[:,:,2]==1, d1!=0), d2!=0).astype(float)
+
+ shape = d1.shape
+ mesh = np.meshgrid(range(shape[1]),range(shape[0]))
+ xcoord = mesh[0].astype(float)
+ ycoord = mesh[1].astype(float)
+
+ # triangulation in two frames
+ P0 = triangulation(d1, xcoord, ycoord, bl=bl, fl = fl, cx = cx, cy = cy)
+ P1 = triangulation(d2, xcoord + flowl0[:,:,0], ycoord + flowl0[:,:,1], bl=bl, fl = fl, cx = cx, cy = cy)
+ dis0 = P0[2]
+ dis1 = P1[2]
+
+ change_size = dis0.reshape(shape).astype(np.float32)
+ flow3d = (P1-P0)[:3].reshape((3,)+shape).transpose((1,2,0))
+
+ gt_normal = np.concatenate((d1[:,:,np.newaxis],d2[:,:,np.newaxis],d2[:,:,np.newaxis]),-1)
+ change_size = np.concatenate((change_size[:,:,np.newaxis],gt_normal,flow3d),2)
+ else:
+ shape = iml0.size
+ shape=[shape[1],shape[0]]
+ flowl0 = np.zeros((shape[0],shape[1],3))
+ change_size = np.zeros((shape[0],shape[1],7))
+ depth = disparity_loader(self.iml1[index].replace('camera','groundtruth'))
+ change_size[:,:,0] = depth
+
+ seqid = self.iml0[index].split('/')[-5].rsplit('_',3)[0]
+ ints = load_calib_cam_to_cam('/data/gengshay/KITTI/%s/calib_cam_to_cam.txt'%seqid)
+ fl = ints['K_cam2'][0,0]
+ cx = ints['K_cam2'][0,2]
+ cy = ints['K_cam2'][1,2]
+ bl = ints['b20']-ints['b30']
+
+
+ iml1 = np.asarray(iml1)/255.
+ iml0 = np.asarray(iml0)/255.
+ iml0 = iml0[:,:,::-1].copy()
+ iml1 = iml1[:,:,::-1].copy()
+
+ ## following data augmentation procedure in PWCNet
+ ## https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
+ import __main__ # a workaround for "discount_coeff"
+ try:
+ with open('/scratch/gengshay/iter_counts-%d.txt'%int(__main__.args.logname.split('-')[-1]), 'r') as f:
+ iter_counts = int(f.readline())
+ except:
+ iter_counts = 0
+ schedule = [0.5, 1., 50000.] # initial coeff, final_coeff, half life
+ schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * \
+ (2/(1+np.exp(-1.0986*iter_counts/schedule[2])) - 1)
+
+ if self.pca_augmentor:
+ pca_augmentor = flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff)
+ else:
+ pca_augmentor = flow_transforms.Scale(1., order=0)
+
+ if np.random.binomial(1,self.prob):
+ co_transform1 = flow_transforms.Compose([
+ flow_transforms.SpatialAug([th,tw],
+ scale=[0.2,0.,0.1],
+ rot=[0.4,0.],
+ trans=[0.4,0.],
+ squeeze=[0.3,0.], schedule_coeff=schedule_coeff, order=self.order),
+ ])
+ else:
+ co_transform1 = flow_transforms.Compose([
+ flow_transforms.RandomCrop([th,tw]),
+ ])
+
+ co_transform2 = flow_transforms.Compose([
+ flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff),
+ #flow_transforms.PCAAug(schedule_coeff=schedule_coeff),
+ flow_transforms.ChromaticAug( schedule_coeff=schedule_coeff, noise=self.noise),
+ ])
+
+ flowl0 = np.concatenate([flowl0,change_size],-1)
+ augmented,flowl0,intr = co_transform1([iml0, iml1], flowl0, [fl,cx,cy,bl])
+ imol0 = augmented[0]
+ imol1 = augmented[1]
+ augmented,flowl0,intr = co_transform2(augmented, flowl0, intr)
+
+ iml0 = augmented[0]
+ iml1 = augmented[1]
+ flowl0 = flowl0.astype(np.float32)
+ change_size = flowl0[:,:,3:]
+ flowl0 = flowl0[:,:,:3]
+
+ # randomly cover a region
+ sx=0;sy=0;cx=0;cy=0
+ if np.random.binomial(1,0.5):
+ sx = int(np.random.uniform(25,100))
+ sy = int(np.random.uniform(25,100))
+ #sx = int(np.random.uniform(50,150))
+ #sy = int(np.random.uniform(50,150))
+ cx = int(np.random.uniform(sx,iml1.shape[0]-sx))
+ cy = int(np.random.uniform(sy,iml1.shape[1]-sy))
+ iml1[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(iml1,0),0)[np.newaxis,np.newaxis]
+
+ iml0 = torch.Tensor(np.transpose(iml0,(2,0,1)))
+ iml1 = torch.Tensor(np.transpose(iml1,(2,0,1)))
+
+ return iml0, iml1, flowl0, change_size, intr, imol0, imol1, np.asarray([cx-sx,cx+sx,cy-sy,cy+sy])
+
+ def __len__(self):
+ return len(self.iml0)
diff --git a/expansion/dataloader/flow_transforms.py b/expansion/dataloader/flow_transforms.py
new file mode 100755
index 0000000000000000000000000000000000000000..3bff1ac448a408012a5aa3a1c1035cbf2695e240
--- /dev/null
+++ b/expansion/dataloader/flow_transforms.py
@@ -0,0 +1,440 @@
+from __future__ import division
+import torch
+import random
+import numpy as np
+import numbers
+import types
+import scipy.ndimage as ndimage
+import pdb
+import torchvision
+import PIL.Image as Image
+import cv2
+from torch.nn import functional as F
+
+
+class Compose(object):
+ """ Composes several co_transforms together.
+ For example:
+ >>> co_transforms.Compose([
+ >>> co_transforms.CenterCrop(10),
+ >>> co_transforms.ToTensor(),
+ >>> ])
+ """
+
+ def __init__(self, co_transforms):
+ self.co_transforms = co_transforms
+
+ def __call__(self, input, target):
+ for t in self.co_transforms:
+ input,target = t(input,target)
+ return input,target
+
+
+class Scale(object):
+ """ Rescales the inputs and target arrays to the given 'size'.
+ 'size' will be the size of the smaller edge.
+ For example, if height > width, then image will be
+ rescaled to (size * height / width, size)
+ size: size of the smaller edge
+ interpolation order: Default: 2 (bilinear)
+ """
+
+ def __init__(self, size, order=1):
+ self.ratio = size
+ self.order = order
+ if order==0:
+ self.code=cv2.INTER_NEAREST
+ elif order==1:
+ self.code=cv2.INTER_LINEAR
+ elif order==2:
+ self.code=cv2.INTER_CUBIC
+
+ def __call__(self, inputs, target):
+ if self.ratio==1:
+ return inputs, target
+ h, w, _ = inputs[0].shape
+ ratio = self.ratio
+
+ inputs[0] = cv2.resize(inputs[0], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR)
+ inputs[1] = cv2.resize(inputs[1], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR)
+ # keep the mask same
+ tmp = cv2.resize(target[:,:,2], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_NEAREST)
+ target = cv2.resize(target, None, fx=ratio,fy=ratio,interpolation=self.code) * ratio
+ target[:,:,2] = tmp
+
+
+ return inputs, target
+
+
+
+
+class SpatialAug(object):
+ def __init__(self, crop, scale=None, rot=None, trans=None, squeeze=None, schedule_coeff=1, order=1, black=False):
+ self.crop = crop
+ self.scale = scale
+ self.rot = rot
+ self.trans = trans
+ self.squeeze = squeeze
+ self.t = np.zeros(6)
+ self.schedule_coeff = schedule_coeff
+ self.order = order
+ self.black = black
+
+ def to_identity(self):
+ self.t[0] = 1; self.t[2] = 0; self.t[4] = 0; self.t[1] = 0; self.t[3] = 1; self.t[5] = 0;
+
+ def left_multiply(self, u0, u1, u2, u3, u4, u5):
+ result = np.zeros(6)
+ result[0] = self.t[0]*u0 + self.t[1]*u2;
+ result[1] = self.t[0]*u1 + self.t[1]*u3;
+
+ result[2] = self.t[2]*u0 + self.t[3]*u2;
+ result[3] = self.t[2]*u1 + self.t[3]*u3;
+
+ result[4] = self.t[4]*u0 + self.t[5]*u2 + u4;
+ result[5] = self.t[4]*u1 + self.t[5]*u3 + u5;
+ self.t = result
+
+ def inverse(self):
+ result = np.zeros(6)
+ a = self.t[0]; c = self.t[2]; e = self.t[4];
+ b = self.t[1]; d = self.t[3]; f = self.t[5];
+
+ denom = a*d - b*c;
+
+ result[0] = d / denom;
+ result[1] = -b / denom;
+ result[2] = -c / denom;
+ result[3] = a / denom;
+ result[4] = (c*f-d*e) / denom;
+ result[5] = (b*e-a*f) / denom;
+
+ return result
+
+ def grid_transform(self, meshgrid, t, normalize=True, gridsize=None):
+ if gridsize is None:
+ h, w = meshgrid[0].shape
+ else:
+ h, w = gridsize
+ vgrid = torch.cat([(meshgrid[0] * t[0] + meshgrid[1] * t[2] + t[4])[:,:,np.newaxis],
+ (meshgrid[0] * t[1] + meshgrid[1] * t[3] + t[5])[:,:,np.newaxis]],-1)
+ if normalize:
+ vgrid[:,:,0] = 2.0*vgrid[:,:,0]/max(w-1,1)-1.0
+ vgrid[:,:,1] = 2.0*vgrid[:,:,1]/max(h-1,1)-1.0
+ return vgrid
+
+
+ def __call__(self, inputs, target):
+ h, w, _ = inputs[0].shape
+ th, tw = self.crop
+ meshgrid = torch.meshgrid([torch.Tensor(range(th)), torch.Tensor(range(tw))])[::-1]
+ cornergrid = torch.meshgrid([torch.Tensor([0,th-1]), torch.Tensor([0,tw-1])])[::-1]
+
+ for i in range(50):
+ # im0
+ self.to_identity()
+ #TODO add mirror
+ if np.random.binomial(1,0.5):
+ mirror = True
+ else:
+ mirror = False
+ ##TODO
+ #mirror = False
+ if mirror:
+ self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th);
+ else:
+ self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
+ scale0 = 1; scale1 = 1; squeeze0 = 1; squeeze1 = 1;
+ if not self.rot is None:
+ rot0 = np.random.uniform(-self.rot[0],+self.rot[0])
+ rot1 = np.random.uniform(-self.rot[1]*self.schedule_coeff, self.rot[1]*self.schedule_coeff) + rot0
+ self.left_multiply(np.cos(rot0), np.sin(rot0), -np.sin(rot0), np.cos(rot0), 0, 0)
+ if not self.trans is None:
+ trans0 = np.random.uniform(-self.trans[0],+self.trans[0], 2)
+ trans1 = np.random.uniform(-self.trans[1]*self.schedule_coeff,+self.trans[1]*self.schedule_coeff, 2) + trans0
+ self.left_multiply(1, 0, 0, 1, trans0[0] * tw, trans0[1] * th)
+ if not self.squeeze is None:
+ squeeze0 = np.exp(np.random.uniform(-self.squeeze[0], self.squeeze[0]))
+ squeeze1 = np.exp(np.random.uniform(-self.squeeze[1]*self.schedule_coeff, self.squeeze[1]*self.schedule_coeff)) * squeeze0
+ if not self.scale is None:
+ scale0 = np.exp(np.random.uniform(self.scale[2]-self.scale[0], self.scale[2]+self.scale[0]))
+ scale1 = np.exp(np.random.uniform(-self.scale[1]*self.schedule_coeff, self.scale[1]*self.schedule_coeff)) * scale0
+ self.left_multiply(1.0/(scale0*squeeze0), 0, 0, 1.0/(scale0/squeeze0), 0, 0)
+
+ self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
+ transmat0 = self.t.copy()
+
+ # im1
+ self.to_identity()
+ if mirror:
+ self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th);
+ else:
+ self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
+ if not self.rot is None:
+ self.left_multiply(np.cos(rot1), np.sin(rot1), -np.sin(rot1), np.cos(rot1), 0, 0)
+ if not self.trans is None:
+ self.left_multiply(1, 0, 0, 1, trans1[0] * tw, trans1[1] * th)
+ self.left_multiply(1.0/(scale1*squeeze1), 0, 0, 1.0/(scale1/squeeze1), 0, 0)
+ self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
+ transmat1 = self.t.copy()
+ transmat1_inv = self.inverse()
+
+ if self.black:
+ # black augmentation, allowing 0 values in the input images
+ # https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/black_augmentation_layer.cu
+ break
+ else:
+ if ((self.grid_transform(cornergrid, transmat0, gridsize=[float(h),float(w)]).abs()>1).sum() +\
+ (self.grid_transform(cornergrid, transmat1, gridsize=[float(h),float(w)]).abs()>1).sum()) == 0:
+ break
+ if i==49:
+ print('max_iter in augmentation')
+ self.to_identity()
+ self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
+ self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
+ transmat0 = self.t.copy()
+ transmat1 = self.t.copy()
+
+ # do the real work
+ vgrid = self.grid_transform(meshgrid, transmat0,gridsize=[float(h),float(w)])
+ inputs_0 = F.grid_sample(torch.Tensor(inputs[0]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
+ if self.order == 0:
+ target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0)
+ else:
+ target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
+
+ mask_0 = target[:,:,2:3].copy(); mask_0[mask_0==0]=np.nan
+ if self.order == 0:
+ mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0)
+ else:
+ mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
+ mask_0[torch.isnan(mask_0)] = 0
+
+
+ vgrid = self.grid_transform(meshgrid, transmat1,gridsize=[float(h),float(w)])
+ inputs_1 = F.grid_sample(torch.Tensor(inputs[1]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
+
+ # flow
+ pos = target_0[:,:,:2] + self.grid_transform(meshgrid, transmat0,normalize=False)
+ pos = self.grid_transform(pos.permute(2,0,1),transmat1_inv,normalize=False)
+ if target_0.shape[2]>=4:
+ # scale
+ exp = target_0[:,:,3:] * scale1 / scale0
+ target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1),
+ (pos[:,:,1] - meshgrid[1]).unsqueeze(-1),
+ mask_0,
+ exp], -1)
+ else:
+ target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1),
+ (pos[:,:,1] - meshgrid[1]).unsqueeze(-1),
+ mask_0], -1)
+# target_0[:,:,2].unsqueeze(-1) ], -1)
+ inputs = [np.asarray(inputs_0), np.asarray(inputs_1)]
+ target = np.asarray(target)
+
+ return inputs,target
+
+
+class pseudoPCAAug(object):
+ """
+ Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
+ This version is faster.
+ """
+ def __init__(self, schedule_coeff=1):
+ self.augcolor = torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.5, hue=0.5/3.14)
+
+ def __call__(self, inputs, target):
+ inputs[0] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[0]*255))))/255.
+ inputs[1] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[1]*255))))/255.
+ return inputs,target
+
+
+class PCAAug(object):
+ """
+ Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
+ """
+ def __init__(self, lmult_pow =[0.4, 0,-0.2],
+ lmult_mult =[0.4, 0,0, ],
+ lmult_add =[0.03,0,0, ],
+ sat_pow =[0.4, 0,0, ],
+ sat_mult =[0.5, 0,-0.3],
+ sat_add =[0.03,0,0, ],
+ col_pow =[0.4, 0,0, ],
+ col_mult =[0.2, 0,0, ],
+ col_add =[0.02,0,0, ],
+ ladd_pow =[0.4, 0,0, ],
+ ladd_mult =[0.4, 0,0, ],
+ ladd_add =[0.04,0,0, ],
+ col_rotate =[1., 0,0, ],
+ schedule_coeff=1):
+ # no mean
+ self.pow_nomean = [1,1,1]
+ self.add_nomean = [0,0,0]
+ self.mult_nomean = [1,1,1]
+ self.pow_withmean = [1,1,1]
+ self.add_withmean = [0,0,0]
+ self.mult_withmean = [1,1,1]
+ self.lmult_pow = 1
+ self.lmult_mult = 1
+ self.lmult_add = 0
+ self.col_angle = 0
+ if not ladd_pow is None:
+ self.pow_nomean[0] =np.exp(np.random.normal(ladd_pow[2], ladd_pow[0]))
+ if not col_pow is None:
+ self.pow_nomean[1] =np.exp(np.random.normal(col_pow[2], col_pow[0]))
+ self.pow_nomean[2] =np.exp(np.random.normal(col_pow[2], col_pow[0]))
+
+ if not ladd_add is None:
+ self.add_nomean[0] =np.random.normal(ladd_add[2], ladd_add[0])
+ if not col_add is None:
+ self.add_nomean[1] =np.random.normal(col_add[2], col_add[0])
+ self.add_nomean[2] =np.random.normal(col_add[2], col_add[0])
+
+ if not ladd_mult is None:
+ self.mult_nomean[0] =np.exp(np.random.normal(ladd_mult[2], ladd_mult[0]))
+ if not col_mult is None:
+ self.mult_nomean[1] =np.exp(np.random.normal(col_mult[2], col_mult[0]))
+ self.mult_nomean[2] =np.exp(np.random.normal(col_mult[2], col_mult[0]))
+
+ # with mean
+ if not sat_pow is None:
+ self.pow_withmean[1] =np.exp(np.random.uniform(sat_pow[2]-sat_pow[0], sat_pow[2]+sat_pow[0]))
+ self.pow_withmean[2] =self.pow_withmean[1]
+ if not sat_add is None:
+ self.add_withmean[1] =np.random.uniform(sat_add[2]-sat_add[0], sat_add[2]+sat_add[0])
+ self.add_withmean[2] =self.add_withmean[1]
+ if not sat_mult is None:
+ self.mult_withmean[1] = np.exp(np.random.uniform(sat_mult[2]-sat_mult[0], sat_mult[2]+sat_mult[0]))
+ self.mult_withmean[2] = self.mult_withmean[1]
+
+ if not lmult_pow is None:
+ self.lmult_pow = np.exp(np.random.uniform(lmult_pow[2]-lmult_pow[0], lmult_pow[2]+lmult_pow[0]))
+ if not lmult_mult is None:
+ self.lmult_mult= np.exp(np.random.uniform(lmult_mult[2]-lmult_mult[0], lmult_mult[2]+lmult_mult[0]))
+ if not lmult_add is None:
+ self.lmult_add = np.random.uniform(lmult_add[2]-lmult_add[0], lmult_add[2]+lmult_add[0])
+ if not col_rotate is None:
+ self.col_angle= np.random.uniform(col_rotate[2]-col_rotate[0], col_rotate[2]+col_rotate[0])
+
+ # eigen vectors
+ self.eigvec = np.reshape([0.51,0.56,0.65,0.79,0.01,-0.62,0.35,-0.83,0.44],[3,3]).transpose()
+
+
+ def __call__(self, inputs, target):
+ inputs[0] = self.pca_image(inputs[0])
+ inputs[1] = self.pca_image(inputs[1])
+ return inputs,target
+
+ def pca_image(self, rgb):
+ eig = np.dot(rgb, self.eigvec)
+ max_rgb = np.clip(rgb,0,np.inf).max((0,1))
+ min_rgb = rgb.min((0,1))
+ mean_rgb = rgb.mean((0,1))
+ max_abs_eig = np.abs(eig).max((0,1))
+ max_l = np.sqrt(np.sum(max_abs_eig*max_abs_eig))
+ mean_eig = np.dot(mean_rgb, self.eigvec)
+
+ # no-mean stuff
+ eig -= mean_eig[np.newaxis, np.newaxis]
+
+ for c in range(3):
+ if max_abs_eig[c] > 1e-2:
+ mean_eig[c] /= max_abs_eig[c]
+ eig[:,:,c] = eig[:,:,c] / max_abs_eig[c];
+ eig[:,:,c] = np.power(np.abs(eig[:,:,c]),self.pow_nomean[c]) *\
+ ((eig[:,:,c] > 0) -0.5)*2
+ eig[:,:,c] = eig[:,:,c] + self.add_nomean[c]
+ eig[:,:,c] = eig[:,:,c] * self.mult_nomean[c]
+ eig += mean_eig[np.newaxis,np.newaxis]
+
+ # withmean stuff
+ if max_abs_eig[0] > 1e-2:
+ eig[:,:,0] = np.power(np.abs(eig[:,:,0]),self.pow_withmean[0]) * \
+ ((eig[:,:,0]>0)-0.5)*2;
+ eig[:,:,0] = eig[:,:,0] + self.add_withmean[0];
+ eig[:,:,0] = eig[:,:,0] * self.mult_withmean[0];
+
+ s = np.sqrt(eig[:,:,1]*eig[:,:,1] + eig[:,:,2] * eig[:,:,2])
+ smask = s > 1e-2
+ s1 = np.power(s, self.pow_withmean[1]);
+ s1 = np.clip(s1 + self.add_withmean[1], 0,np.inf)
+ s1 = s1 * self.mult_withmean[1]
+ s1 = s1 * smask + s*(1-smask)
+
+ # color angle
+ if self.col_angle!=0:
+ temp1 = np.cos(self.col_angle) * eig[:,:,1] - np.sin(self.col_angle) * eig[:,:,2]
+ temp2 = np.sin(self.col_angle) * eig[:,:,1] + np.cos(self.col_angle) * eig[:,:,2]
+ eig[:,:,1] = temp1
+ eig[:,:,2] = temp2
+
+ # to origin magnitude
+ for c in range(3):
+ if max_abs_eig[c] > 1e-2:
+ eig[:,:,c] = eig[:,:,c] * max_abs_eig[c]
+
+ if max_l > 1e-2:
+ l1 = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2])
+ l1 = l1 / max_l
+
+ eig[:,:,1][smask] = (eig[:,:,1] / s * s1)[smask]
+ eig[:,:,2][smask] = (eig[:,:,2] / s * s1)[smask]
+ #eig[:,:,1] = (eig[:,:,1] / s * s1) * smask + eig[:,:,1] * (1-smask)
+ #eig[:,:,2] = (eig[:,:,2] / s * s1) * smask + eig[:,:,2] * (1-smask)
+
+ if max_l > 1e-2:
+ l = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2])
+ l1 = np.power(l1, self.lmult_pow)
+ l1 = np.clip(l1 + self.lmult_add, 0, np.inf)
+ l1 = l1 * self.lmult_mult
+ l1 = l1 * max_l
+ lmask = l > 1e-2
+ eig[lmask] = (eig / l[:,:,np.newaxis] * l1[:,:,np.newaxis])[lmask]
+ for c in range(3):
+ eig[:,:,c][lmask] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c]))[lmask]
+ # for c in range(3):
+# # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] * lmask + eig[:,:,c] * (1-lmask)
+ # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask]
+ # eig[:,:,c] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c])) * lmask + eig[:,:,c] * (1-lmask)
+
+ return np.clip(np.dot(eig, self.eigvec.transpose()), 0, 1)
+
+
+class ChromaticAug(object):
+ """
+ Chromatic augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
+ """
+ def __init__(self, noise = 0.06,
+ gamma = 0.02,
+ brightness = 0.02,
+ contrast = 0.02,
+ color = 0.02,
+ schedule_coeff=1):
+
+ self.noise = np.random.uniform(0,noise)
+ self.gamma = np.exp(np.random.normal(0, gamma*schedule_coeff))
+ self.brightness = np.random.normal(0, brightness*schedule_coeff)
+ self.contrast = np.exp(np.random.normal(0, contrast*schedule_coeff))
+ self.color = np.exp(np.random.normal(0, color*schedule_coeff,3))
+
+ def __call__(self, inputs, target):
+ inputs[1] = self.chrom_aug(inputs[1])
+ # noise
+ inputs[0]+=np.random.normal(0, self.noise, inputs[0].shape)
+ inputs[1]+=np.random.normal(0, self.noise, inputs[0].shape)
+ return inputs,target
+
+ def chrom_aug(self, rgb):
+ # color change
+ mean_in = rgb.sum(-1)
+ rgb = rgb*self.color[np.newaxis,np.newaxis]
+ brightness_coeff = mean_in / (rgb.sum(-1)+0.01)
+ rgb = np.clip(rgb*brightness_coeff[:,:,np.newaxis],0,1)
+ # gamma
+ rgb = np.power(rgb,self.gamma)
+ # brightness
+ rgb += self.brightness
+ # contrast
+ rgb = 0.5 + ( rgb-0.5)*self.contrast
+ rgb = np.clip(rgb, 0, 1)
+ return rgb
diff --git a/expansion/dataloader/hd1klist.py b/expansion/dataloader/hd1klist.py
new file mode 100755
index 0000000000000000000000000000000000000000..9b4754986a850007f43345068d42648c2d6ecc24
--- /dev/null
+++ b/expansion/dataloader/hd1klist.py
@@ -0,0 +1,29 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+import pdb
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+
+ left_fold = 'image_2/'
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('HD1K2018') > -1]
+ train = sorted(train)
+
+ l0_train = [filepath+left_fold+img for img in train]
+ l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
+ l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
+ flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
+
+ return l0_train, l1_train, flow_train
diff --git a/expansion/dataloader/kitti12list.py b/expansion/dataloader/kitti12list.py
new file mode 100755
index 0000000000000000000000000000000000000000..af1e7f2e9438f79dd7219d01274786c5c4f0e794
--- /dev/null
+++ b/expansion/dataloader/kitti12list.py
@@ -0,0 +1,29 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+
+ left_fold = 'colored_0/'
+ flow_noc = 'flow_occ/'
+
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
+
+ l0_train = [filepath+left_fold+img for img in train]
+ l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
+ flow_train = [filepath+flow_noc+img for img in train]
+
+
+ return l0_train, l1_train, flow_train
diff --git a/expansion/dataloader/kitti15list.py b/expansion/dataloader/kitti15list.py
new file mode 100755
index 0000000000000000000000000000000000000000..ff44e5cbb43d64e99bda79aaac4bdb8f7cdd9d2b
--- /dev/null
+++ b/expansion/dataloader/kitti15list.py
@@ -0,0 +1,29 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+
+ left_fold = 'image_2/'
+ flow_noc = 'flow_occ/'
+
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
+
+ l0_train = [filepath+left_fold+img for img in train]
+ l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
+ flow_train = [filepath+flow_noc+img for img in train]
+
+
+ return sorted(l0_train), sorted(l1_train), sorted(flow_train)
diff --git a/expansion/dataloader/kitti15list_train.py b/expansion/dataloader/kitti15list_train.py
new file mode 100755
index 0000000000000000000000000000000000000000..e1eca1af37426274490e916b884271281024cc47
--- /dev/null
+++ b/expansion/dataloader/kitti15list_train.py
@@ -0,0 +1,31 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+
+ left_fold = 'image_2/'
+ flow_noc = 'flow_occ/'
+
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
+
+ train = [i for i in train if int(i.split('_')[0])%5!=0]
+
+ l0_train = [filepath+left_fold+img for img in train]
+ l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
+ flow_train = [filepath+flow_noc+img for img in train]
+
+
+ return sorted(l0_train), sorted(l1_train), sorted(flow_train)
diff --git a/expansion/dataloader/kitti15list_train_lidar.py b/expansion/dataloader/kitti15list_train_lidar.py
new file mode 100755
index 0000000000000000000000000000000000000000..aa77c139455c195d0532aa890c5c8b8f137923cb
--- /dev/null
+++ b/expansion/dataloader/kitti15list_train_lidar.py
@@ -0,0 +1,34 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+
+ left_fold = 'image_2/'
+ flow_noc = 'flow_occ/'
+
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
+
+# train = [i for i in train if int(i.split('_')[0])%5!=0]
+ with open('/data/gengshay/kitti_scene/devkit/mapping/train_mapping.txt','r') as f:
+ flags = [True if len(i)>1 else False for i in f.readlines()]
+ train = [fn for (it,fn) in enumerate(sorted(train)) if flags[it] ][:100]
+
+ l0_train = [filepath+left_fold+img for img in train]
+ l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
+ flow_train = [filepath+flow_noc+img for img in train]
+
+
+ return sorted(l0_train), sorted(l1_train), sorted(flow_train)
diff --git a/expansion/dataloader/kitti15list_val.py b/expansion/dataloader/kitti15list_val.py
new file mode 100755
index 0000000000000000000000000000000000000000..3d5e39e245dd8878f1f6270dcc2b28ff2c85ed5d
--- /dev/null
+++ b/expansion/dataloader/kitti15list_val.py
@@ -0,0 +1,31 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+
+ left_fold = 'image_2/'
+ flow_noc = 'flow_occ/'
+
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
+
+ train = [i for i in train if int(i.split('_')[0])%5==0]
+
+ l0_train = [filepath+left_fold+img for img in train]
+ l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
+ flow_train = [filepath+flow_noc+img for img in train]
+
+
+ return sorted(l0_train), sorted(l1_train), sorted(flow_train)
diff --git a/expansion/dataloader/kitti15list_val_lidar.py b/expansion/dataloader/kitti15list_val_lidar.py
new file mode 100755
index 0000000000000000000000000000000000000000..12420446c328bfa45c283a80a9d406b54e7cf346
--- /dev/null
+++ b/expansion/dataloader/kitti15list_val_lidar.py
@@ -0,0 +1,34 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+
+ left_fold = 'image_2/'
+ flow_noc = 'flow_occ/'
+
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
+
+# train = [i for i in train if int(i.split('_')[0])%5!=0]
+ with open('/data/gengshay/kitti_scene/devkit/mapping/train_mapping.txt','r') as f:
+ flags = [True if len(i)>1 else False for i in f.readlines()]
+ train = [fn for (it,fn) in enumerate(sorted(train)) if flags[it] ][100:]
+
+ l0_train = [filepath+left_fold+img for img in train]
+ l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
+ flow_train = [filepath+flow_noc+img for img in train]
+
+
+ return sorted(l0_train), sorted(l1_train), sorted(flow_train)
diff --git a/expansion/dataloader/kitti15list_val_mr.py b/expansion/dataloader/kitti15list_val_mr.py
new file mode 100755
index 0000000000000000000000000000000000000000..56c209f65a734b9874738456a1bebc843f6fecb1
--- /dev/null
+++ b/expansion/dataloader/kitti15list_val_mr.py
@@ -0,0 +1,41 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+
+ left_fold = 'image_2/'
+ flow_noc = 'flow_occ/'
+
+ train = [img for img in os.listdir(filepath+left_fold) if 'Kitti' in img and img.find('_10') > -1]
+
+# train = [i for i in train if int(i.split('_')[1])%5==0]
+ import pdb; pdb.set_trace()
+ train = sorted([i for i in train if int(i.split('_')[1])%5==0])[0:1]
+
+ l0_train = [filepath+left_fold+img for img in train]
+ l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
+ flow_train = [filepath+flow_noc+img for img in train]
+
+ l0_train += [filepath+left_fold+img.replace('_10','_09') for img in train]
+ l1_train += [filepath+left_fold+img for img in train]
+ flow_train += flow_train
+
+ tmp = l0_train
+ l0_train = l0_train+ [i.replace('rob_flow', 'kitti_scene').replace('Kitti2015_','') for i in l1_train]
+ l1_train = l1_train+tmp
+ flow_train += flow_train
+
+ return l0_train, l1_train, flow_train
diff --git a/expansion/dataloader/robloader.py b/expansion/dataloader/robloader.py
new file mode 100755
index 0000000000000000000000000000000000000000..471f2262c77d7c71b48b0fb187159e9d2dbc517a
--- /dev/null
+++ b/expansion/dataloader/robloader.py
@@ -0,0 +1,133 @@
+import os
+import numbers
+import torch
+import torch.utils.data as data
+import torch
+import torchvision.transforms as transforms
+import random
+from PIL import Image, ImageOps
+import numpy as np
+import torchvision
+from . import flow_transforms
+import pdb
+import cv2
+from utils.flowlib import read_flow
+from utils.util_flow import readPFM
+
+
+def default_loader(path):
+ return Image.open(path).convert('RGB')
+
+def flow_loader(path):
+ if '.pfm' in path:
+ data = readPFM(path)[0]
+ data[:,:,2] = 1
+ return data
+ else:
+ return read_flow(path)
+
+
+def disparity_loader(path):
+ if '.png' in path:
+ data = Image.open(path)
+ data = np.ascontiguousarray(data,dtype=np.float32)/256
+ return data
+ else:
+ return readPFM(path)[0]
+
+class myImageFloder(data.Dataset):
+ def __init__(self, iml0, iml1, flowl0, loader=default_loader, dploader= flow_loader, scale=1.,shape=[320,448], order=1, noise=0.06, pca_augmentor=True, prob = 1., cover=False, black=False, scale_aug=[0.4,0.2]):
+ self.iml0 = iml0
+ self.iml1 = iml1
+ self.flowl0 = flowl0
+ self.loader = loader
+ self.dploader = dploader
+ self.scale=scale
+ self.shape=shape
+ self.order=order
+ self.noise = noise
+ self.pca_augmentor = pca_augmentor
+ self.prob = prob
+ self.cover = cover
+ self.black = black
+ self.scale_aug = scale_aug
+
+ def __getitem__(self, index):
+ iml0 = self.iml0[index]
+ iml1 = self.iml1[index]
+ flowl0= self.flowl0[index]
+ th, tw = self.shape
+
+ iml0 = self.loader(iml0)
+ iml1 = self.loader(iml1)
+ iml1 = np.asarray(iml1)/255.
+ iml0 = np.asarray(iml0)/255.
+ iml0 = iml0[:,:,::-1].copy()
+ iml1 = iml1[:,:,::-1].copy()
+ flowl0 = self.dploader(flowl0)
+ #flowl0[:,:,-1][flowl0[:,:,0]==np.inf]=0 # for gtav window pfm files
+ #flowl0[:,:,0][~flowl0[:,:,2].astype(bool)]=0
+ #flowl0[:,:,1][~flowl0[:,:,2].astype(bool)]=0 # avoid nan in grad
+ flowl0 = np.ascontiguousarray(flowl0,dtype=np.float32)
+ flowl0[np.isnan(flowl0)] = 1e6 # set to max
+
+ ## following data augmentation procedure in PWCNet
+ ## https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
+ import __main__ # a workaround for "discount_coeff"
+ try:
+ with open('iter_counts-%d.txt'%int(__main__.args.logname.split('-')[-1]), 'r') as f:
+ iter_counts = int(f.readline())
+ except:
+ iter_counts = 0
+ schedule = [0.5, 1., 50000.] # initial coeff, final_coeff, half life
+ schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * \
+ (2/(1+np.exp(-1.0986*iter_counts/schedule[2])) - 1)
+
+ if self.pca_augmentor:
+ pca_augmentor = flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff)
+ else:
+ pca_augmentor = flow_transforms.Scale(1., order=0)
+
+ if np.random.binomial(1,self.prob):
+ co_transform = flow_transforms.Compose([
+ flow_transforms.Scale(self.scale, order=self.order),
+ #flow_transforms.SpatialAug([th,tw], trans=[0.2,0.03], order=self.order, black=self.black),
+ flow_transforms.SpatialAug([th,tw],scale=[self.scale_aug[0],0.03,self.scale_aug[1]],
+ rot=[0.4,0.03],
+ trans=[0.4,0.03],
+ squeeze=[0.3,0.], schedule_coeff=schedule_coeff, order=self.order, black=self.black),
+ #flow_transforms.pseudoPCAAug(schedule_coeff=schedule_coeff),
+ flow_transforms.PCAAug(schedule_coeff=schedule_coeff),
+ flow_transforms.ChromaticAug( schedule_coeff=schedule_coeff, noise=self.noise),
+ ])
+ else:
+ co_transform = flow_transforms.Compose([
+ flow_transforms.Scale(self.scale, order=self.order),
+ flow_transforms.SpatialAug([th,tw], trans=[0.4,0.03], order=self.order, black=self.black)
+ ])
+
+ augmented,flowl0 = co_transform([iml0, iml1], flowl0)
+ iml0 = augmented[0]
+ iml1 = augmented[1]
+
+ if self.cover:
+ ## randomly cover a region
+ # following sec. 3.2 of http://openaccess.thecvf.com/content_CVPR_2019/html/Yang_Hierarchical_Deep_Stereo_Matching_on_High-Resolution_Images_CVPR_2019_paper.html
+ if np.random.binomial(1,0.5):
+ #sx = int(np.random.uniform(25,100))
+ #sy = int(np.random.uniform(25,100))
+ sx = int(np.random.uniform(50,125))
+ sy = int(np.random.uniform(50,125))
+ #sx = int(np.random.uniform(50,150))
+ #sy = int(np.random.uniform(50,150))
+ cx = int(np.random.uniform(sx,iml1.shape[0]-sx))
+ cy = int(np.random.uniform(sy,iml1.shape[1]-sy))
+ iml1[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(iml1,0),0)[np.newaxis,np.newaxis]
+
+ iml0 = torch.Tensor(np.transpose(iml0,(2,0,1)))
+ iml1 = torch.Tensor(np.transpose(iml1,(2,0,1)))
+
+ return iml0, iml1, flowl0
+
+ def __len__(self):
+ return len(self.iml0)
diff --git a/expansion/dataloader/sceneflowlist.py b/expansion/dataloader/sceneflowlist.py
new file mode 100755
index 0000000000000000000000000000000000000000..0ba93e6caa33214f0a2bf7a46c74566c86414d0b
--- /dev/null
+++ b/expansion/dataloader/sceneflowlist.py
@@ -0,0 +1,51 @@
+import os
+import os.path
+import glob
+
+def dataloader(filepath, level=6):
+ iml0 = []
+ iml1 = []
+ flowl0 = []
+ disp0 = []
+ dispc = []
+ calib = []
+ level_stars = '/*'*level
+ candidate_pool = glob.glob('%s/optical_flow%s'%(filepath,level_stars))
+ for flow_path in sorted(candidate_pool):
+ if 'TEST' in flow_path: continue
+ if 'flower_storm_x2/into_future/right/OpticalFlowIntoFuture_0023_R.pfm' in flow_path:
+ continue
+ if 'flower_storm_x2/into_future/left/OpticalFlowIntoFuture_0023_L.pfm' in flow_path:
+ continue
+ if 'flower_storm_augmented0_x2/into_future/right/OpticalFlowIntoFuture_0023_R.pfm' in flow_path:
+ continue
+ if 'flower_storm_augmented0_x2/into_future/left/OpticalFlowIntoFuture_0023_L.pfm' in flow_path:
+ continue
+ if 'FlyingThings' in flow_path and '_0014_' in flow_path:
+ continue
+ if 'FlyingThings' in flow_path and '_0015_' in flow_path:
+ continue
+ idd = flow_path.split('/')[-1].split('_')[-2]
+ if 'into_future' in flow_path:
+ idd_p1 = '%04d'%(int(idd)+1)
+ else:
+ idd_p1 = '%04d'%(int(idd)-1)
+ if os.path.exists(flow_path.replace(idd,idd_p1)):
+ d0_path = flow_path.replace('/into_future/','/').replace('/into_past/','/').replace('optical_flow','disparity')
+ d0_path = '%s/%s.pfm'%(d0_path.rsplit('/',1)[0],idd)
+ dc_path = flow_path.replace('optical_flow','disparity_change')
+ dc_path = '%s/%s.pfm'%(dc_path.rsplit('/',1)[0],idd)
+ im_path = flow_path.replace('/into_future/','/').replace('/into_past/','/').replace('optical_flow','frames_cleanpass')
+ im0_path = '%s/%s.png'%(im_path.rsplit('/',1)[0],idd)
+ im1_path = '%s/%s.png'%(im_path.rsplit('/',1)[0],idd_p1)
+ #with open('%s/camera_data.txt'%(im0_path.replace('frames_cleanpass','camera_data').rsplit('/',2)[0]),'r') as f:
+ # if 'FlyingThings' in flow_path and len(f.readlines())!=40:
+ # print(flow_path)
+ # continue
+ iml0.append(im0_path)
+ iml1.append(im1_path)
+ flowl0.append(flow_path)
+ disp0.append(d0_path)
+ dispc.append(dc_path)
+ calib.append('%s/camera_data.txt'%(im0_path.replace('frames_cleanpass','camera_data').rsplit('/',2)[0]))
+ return iml0, iml1, flowl0, disp0, dispc, calib
diff --git a/expansion/dataloader/seqlist.py b/expansion/dataloader/seqlist.py
new file mode 100755
index 0000000000000000000000000000000000000000..2e0e8bda391a7ddfd09b3268065dc02712bc7575
--- /dev/null
+++ b/expansion/dataloader/seqlist.py
@@ -0,0 +1,26 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+import glob
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+
+ train = [img for img in sorted(glob.glob('%s/*'%filepath))]
+
+ l0_train = train[:-1]
+ l1_train = train[1:]
+
+
+ return sorted(l0_train), sorted(l1_train), sorted(l0_train)
diff --git a/expansion/dataloader/sintellist.py b/expansion/dataloader/sintellist.py
new file mode 100755
index 0000000000000000000000000000000000000000..44bc1ab5d466d605b7bc695adb41f2431aa0f790
--- /dev/null
+++ b/expansion/dataloader/sintellist.py
@@ -0,0 +1,32 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+import pdb
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+
+ left_fold = 'image_2/'
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
+
+ l0_train = [filepath+left_fold+img for img in train]
+ l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
+
+ #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
+
+ l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
+ flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
+
+
+ return l0_train, l1_train, flow_train
diff --git a/expansion/dataloader/sintellist_clean.py b/expansion/dataloader/sintellist_clean.py
new file mode 100755
index 0000000000000000000000000000000000000000..c008399e94e150856e921bef12199af9910400f6
--- /dev/null
+++ b/expansion/dataloader/sintellist_clean.py
@@ -0,0 +1,31 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+import pdb
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+
+ left_fold = 'image_2/'
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_clean') > -1]
+
+ l0_train = [filepath+left_fold+img for img in train]
+ l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
+
+ #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
+
+ l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
+ flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
+
+ return l0_train, l1_train, flow_train
diff --git a/expansion/dataloader/sintellist_final.py b/expansion/dataloader/sintellist_final.py
new file mode 100755
index 0000000000000000000000000000000000000000..b8585d594b02984bbe13005f680aaf1e859864d3
--- /dev/null
+++ b/expansion/dataloader/sintellist_final.py
@@ -0,0 +1,32 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+import pdb
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+
+ left_fold = 'image_2/'
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_final') > -1]
+
+ l0_train = [filepath+left_fold+img for img in train]
+ l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
+
+ #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
+
+ l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
+ flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
+
+ pdb.set_trace()
+ return l0_train, l1_train, flow_train
diff --git a/expansion/dataloader/sintellist_train.py b/expansion/dataloader/sintellist_train.py
new file mode 100755
index 0000000000000000000000000000000000000000..81ff9393fb5896983648885186f5fb7c2e6907de
--- /dev/null
+++ b/expansion/dataloader/sintellist_train.py
@@ -0,0 +1,32 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+import pdb
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+
+ left_fold = 'image_2/'
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
+
+ l0_train = [filepath+left_fold+img for img in train]
+ l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
+
+ l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val
+
+ l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
+ flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
+
+
+ return l0_train, l1_train, flow_train
diff --git a/expansion/dataloader/sintellist_val.py b/expansion/dataloader/sintellist_val.py
new file mode 100755
index 0000000000000000000000000000000000000000..452446556bc18aee9ff2b47ab5d2ff81c7b5a2ec
--- /dev/null
+++ b/expansion/dataloader/sintellist_val.py
@@ -0,0 +1,34 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+import pdb
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+
+ left_fold = 'image_2/'
+ train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
+
+ l0_train = [filepath+left_fold+img for img in train]
+ l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
+
+ l0_train = [i for i in l0_train if ('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i)] # remove 10 as val
+ #l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val
+
+ l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
+ flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
+
+
+ return sorted(l0_train)[::3], sorted(l1_train)[::3], sorted(flow_train)[::3]
+# return sorted(l0_train)[::10], sorted(l1_train)[::10], sorted(flow_train)[::10]
diff --git a/expansion/dataloader/thingslist.py b/expansion/dataloader/thingslist.py
new file mode 100755
index 0000000000000000000000000000000000000000..cfe3976bce5738dc2000b2dfc736b41a28d8624d
--- /dev/null
+++ b/expansion/dataloader/thingslist.py
@@ -0,0 +1,122 @@
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+import numpy as np
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+def dataloader(filepath):
+ exc_list = [
+'0004117.flo',
+'0003149.flo',
+'0001203.flo',
+'0003147.flo',
+'0003666.flo',
+'0006337.flo',
+'0006336.flo',
+'0007126.flo',
+'0004118.flo',
+]
+
+ left_fold = 'image_clean/left/'
+ flow_noc = 'flow/left/into_future/'
+ train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
+
+ l0_trainlf = [filepath+left_fold+img.replace('flo','png') for img in train]
+ l1_trainlf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlf]
+ flow_trainlf = [filepath+flow_noc+img for img in train]
+
+
+ exc_list = [
+'0003148.flo',
+'0004117.flo',
+'0002890.flo',
+'0003149.flo',
+'0001203.flo',
+'0003666.flo',
+'0006337.flo',
+'0006336.flo',
+'0004118.flo',
+]
+
+ left_fold = 'image_clean/right/'
+ flow_noc = 'flow/right/into_future/'
+ train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
+
+ l0_trainrf = [filepath+left_fold+img.replace('flo','png') for img in train]
+ l1_trainrf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrf]
+ flow_trainrf = [filepath+flow_noc+img for img in train]
+
+
+ exc_list = [
+'0004237.flo',
+'0004705.flo',
+'0004045.flo',
+'0004346.flo',
+'0000161.flo',
+'0000931.flo',
+'0000121.flo',
+'0010822.flo',
+'0004117.flo',
+'0006023.flo',
+'0005034.flo',
+'0005054.flo',
+'0000162.flo',
+'0000053.flo',
+'0005055.flo',
+'0003147.flo',
+'0004876.flo',
+'0000163.flo',
+'0006878.flo',
+]
+
+ left_fold = 'image_clean/left/'
+ flow_noc = 'flow/left/into_past/'
+ train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
+
+ l0_trainlp = [filepath+left_fold+img.replace('flo','png') for img in train]
+ l1_trainlp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlp]
+ flow_trainlp = [filepath+flow_noc+img for img in train]
+
+ exc_list = [
+'0003148.flo',
+'0004705.flo',
+'0000161.flo',
+'0000121.flo',
+'0004117.flo',
+'0000160.flo',
+'0005034.flo',
+'0005054.flo',
+'0000162.flo',
+'0000053.flo',
+'0005055.flo',
+'0003147.flo',
+'0001549.flo',
+'0000163.flo',
+'0006336.flo',
+'0001648.flo',
+'0006878.flo',
+]
+
+ left_fold = 'image_clean/right/'
+ flow_noc = 'flow/right/into_past/'
+ train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
+
+ l0_trainrp = [filepath+left_fold+img.replace('flo','png') for img in train]
+ l1_trainrp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrp]
+ flow_trainrp = [filepath+flow_noc+img for img in train]
+
+
+ l0_train = l0_trainlf + l0_trainrf + l0_trainlp + l0_trainrp
+ l1_train = l1_trainlf + l1_trainrf + l1_trainlp + l1_trainrp
+ flow_train = flow_trainlf + flow_trainrf + flow_trainlp + flow_trainrp
+ return l0_train, l1_train, flow_train
diff --git a/expansion/models/VCN_exp.py b/expansion/models/VCN_exp.py
new file mode 100755
index 0000000000000000000000000000000000000000..6115ef9a68c63a12fae83bec92846de173c829b5
--- /dev/null
+++ b/expansion/models/VCN_exp.py
@@ -0,0 +1,561 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import os
+os.environ['PYTHON_EGG_CACHE'] = 'tmp/' # a writable directory
+import numpy as np
+import math
+import pdb
+import time
+
+from .submodule import pspnet, bfmodule, conv
+from .conv4d import sepConv4d, sepConv4dBlock, butterfly4D
+
+class flow_reg(nn.Module):
+ """
+ Soft winner-take-all that selects the most likely diplacement.
+ Set ent=True to enable entropy output.
+ Set maxdisp to adjust maximum allowed displacement towards one side.
+ maxdisp=4 searches for a 9x9 region.
+ Set fac to squeeze search window.
+ maxdisp=4 and fac=2 gives search window of 9x5
+ """
+ def __init__(self, size, ent=False, maxdisp = int(4), fac=1):
+ B,W,H = size
+ super(flow_reg, self).__init__()
+ self.ent = ent
+ self.md = maxdisp
+ self.fac = fac
+ self.truncated = True
+ self.wsize = 3 # by default using truncation 7x7
+
+ flowrangey = range(-maxdisp,maxdisp+1)
+ flowrangex = range(-int(maxdisp//self.fac),int(maxdisp//self.fac)+1)
+ meshgrid = np.meshgrid(flowrangex,flowrangey)
+ flowy = np.tile( np.reshape(meshgrid[0],[1,2*maxdisp+1,2*int(maxdisp//self.fac)+1,1,1]), (B,1,1,H,W) )
+ flowx = np.tile( np.reshape(meshgrid[1],[1,2*maxdisp+1,2*int(maxdisp//self.fac)+1,1,1]), (B,1,1,H,W) )
+ self.register_buffer('flowx',torch.Tensor(flowx))
+ self.register_buffer('flowy',torch.Tensor(flowy))
+
+ self.pool3d = nn.MaxPool3d((self.wsize*2+1,self.wsize*2+1,1),stride=1,padding=(self.wsize,self.wsize,0))
+
+ def forward(self, x):
+ b,u,v,h,w = x.shape
+ oldx = x
+
+ if self.truncated:
+ # truncated softmax
+ x = x.view(b,u*v,h,w)
+
+ idx = x.argmax(1)[:,np.newaxis]
+ if x.is_cuda:
+ mask = Variable(torch.cuda.HalfTensor(b,u*v,h,w)).fill_(0)
+ else:
+ mask = Variable(torch.FloatTensor(b,u*v,h,w)).fill_(0)
+ mask.scatter_(1,idx,1)
+ mask = mask.view(b,1,u,v,-1)
+ mask = self.pool3d(mask)[:,0].view(b,u,v,h,w)
+
+ ninf = x.clone().fill_(-np.inf).view(b,u,v,h,w)
+ x = torch.where(mask.byte(),oldx,ninf)
+ else:
+ self.wsize = (np.sqrt(u*v)-1)/2
+
+ b,u,v,h,w = x.shape
+ x = F.softmax(x.view(b,-1,h,w),1).view(b,u,v,h,w)
+ outx = torch.sum(torch.sum(x*self.flowx,1),1,keepdim=True)
+ outy = torch.sum(torch.sum(x*self.flowy,1),1,keepdim=True)
+
+ if self.ent:
+ # local
+ local_entropy = (-x*torch.clamp(x,1e-9,1-1e-9).log()).sum(1).sum(1)[:,np.newaxis]
+ if self.wsize == 0:
+ local_entropy[:] = 1.
+ else:
+ local_entropy /= np.log((self.wsize*2+1)**2)
+
+ # global
+ x = F.softmax(oldx.view(b,-1,h,w),1).view(b,u,v,h,w)
+ global_entropy = (-x*torch.clamp(x,1e-9,1-1e-9).log()).sum(1).sum(1)[:,np.newaxis]
+ global_entropy /= np.log(x.shape[1]*x.shape[2])
+ return torch.cat([outx,outy],1),torch.cat([local_entropy, global_entropy],1)
+ else:
+ return torch.cat([outx,outy],1),None
+
+
+class WarpModule(nn.Module):
+ """
+ taken from https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py
+ """
+ def __init__(self, size):
+ super(WarpModule, self).__init__()
+ B,W,H = size
+ # mesh grid
+ xx = torch.arange(0, W).view(1,-1).repeat(H,1)
+ yy = torch.arange(0, H).view(-1,1).repeat(1,W)
+ xx = xx.view(1,1,H,W).repeat(B,1,1,1)
+ yy = yy.view(1,1,H,W).repeat(B,1,1,1)
+ self.register_buffer('grid',torch.cat((xx,yy),1).float())
+
+ def forward(self, x, flo):
+ """
+ warp an image/tensor (im2) back to im1, according to the optical flow
+
+ x: [B, C, H, W] (im2)
+ flo: [B, 2, H, W] flow
+
+ """
+ B, C, H, W = x.size()
+ vgrid = self.grid + flo
+
+ # scale grid to [-1,1]
+ vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:]/max(W-1,1)-1.0
+ vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:]/max(H-1,1)-1.0
+
+ vgrid = vgrid.permute(0,2,3,1)
+ output = nn.functional.grid_sample(x, vgrid,align_corners=True)
+ mask = ((vgrid[:,:,:,0].abs()<1) * (vgrid[:,:,:,1].abs()<1)) >0
+ return output*mask.unsqueeze(1).float(), mask
+
+
+def get_grid(B,H,W):
+ meshgrid_base = np.meshgrid(range(0,W), range(0,H))[::-1]
+ basey = np.reshape(meshgrid_base[0],[1,1,1,H,W])
+ basex = np.reshape(meshgrid_base[1],[1,1,1,H,W])
+ grid = torch.tensor(np.concatenate((basex.reshape((-1,H,W,1)),basey.reshape((-1,H,W,1))),-1)).cuda().float()
+ return grid.view(1,1,H,W,2)
+
+
+class VCN(nn.Module):
+ """
+ VCN.
+ md defines maximum displacement for each level, following a coarse-to-fine-warping scheme
+ fac defines squeeze parameter for the coarsest level
+ """
+ def __init__(self, size, md=[4,4,4,4,4], fac=1.,exp_unc=False): # exp_uncertainty
+ super(VCN,self).__init__()
+ self.md = md
+ self.fac = fac
+ use_entropy = True
+ withbn = True
+
+ ## pspnet
+ self.pspnet = pspnet(is_proj=False)
+
+ ### Volumetric-UNet
+ fdima1 = 128 # 6/5/4
+ fdima2 = 64 # 3/2
+ fdimb1 = 16 # 6/5/4/3
+ fdimb2 = 12 # 2
+
+ full=False
+ self.f6 = butterfly4D(fdima1, fdimb1,withbn=withbn,full=full)
+ self.p6 = sepConv4d(fdimb1,fdimb1, with_bn=False, full=full)
+
+ self.f5 = butterfly4D(fdima1, fdimb1,withbn=withbn, full=full)
+ self.p5 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full)
+
+ self.f4 = butterfly4D(fdima1, fdimb1,withbn=withbn,full=full)
+ self.p4 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full)
+
+ self.f3 = butterfly4D(fdima2, fdimb1,withbn=withbn,full=full)
+ self.p3 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full)
+
+ full=True
+ self.f2 = butterfly4D(fdima2, fdimb2,withbn=withbn,full=full)
+ self.p2 = sepConv4d(fdimb2,fdimb2, with_bn=False,full=full)
+
+ self.flow_reg64 = flow_reg([fdimb1*size[0],size[1]//64,size[2]//64], ent=use_entropy, maxdisp=self.md[0], fac=self.fac)
+ self.flow_reg32 = flow_reg([fdimb1*size[0],size[1]//32,size[2]//32], ent=use_entropy, maxdisp=self.md[1])
+ self.flow_reg16 = flow_reg([fdimb1*size[0],size[1]//16,size[2]//16], ent=use_entropy, maxdisp=self.md[2])
+ self.flow_reg8 = flow_reg([fdimb1*size[0],size[1]//8,size[2]//8] , ent=use_entropy, maxdisp=self.md[3])
+ self.flow_reg4 = flow_reg([fdimb2*size[0],size[1]//4,size[2]//4] , ent=use_entropy, maxdisp=self.md[4])
+
+ self.warp5 = WarpModule([size[0],size[1]//32,size[2]//32])
+ self.warp4 = WarpModule([size[0],size[1]//16,size[2]//16])
+ self.warp3 = WarpModule([size[0],size[1]//8,size[2]//8])
+ self.warp2 = WarpModule([size[0],size[1]//4,size[2]//4])
+
+ ## hypotheses fusion modules, adopted from the refinement module of PWCNet
+ # https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py
+ # c6
+ self.dc6_conv1 = conv(128+4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1)
+ self.dc6_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
+ self.dc6_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
+ self.dc6_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
+ self.dc6_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
+ self.dc6_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
+ self.dc6_conv7 = nn.Conv2d(32,2*fdimb1,kernel_size=3,stride=1,padding=1,bias=True)
+
+ # c5
+ self.dc5_conv1 = conv(128+4*fdimb1*2, 128, kernel_size=3, stride=1, padding=1, dilation=1)
+ self.dc5_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
+ self.dc5_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
+ self.dc5_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
+ self.dc5_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
+ self.dc5_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
+ self.dc5_conv7 = nn.Conv2d(32,2*fdimb1*2,kernel_size=3,stride=1,padding=1,bias=True)
+
+ # c4
+ self.dc4_conv1 = conv(128+4*fdimb1*3, 128, kernel_size=3, stride=1, padding=1, dilation=1)
+ self.dc4_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
+ self.dc4_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
+ self.dc4_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
+ self.dc4_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
+ self.dc4_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
+ self.dc4_conv7 = nn.Conv2d(32,2*fdimb1*3,kernel_size=3,stride=1,padding=1,bias=True)
+
+ # c3
+ self.dc3_conv1 = conv(64+16*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1)
+ self.dc3_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
+ self.dc3_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
+ self.dc3_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
+ self.dc3_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
+ self.dc3_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
+ self.dc3_conv7 = nn.Conv2d(32,8*fdimb1,kernel_size=3,stride=1,padding=1,bias=True)
+
+ # c2
+ self.dc2_conv1 = conv(64+16*fdimb1+4*fdimb2, 128, kernel_size=3, stride=1, padding=1, dilation=1)
+ self.dc2_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
+ self.dc2_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
+ self.dc2_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
+ self.dc2_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
+ self.dc2_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
+ self.dc2_conv7 = nn.Conv2d(32,4*2*fdimb1 + 2*fdimb2,kernel_size=3,stride=1,padding=1,bias=True)
+
+ self.dc6_conv = nn.Sequential( self.dc6_conv1,
+ self.dc6_conv2,
+ self.dc6_conv3,
+ self.dc6_conv4,
+ self.dc6_conv5,
+ self.dc6_conv6,
+ self.dc6_conv7)
+ self.dc5_conv = nn.Sequential( self.dc5_conv1,
+ self.dc5_conv2,
+ self.dc5_conv3,
+ self.dc5_conv4,
+ self.dc5_conv5,
+ self.dc5_conv6,
+ self.dc5_conv7)
+ self.dc4_conv = nn.Sequential( self.dc4_conv1,
+ self.dc4_conv2,
+ self.dc4_conv3,
+ self.dc4_conv4,
+ self.dc4_conv5,
+ self.dc4_conv6,
+ self.dc4_conv7)
+ self.dc3_conv = nn.Sequential( self.dc3_conv1,
+ self.dc3_conv2,
+ self.dc3_conv3,
+ self.dc3_conv4,
+ self.dc3_conv5,
+ self.dc3_conv6,
+ self.dc3_conv7)
+ self.dc2_conv = nn.Sequential( self.dc2_conv1,
+ self.dc2_conv2,
+ self.dc2_conv3,
+ self.dc2_conv4,
+ self.dc2_conv5,
+ self.dc2_conv6,
+ self.dc2_conv7)
+
+ ## Out-of-range detection
+ self.dc6_convo = nn.Sequential(conv(128+4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1),
+ conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
+ conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
+ conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
+ conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
+ conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
+ nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
+
+ self.dc5_convo = nn.Sequential(conv(128+2*4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1),
+ conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
+ conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
+ conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
+ conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
+ conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
+ nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
+
+ self.dc4_convo = nn.Sequential(conv(128+3*4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1),
+ conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
+ conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
+ conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
+ conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
+ conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
+ nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
+
+ self.dc3_convo = nn.Sequential(conv(64+16*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1),
+ conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
+ conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
+ conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
+ conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
+ conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
+ nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
+
+ self.dc2_convo = nn.Sequential(conv(64+16*fdimb1+4*fdimb2, 128, kernel_size=3, stride=1, padding=1, dilation=1),
+ conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
+ conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
+ conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
+ conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
+ conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
+ nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True))
+
+ # affine-exp
+ self.f3d2v1 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
+ self.f3d2v2 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
+ self.f3d2v3 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
+ self.f3d2v4 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
+ self.f3d2v5 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
+ self.f3d2v6 = conv(12*81, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
+ self.f3d2 = bfmodule(128-64,1)
+
+ # depth change net
+ self.dcnetv1 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
+ self.dcnetv2 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
+ self.dcnetv3 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
+ self.dcnetv4 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
+ self.dcnetv5 = conv(12*81, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
+ self.dcnetv6 = conv(4, 32, kernel_size=3, stride=1, padding=1,dilation=1) #
+ if exp_unc:
+ self.dcnet = bfmodule(128,2)
+ else:
+ self.dcnet = bfmodule(128,1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv3d):
+ n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ if hasattr(m.bias,'data'):
+ m.bias.data.zero_()
+
+ self.facs = [self.fac,1,1,1,1]
+ self.warp_modules = nn.ModuleList([None, self.warp5, self.warp4, self.warp3, self.warp2])
+ self.f_modules = nn.ModuleList([self.f6, self.f5, self.f4, self.f3, self.f2])
+ self.p_modules = nn.ModuleList([self.p6, self.p5, self.p4, self.p3, self.p2])
+ self.reg_modules = nn.ModuleList([self.flow_reg64, self.flow_reg32, self.flow_reg16, self.flow_reg8, self.flow_reg4])
+ self.oor_modules = nn.ModuleList([self.dc6_convo, self.dc5_convo, self.dc4_convo, self.dc3_convo, self.dc2_convo])
+ self.fuse_modules = nn.ModuleList([self.dc6_conv, self.dc5_conv, self.dc4_conv, self.dc3_conv, self.dc2_conv])
+
+ def corrf(self, refimg_fea, targetimg_fea,maxdisp, fac=1):
+ """
+ slow correlation function
+ """
+ b,c,height,width = refimg_fea.shape
+ if refimg_fea.is_cuda:
+ cost = Variable(torch.cuda.FloatTensor(b,c,2*maxdisp+1,2*int(maxdisp//fac)+1,height,width)).fill_(0.) # b,c,u,v,h,w
+ else:
+ cost = Variable(torch.FloatTensor(b,c,2*maxdisp+1,2*int(maxdisp//fac)+1,height,width)).fill_(0.) # b,c,u,v,h,w
+ for i in range(2*maxdisp+1):
+ ind = i-maxdisp
+ for j in range(2*int(maxdisp//fac)+1):
+ indd = j-int(maxdisp//fac)
+ feata = refimg_fea[:,:,max(0,-indd):height-indd,max(0,-ind):width-ind]
+ featb = targetimg_fea[:,:,max(0,+indd):height+indd,max(0,ind):width+ind]
+ diff = (feata*featb)
+ cost[:, :, i,j,max(0,-indd):height-indd,max(0,-ind):width-ind] = diff # standard
+ cost = F.leaky_relu(cost, 0.1,inplace=True)
+ return cost
+
+ def cost_matching(self,up_flow, c1, c2, flowh, enth, level):
+ """
+ up_flow: upsample coarse flow
+ c1: normalized feature of image 1
+ c2: normalized feature of image 2
+ flowh: flow hypotheses
+ enth: entropy
+ """
+
+ # normalize
+ c1n = c1 / (c1.norm(dim=1, keepdim=True)+1e-9)
+ c2n = c2 / (c2.norm(dim=1, keepdim=True)+1e-9)
+
+ # cost volume
+ if level == 0:
+ warp = c2n
+ else:
+ warp,_ = self.warp_modules[level](c2n, up_flow)
+
+ feat = self.corrf(c1n,warp,self.md[level],fac=self.facs[level])
+ feat = self.f_modules[level](feat)
+ cost = self.p_modules[level](feat) # b, 16, u,v,h,w
+
+ # soft WTA
+ b,c,u,v,h,w = cost.shape
+ cost = cost.view(-1,u,v,h,w) # bx16, 9,9,h,w, also predict uncertainty from here
+ flowhh,enthh = self.reg_modules[level](cost) # bx16, 2, h, w
+ flowhh = flowhh.view(b,c,2,h,w)
+ if level > 0:
+ flowhh = flowhh + up_flow[:,np.newaxis]
+ flowhh = flowhh.view(b,-1,h,w) # b, 16*2, h, w
+ enthh = enthh.view(b,-1,h,w) # b, 16*1, h, w
+
+ # append coarse hypotheses
+ if level == 0:
+ flowh = flowhh
+ enth = enthh
+ else:
+ flowh = torch.cat((flowhh, F.upsample(flowh.detach()*2, [flowhh.shape[2],flowhh.shape[3]], mode='bilinear')),1) # b, k2--k2, h, w
+ enth = torch.cat((enthh, F.upsample(enth, [flowhh.shape[2],flowhh.shape[3]], mode='bilinear')),1)
+
+ if self.training or level==4:
+ x = torch.cat((enth.detach(), flowh.detach(), c1),1)
+ oor = self.oor_modules[level](x)[:,0]
+ else: oor = None
+
+ # hypotheses fusion
+ x = torch.cat((enth.detach(), flowh.detach(), c1),1)
+ va = self.fuse_modules[level](x)
+ va = va.view(b,-1,2,h,w)
+ flow = ( flowh.view(b,-1,2,h,w) * F.softmax(va,1) ).sum(1) # b, 2k, 2, h, w
+
+ return flow, flowh, enth, oor
+
+ def affine(self,pref,flow, pw=1):
+ b,_,lh,lw=flow.shape
+ ptar = pref + flow
+ pw = 1
+ pref = F.unfold(pref, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-pref[:,:,np.newaxis]
+ ptar = F.unfold(ptar, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-ptar[:,:,np.newaxis] # b, 2,9,h,w
+ pref = pref.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2)
+ ptar = ptar.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2)
+
+ prefprefT = pref.matmul(pref.permute(0,2,1))
+ ppdet = prefprefT[:,0,0]*prefprefT[:,1,1]-prefprefT[:,1,0]*prefprefT[:,0,1]
+ ppinv = torch.cat((prefprefT[:,1,1:],-prefprefT[:,0,1:], -prefprefT[:,1:,0], prefprefT[:,0:1,0]),1).view(-1,2,2)/ppdet.clamp(1e-10,np.inf)[:,np.newaxis,np.newaxis]
+
+ Affine = ptar.matmul(pref.permute(0,2,1)).matmul(ppinv)
+ Error = (Affine.matmul(pref)-ptar).norm(2,1).mean(1).view(b,1,lh,lw)
+
+ Avol = (Affine[:,0,0]*Affine[:,1,1]-Affine[:,1,0]*Affine[:,0,1]).view(b,1,lh,lw).abs().clamp(1e-10,np.inf)
+ exp = Avol.sqrt()
+ mask = (exp>0.5) & (exp<2) & (Error<0.1)
+ mask = mask[:,0]
+
+ exp = exp.clamp(0.5,2)
+ exp[Error>0.1]=1
+ return exp, Error, mask
+
+ def affine_mask(self,pref,flow, pw=3):
+ """
+ pref: reference coordinates
+ pw: patch width
+ """
+ flmask = flow[:,2:]
+ flow = flow[:,:2]
+ b,_,lh,lw=flow.shape
+ ptar = pref + flow
+ pref = F.unfold(pref, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-pref[:,:,np.newaxis]
+ ptar = F.unfold(ptar, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-ptar[:,:,np.newaxis] # b, 2,9,h,w
+
+ conf_flow = flmask
+ conf_flow = F.unfold(conf_flow,(pw*2+1,pw*2+1), padding=(pw)).view(b,1,(pw*2+1)**2,lh,lw)
+ count = conf_flow.sum(2,keepdims=True)
+ conf_flow = ((pw*2+1)**2)*conf_flow / count
+ pref = pref * conf_flow
+ ptar = ptar * conf_flow
+
+ pref = pref.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2)
+ ptar = ptar.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2)
+
+ prefprefT = pref.matmul(pref.permute(0,2,1))
+ ppdet = prefprefT[:,0,0]*prefprefT[:,1,1]-prefprefT[:,1,0]*prefprefT[:,0,1]
+ ppinv = torch.cat((prefprefT[:,1,1:],-prefprefT[:,0,1:], -prefprefT[:,1:,0], prefprefT[:,0:1,0]),1).view(-1,2,2)/ppdet.clamp(1e-10,np.inf)[:,np.newaxis,np.newaxis]
+
+ Affine = ptar.matmul(pref.permute(0,2,1)).matmul(ppinv)
+ Error = (Affine.matmul(pref)-ptar).norm(2,1).mean(1).view(b,1,lh,lw)
+
+ Avol = (Affine[:,0,0]*Affine[:,1,1]-Affine[:,1,0]*Affine[:,0,1]).view(b,1,lh,lw).abs().clamp(1e-10,np.inf)
+ exp = Avol.sqrt()
+ mask = (exp>0.5) & (exp<2) & (Error<0.2) & (flmask.bool()) & (count[:,0]>4)
+ mask = mask[:,0]
+
+ exp = exp.clamp(0.5,2)
+ exp[Error>0.2]=1
+ return exp, Error, mask
+
+ def weight_parameters(self):
+ return [param for name, param in self.named_parameters() if 'weight' in name]
+
+ def bias_parameters(self):
+ return [param for name, param in self.named_parameters() if 'bias' in name]
+
+ def forward(self,im,disc_aux=None):
+ bs = im.shape[0]//2
+
+ if self.training and disc_aux[-1]: # if only fine-tuning expansion
+ reset=True
+ self.eval()
+ torch.set_grad_enabled(False)
+ else: reset=False
+
+ c06,c05,c04,c03,c02 = self.pspnet(im)
+ c16 = c06[:bs]; c26 = c06[bs:]
+ c15 = c05[:bs]; c25 = c05[bs:]
+ c14 = c04[:bs]; c24 = c04[bs:]
+ c13 = c03[:bs]; c23 = c03[bs:]
+ c12 = c02[:bs]; c22 = c02[bs:]
+
+ ## matching 6
+ flow6, flow6h, ent6h, oor6 = self.cost_matching(None, c16, c26, None, None,level=0)
+
+ ## matching 5
+ up_flow6 = F.upsample(flow6, [im.size()[2]//32,im.size()[3]//32], mode='bilinear')*2
+ flow5, flow5h, ent5h, oor5 = self.cost_matching(up_flow6, c15, c25, flow6h, ent6h,level=1)
+
+ ## matching 4
+ up_flow5 = F.upsample(flow5, [im.size()[2]//16,im.size()[3]//16], mode='bilinear')*2
+ flow4, flow4h, ent4h, oor4 = self.cost_matching(up_flow5, c14, c24, flow5h, ent5h,level=2)
+
+ ## matching 3
+ up_flow4 = F.upsample(flow4, [im.size()[2]//8,im.size()[3]//8], mode='bilinear')*2
+ flow3, flow3h, ent3h, oor3 = self.cost_matching(up_flow4, c13, c23, flow4h, ent4h,level=3)
+
+ ## matching 2
+ up_flow3 = F.upsample(flow3, [im.size()[2]//4,im.size()[3]//4], mode='bilinear')*2
+ flow2, flow2h, ent2h, oor2 = self.cost_matching(up_flow3, c12, c22, flow3h, ent3h,level=4)
+
+ if reset:
+ torch.set_grad_enabled(True)
+ self.train()
+
+ # expansion
+ b,_,h,w = flow2.shape
+ exp2,err2,_ = self.affine(get_grid(b,h,w)[:,0].permute(0,3,1,2).repeat(b,1,1,1).clone(), flow2.detach(),pw=1)
+ x = torch.cat((
+ self.f3d2v2(-exp2.log()),
+ self.f3d2v3(err2),
+ ),1)
+ dchange2 = -exp2.log()+1./200*self.f3d2(x)[0]
+
+ # depth change net
+ iexp2 = F.upsample(dchange2.clone(), [im.size()[2],im.size()[3]], mode='bilinear')
+
+ x = torch.cat((self.dcnetv1(c12.detach()),
+ self.dcnetv2(dchange2.detach()),
+ self.dcnetv3(-exp2.log()),
+ self.dcnetv4(err2),
+ ),1)
+ dcneto = 1./200*self.dcnet(x)[0]
+ dchange2 = dchange2.detach() + dcneto[:,:1]
+
+ flow2 = F.upsample(flow2.detach(), [im.size()[2],im.size()[3]], mode='bilinear')*4
+ dchange2 = F.upsample(dchange2, [im.size()[2],im.size()[3]], mode='bilinear')
+
+ if self.training:
+ flowl0 = disc_aux[0].permute(0,3,1,2).clone()
+ gt_depth = disc_aux[2][:,:,:,0]
+ gt_f3d = disc_aux[2][:,:,:,4:7].permute(0,3,1,2).clone()
+ gt_dchange = (1+gt_f3d[:,2]/gt_depth)
+ maskdc = (gt_dchange < 2) & (gt_dchange > 0.5) & disc_aux[1]
+
+ gt_expi,gt_expi_err,maskoe = self.affine_mask(get_grid(b,4*h,4*w)[:,0].permute(0,3,1,2).repeat(b,1,1,1), flowl0,pw=3)
+ gt_exp = 1./gt_expi[:,0]
+
+ loss = 0.1* (dchange2[:,0]-gt_dchange.log()).abs()[maskdc].mean()
+ loss += 0.1* (iexp2[:,0]-gt_exp.log()).abs()[maskoe].mean()
+ return flow2*4, flow3*8,flow4*16,flow5*32,flow6*64,loss, dchange2[:,0], iexp2[:,0]
+
+ else:
+ return flow2, oor2, dchange2, iexp2
+
diff --git a/expansion/models/__init__.py b/expansion/models/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/expansion/models/__pycache__/VCN_exp.cpython-38.pyc b/expansion/models/__pycache__/VCN_exp.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..adab22249d7acf5f2de9c51538b0d2f2779cae49
Binary files /dev/null and b/expansion/models/__pycache__/VCN_exp.cpython-38.pyc differ
diff --git a/expansion/models/__pycache__/__init__.cpython-38.pyc b/expansion/models/__pycache__/__init__.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..0be5cf74212f653439ad3462dae141dbe4016ac1
Binary files /dev/null and b/expansion/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/expansion/models/__pycache__/conv4d.cpython-38.pyc b/expansion/models/__pycache__/conv4d.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..254f5581fc8389dd765f0d8008b9bd9f9574d870
Binary files /dev/null and b/expansion/models/__pycache__/conv4d.cpython-38.pyc differ
diff --git a/expansion/models/__pycache__/submodule.cpython-38.pyc b/expansion/models/__pycache__/submodule.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..4a42df05b4e4a0431ad436bb8dc9942e6f3224d6
Binary files /dev/null and b/expansion/models/__pycache__/submodule.cpython-38.pyc differ
diff --git a/expansion/models/conv4d.py b/expansion/models/conv4d.py
new file mode 100755
index 0000000000000000000000000000000000000000..2747f2cf1709cc3b0adb1f0a16583eb01b2e4a1d
--- /dev/null
+++ b/expansion/models/conv4d.py
@@ -0,0 +1,296 @@
+import pdb
+import torch.nn as nn
+import math
+import torch
+from torch.nn.parameter import Parameter
+import torch.nn.functional as F
+from torch.nn import Module
+from torch.nn.modules.conv import _ConvNd
+from torch.nn.modules.utils import _quadruple
+from torch.autograd import Variable
+from torch.nn import Conv2d
+
+def conv4d(data,filters,bias=None,permute_filters=True,use_half=False):
+ """
+ This is done by stacking results of multiple 3D convolutions, and is very slow.
+ Taken from https://github.com/ignacio-rocco/ncnet
+ """
+ b,c,h,w,d,t=data.size()
+
+ data=data.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop
+
+ # Same permutation is done with filters, unless already provided with permutation
+ if permute_filters:
+ filters=filters.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop
+
+ c_out=filters.size(1)
+ if use_half:
+ output = Variable(torch.HalfTensor(h,b,c_out,w,d,t),requires_grad=data.requires_grad)
+ else:
+ output = Variable(torch.zeros(h,b,c_out,w,d,t),requires_grad=data.requires_grad)
+
+ padding=filters.size(0)//2
+ if use_half:
+ Z=Variable(torch.zeros(padding,b,c,w,d,t).half())
+ else:
+ Z=Variable(torch.zeros(padding,b,c,w,d,t))
+
+ if data.is_cuda:
+ Z=Z.cuda(data.get_device())
+ output=output.cuda(data.get_device())
+
+ data_padded = torch.cat((Z,data,Z),0)
+
+
+ for i in range(output.size(0)): # loop on first feature dimension
+ # convolve with center channel of filter (at position=padding)
+ output[i,:,:,:,:,:]=F.conv3d(data_padded[i+padding,:,:,:,:,:],
+ filters[padding,:,:,:,:,:], bias=bias, stride=1, padding=padding)
+ # convolve with upper/lower channels of filter (at postions [:padding] [padding+1:])
+ for p in range(1,padding+1):
+ output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding-p,:,:,:,:,:],
+ filters[padding-p,:,:,:,:,:], bias=None, stride=1, padding=padding)
+ output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding+p,:,:,:,:,:],
+ filters[padding+p,:,:,:,:,:], bias=None, stride=1, padding=padding)
+
+ output=output.permute(1,2,0,3,4,5).contiguous()
+ return output
+
+class Conv4d(_ConvNd):
+ """Applies a 4D convolution over an input signal composed of several input
+ planes.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True):
+ # stride, dilation and groups !=1 functionality not tested
+ stride=1
+ dilation=1
+ groups=1
+ # zero padding is added automatically in conv4d function to preserve tensor size
+ padding = 0
+ kernel_size = _quadruple(kernel_size)
+ stride = _quadruple(stride)
+ padding = _quadruple(padding)
+ dilation = _quadruple(dilation)
+ super(Conv4d, self).__init__(
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
+ False, _quadruple(0), groups, bias)
+ # weights will be sliced along one dimension during convolution loop
+ # make the looping dimension to be the first one in the tensor,
+ # so that we don't need to call contiguous() inside the loop
+ self.pre_permuted_filters=pre_permuted_filters
+ if self.pre_permuted_filters:
+ self.weight.data=self.weight.data.permute(2,0,1,3,4,5).contiguous()
+ self.use_half=False
+ # self.isbias = bias
+ # if not self.isbias:
+ # self.bn = torch.nn.BatchNorm1d(out_channels)
+
+
+ def forward(self, input):
+ out = conv4d(input, self.weight, bias=self.bias,permute_filters=not self.pre_permuted_filters,use_half=self.use_half) # filters pre-permuted in constructor
+ # if not self.isbias:
+ # b,c,u,v,h,w = out.shape
+ # out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w)
+ return out
+
+class fullConv4d(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True):
+ super(fullConv4d, self).__init__()
+ self.conv = Conv4d(in_channels, out_channels, kernel_size, bias=bias, pre_permuted_filters=pre_permuted_filters)
+ self.isbias = bias
+ if not self.isbias:
+ self.bn = torch.nn.BatchNorm1d(out_channels)
+
+ def forward(self, input):
+ out = self.conv(input)
+ if not self.isbias:
+ b,c,u,v,h,w = out.shape
+ out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w)
+ return out
+
+class butterfly4D(torch.nn.Module):
+ '''
+ butterfly 4d
+ '''
+ def __init__(self, fdima, fdimb, withbn=True, full=True,groups=1):
+ super(butterfly4D, self).__init__()
+ self.proj = nn.Sequential(projfeat4d(fdima, fdimb, 1, with_bn=withbn,groups=groups),
+ nn.ReLU(inplace=True),)
+ self.conva1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups)
+ self.conva2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups)
+ self.convb3 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
+ self.convb2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
+ self.convb1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
+
+ #@profile
+ def forward(self,x):
+ out = self.proj(x)
+ b,c,u,v,h,w = out.shape # 9x9
+
+ out1 = self.conva1(out) # 5x5, 3
+ _,c1,u1,v1,h1,w1 = out1.shape
+
+ out2 = self.conva2(out1) # 3x3, 9
+ _,c2,u2,v2,h2,w2 = out2.shape
+
+ out2 = self.convb3(out2) # 3x3, 9
+
+ tout1 = F.upsample(out2.view(b,c,u2,v2,-1),(u1,v1,h2*w2),mode='trilinear').view(b,c,u1,v1,h2,w2) # 5x5
+ tout1 = F.upsample(tout1.view(b,c,-1,h2,w2),(u1*v1,h1,w1),mode='trilinear').view(b,c,u1,v1,h1,w1) # 5x5
+ out1 = tout1 + out1
+ out1 = self.convb2(out1)
+
+ tout = F.upsample(out1.view(b,c,u1,v1,-1),(u,v,h1*w1),mode='trilinear').view(b,c,u,v,h1,w1)
+ tout = F.upsample(tout.view(b,c,-1,h1,w1),(u*v,h,w),mode='trilinear').view(b,c,u,v,h,w)
+ out = tout + out
+ out = self.convb1(out)
+
+ return out
+
+
+
+class projfeat4d(torch.nn.Module):
+ '''
+ Turn 3d projection into 2d projection
+ '''
+ def __init__(self, in_planes, out_planes, stride, with_bn=True,groups=1):
+ super(projfeat4d, self).__init__()
+ self.with_bn = with_bn
+ self.stride = stride
+ self.conv1 = nn.Conv3d(in_planes, out_planes, 1, (stride,stride,1), padding=0,bias=not with_bn,groups=groups)
+ self.bn = nn.BatchNorm3d(out_planes)
+
+ def forward(self,x):
+ b,c,u,v,h,w = x.size()
+ x = self.conv1(x.view(b,c,u,v,h*w))
+ if self.with_bn:
+ x = self.bn(x)
+ _,c,u,v,_ = x.shape
+ x = x.view(b,c,u,v,h,w)
+ return x
+
+class sepConv4d(torch.nn.Module):
+ '''
+ Separable 4d convolution block as 2 3D convolutions
+ '''
+ def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, ksize=3, full=True,groups=1):
+ super(sepConv4d, self).__init__()
+ bias = not with_bn
+ self.isproj = False
+ self.stride = stride[0]
+ expand = 1
+
+ if with_bn:
+ if in_planes != out_planes:
+ self.isproj = True
+ self.proj = nn.Sequential(nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups),
+ nn.BatchNorm2d(out_planes))
+ if full:
+ self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups),
+ nn.BatchNorm3d(in_planes))
+ else:
+ self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups),
+ nn.BatchNorm3d(in_planes))
+ self.conv2 = nn.Sequential(nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups),
+ nn.BatchNorm3d(in_planes*expand))
+ else:
+ if in_planes != out_planes:
+ self.isproj = True
+ self.proj = nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups)
+ if full:
+ self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups)
+ else:
+ self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups)
+ self.conv2 = nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups)
+ self.relu = nn.ReLU(inplace=True)
+
+ #@profile
+ def forward(self,x):
+ b,c,u,v,h,w = x.shape
+ x = self.conv2(x.view(b,c,u,v,-1))
+ b,c,u,v,_ = x.shape
+ x = self.relu(x)
+ x = self.conv1(x.view(b,c,-1,h,w))
+ b,c,_,h,w = x.shape
+
+ if self.isproj:
+ x = self.proj(x.view(b,c,-1,w))
+ x = x.view(b,-1,u,v,h,w)
+ return x
+
+
+class sepConv4dBlock(torch.nn.Module):
+ '''
+ Separable 4d convolution block as 2 2D convolutions and a projection
+ layer
+ '''
+ def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, full=True,groups=1):
+ super(sepConv4dBlock, self).__init__()
+ if in_planes == out_planes and stride==(1,1,1):
+ self.downsample = None
+ else:
+ if full:
+ self.downsample = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn,ksize=1, full=full,groups=groups)
+ else:
+ self.downsample = projfeat4d(in_planes, out_planes,stride[0], with_bn=with_bn,groups=groups)
+ self.conv1 = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn, full=full ,groups=groups)
+ self.conv2 = sepConv4d(out_planes, out_planes,(1,1,1), with_bn=with_bn, full=full,groups=groups)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.relu2 = nn.ReLU(inplace=True)
+
+ #@profile
+ def forward(self,x):
+ out = self.relu1(self.conv1(x))
+ if self.downsample:
+ x = self.downsample(x)
+ out = self.relu2(x + self.conv2(out))
+ return out
+
+
+##import torch.backends.cudnn as cudnn
+##cudnn.benchmark = True
+#import time
+##im = torch.randn(9,64,9,160,224).cuda()
+##net = torch.nn.Conv3d(64, 64, 3).cuda()
+##net = Conv4d(1,1,3,bias=True,pre_permuted_filters=True).cuda()
+##net = sepConv4dBlock(2,2,stride=(1,1,1)).cuda()
+#
+##im = torch.randn(1,16,9,9,96,320).cuda()
+##net = sepConv4d(16,16,with_bn=False).cuda()
+#
+##im = torch.randn(1,16,81,96,320).cuda()
+##net = torch.nn.Conv3d(16,16,(1,3,3),padding=(0,1,1)).cuda()
+#
+##im = torch.randn(1,16,9,9,96*320).cuda()
+##net = torch.nn.Conv3d(16,16,(3,3,1),padding=(1,1,0)).cuda()
+#
+##im = torch.randn(10000,10,9,9).cuda()
+##net = torch.nn.Conv2d(10,10,3,padding=1).cuda()
+#
+##im = torch.randn(81,16,96,320).cuda()
+##net = torch.nn.Conv2d(16,16,3,padding=1).cuda()
+#c= int(16 *1)
+#cp = int(16 *1)
+#h=int(96 *4)
+#w=int(320 *4)
+#k=3
+#im = torch.randn(1,c,h,w).cuda()
+#net = torch.nn.Conv2d(c,cp,k,padding=k//2).cuda()
+#
+#im2 = torch.randn(cp,k*k*c).cuda()
+#im1 = F.unfold(im, (k,k), padding=k//2)[0]
+#
+#
+#net(im)
+#net(im)
+#torch.mm(im2,im1)
+#torch.mm(im2,im1)
+#torch.cuda.synchronize()
+#beg = time.time()
+#for i in range(100):
+# net(im)
+# #im1 = F.unfold(im, (k,k), padding=k//2)[0]
+# torch.mm(im2,im1)
+#torch.cuda.synchronize()
+#print('%f'%((time.time()-beg)*10.))
diff --git a/expansion/models/submodule.py b/expansion/models/submodule.py
new file mode 100755
index 0000000000000000000000000000000000000000..0e9a032c776f320dc35aaa5a9219022232811709
--- /dev/null
+++ b/expansion/models/submodule.py
@@ -0,0 +1,450 @@
+from __future__ import print_function
+import torch
+import torch.nn as nn
+import torch.utils.data
+from torch.autograd import Variable
+import torch.nn.functional as F
+import math
+import numpy as np
+import pdb
+
+class residualBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_channels, n_filters, stride=1, downsample=None,dilation=1,with_bn=True):
+ super(residualBlock, self).__init__()
+ if dilation > 1:
+ padding = dilation
+ else:
+ padding = 1
+
+ if with_bn:
+ self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, padding, dilation=dilation)
+ self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1)
+ else:
+ self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, padding, dilation=dilation,with_bn=False)
+ self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, with_bn=False)
+ self.downsample = downsample
+ self.relu = nn.LeakyReLU(0.1, inplace=True)
+
+ def forward(self, x):
+ residual = x
+
+ out = self.convbnrelu1(x)
+ out = self.convbn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ return self.relu(out)
+
+def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
+ return nn.Sequential(
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation, bias=True),
+ nn.BatchNorm2d(out_planes),
+ nn.LeakyReLU(0.1,inplace=True))
+
+
+class conv2DBatchNorm(nn.Module):
+ def __init__(self, in_channels, n_filters, k_size, stride, padding, dilation=1, with_bn=True):
+ super(conv2DBatchNorm, self).__init__()
+ bias = not with_bn
+
+ if dilation > 1:
+ conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
+ padding=padding, stride=stride, bias=bias, dilation=dilation)
+
+ else:
+ conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
+ padding=padding, stride=stride, bias=bias, dilation=1)
+
+
+ if with_bn:
+ self.cb_unit = nn.Sequential(conv_mod,
+ nn.BatchNorm2d(int(n_filters)),)
+ else:
+ self.cb_unit = nn.Sequential(conv_mod,)
+
+ def forward(self, inputs):
+ outputs = self.cb_unit(inputs)
+ return outputs
+
+class conv2DBatchNormRelu(nn.Module):
+ def __init__(self, in_channels, n_filters, k_size, stride, padding, dilation=1, with_bn=True):
+ super(conv2DBatchNormRelu, self).__init__()
+ bias = not with_bn
+ if dilation > 1:
+ conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
+ padding=padding, stride=stride, bias=bias, dilation=dilation)
+
+ else:
+ conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
+ padding=padding, stride=stride, bias=bias, dilation=1)
+
+ if with_bn:
+ self.cbr_unit = nn.Sequential(conv_mod,
+ nn.BatchNorm2d(int(n_filters)),
+ nn.LeakyReLU(0.1, inplace=True),)
+ else:
+ self.cbr_unit = nn.Sequential(conv_mod,
+ nn.LeakyReLU(0.1, inplace=True),)
+
+ def forward(self, inputs):
+ outputs = self.cbr_unit(inputs)
+ return outputs
+
+class pyramidPooling(nn.Module):
+
+ def __init__(self, in_channels, with_bn=True, levels=4):
+ super(pyramidPooling, self).__init__()
+ self.levels = levels
+
+ self.paths = []
+ for i in range(levels):
+ self.paths.append(conv2DBatchNormRelu(in_channels, in_channels, 1, 1, 0, with_bn=with_bn))
+ self.path_module_list = nn.ModuleList(self.paths)
+ self.relu = nn.LeakyReLU(0.1, inplace=True)
+
+ def forward(self, x):
+ h, w = x.shape[2:]
+
+ k_sizes = []
+ strides = []
+ for pool_size in np.linspace(1,min(h,w)//2,self.levels,dtype=int):
+ k_sizes.append((int(h/pool_size), int(w/pool_size)))
+ strides.append((int(h/pool_size), int(w/pool_size)))
+ k_sizes = k_sizes[::-1]
+ strides = strides[::-1]
+
+ pp_sum = x
+
+ for i, module in enumerate(self.path_module_list):
+ out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
+ out = module(out)
+ out = F.upsample(out, size=(h,w), mode='bilinear')
+ pp_sum = pp_sum + 1./self.levels*out
+ pp_sum = self.relu(pp_sum/2.)
+
+ return pp_sum
+
+class pspnet(nn.Module):
+ """
+ Modified PSPNet. https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/pspnet.py
+ """
+ def __init__(self, is_proj=True,groups=1):
+ super(pspnet, self).__init__()
+ self.inplanes = 32
+ self.is_proj = is_proj
+
+ # Encoder
+ self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16,
+ padding=1, stride=2)
+ self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16,
+ padding=1, stride=1)
+ self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32,
+ padding=1, stride=1)
+ # Vanilla Residual Blocks
+ self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
+ self.res_block5 = self._make_layer(residualBlock,128,1,stride=2)
+ self.res_block6 = self._make_layer(residualBlock,128,1,stride=2)
+ self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
+ self.pyramid_pooling = pyramidPooling(128, levels=3)
+
+ # Iconvs
+ self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2),
+ conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
+ padding=1, stride=1))
+ self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
+ padding=1, stride=1)
+ self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2),
+ conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
+ padding=1, stride=1))
+ self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
+ padding=1, stride=1)
+ self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2),
+ conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
+ padding=1, stride=1))
+ self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
+ padding=1, stride=1)
+ self.upconv3 = nn.Sequential(nn.Upsample(scale_factor=2),
+ conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
+ padding=1, stride=1))
+ self.iconv2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64,
+ padding=1, stride=1)
+
+ if self.is_proj:
+ self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
+ self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
+ self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
+ self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
+ self.proj2 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ if hasattr(m.bias,'data'):
+ m.bias.data.zero_()
+
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),)
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # H, W -> H/2, W/2
+ conv1 = self.convbnrelu1_1(x)
+ conv1 = self.convbnrelu1_2(conv1)
+ conv1 = self.convbnrelu1_3(conv1)
+
+ ## H/2, W/2 -> H/4, W/4
+ pool1 = F.max_pool2d(conv1, 3, 2, 1)
+
+ # H/4, W/4 -> H/16, W/16
+ rconv3 = self.res_block3(pool1)
+ conv4 = self.res_block5(rconv3)
+ conv5 = self.res_block6(conv4)
+ conv6 = self.res_block7(conv5)
+ conv6 = self.pyramid_pooling(conv6)
+
+ conv6x = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]],mode='bilinear')
+ concat5 = torch.cat((conv5,self.upconv6[1](conv6x)),dim=1)
+ conv5 = self.iconv5(concat5)
+
+ conv5x = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]],mode='bilinear')
+ concat4 = torch.cat((conv4,self.upconv5[1](conv5x)),dim=1)
+ conv4 = self.iconv4(concat4)
+
+ conv4x = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]],mode='bilinear')
+ concat3 = torch.cat((rconv3,self.upconv4[1](conv4x)),dim=1)
+ conv3 = self.iconv3(concat3)
+
+ conv3x = F.upsample(conv3, [pool1.size()[2],pool1.size()[3]],mode='bilinear')
+ concat2 = torch.cat((pool1,self.upconv3[1](conv3x)),dim=1)
+ conv2 = self.iconv2(concat2)
+
+ if self.is_proj:
+ proj6 = self.proj6(conv6)
+ proj5 = self.proj5(conv5)
+ proj4 = self.proj4(conv4)
+ proj3 = self.proj3(conv3)
+ proj2 = self.proj2(conv2)
+ return proj6,proj5,proj4,proj3,proj2
+ else:
+ return conv6, conv5, conv4, conv3, conv2
+
+
+class pspnet_s(nn.Module):
+ """
+ Modified PSPNet. https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/pspnet.py
+ """
+ def __init__(self, is_proj=True,groups=1):
+ super(pspnet_s, self).__init__()
+ self.inplanes = 32
+ self.is_proj = is_proj
+
+ # Encoder
+ self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16,
+ padding=1, stride=2)
+ self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16,
+ padding=1, stride=1)
+ self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32,
+ padding=1, stride=1)
+ # Vanilla Residual Blocks
+ self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
+ self.res_block5 = self._make_layer(residualBlock,128,1,stride=2)
+ self.res_block6 = self._make_layer(residualBlock,128,1,stride=2)
+ self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
+ self.pyramid_pooling = pyramidPooling(128, levels=3)
+
+ # Iconvs
+ self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2),
+ conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
+ padding=1, stride=1))
+ self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
+ padding=1, stride=1)
+ self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2),
+ conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
+ padding=1, stride=1))
+ self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
+ padding=1, stride=1)
+ self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2),
+ conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
+ padding=1, stride=1))
+ self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
+ padding=1, stride=1)
+ #self.upconv3 = nn.Sequential(nn.Upsample(scale_factor=2),
+ # conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
+ # padding=1, stride=1))
+ #self.iconv2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64,
+ # padding=1, stride=1)
+
+ if self.is_proj:
+ self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
+ self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
+ self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
+ self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
+ #self.proj2 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ if hasattr(m.bias,'data'):
+ m.bias.data.zero_()
+
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),)
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # H, W -> H/2, W/2
+ conv1 = self.convbnrelu1_1(x)
+ conv1 = self.convbnrelu1_2(conv1)
+ conv1 = self.convbnrelu1_3(conv1)
+
+ ## H/2, W/2 -> H/4, W/4
+ pool1 = F.max_pool2d(conv1, 3, 2, 1)
+
+ # H/4, W/4 -> H/16, W/16
+ rconv3 = self.res_block3(pool1)
+ conv4 = self.res_block5(rconv3)
+ conv5 = self.res_block6(conv4)
+ conv6 = self.res_block7(conv5)
+ conv6 = self.pyramid_pooling(conv6)
+
+ conv6x = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]],mode='bilinear')
+ concat5 = torch.cat((conv5,self.upconv6[1](conv6x)),dim=1)
+ conv5 = self.iconv5(concat5)
+
+ conv5x = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]],mode='bilinear')
+ concat4 = torch.cat((conv4,self.upconv5[1](conv5x)),dim=1)
+ conv4 = self.iconv4(concat4)
+
+ conv4x = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]],mode='bilinear')
+ concat3 = torch.cat((rconv3,self.upconv4[1](conv4x)),dim=1)
+ conv3 = self.iconv3(concat3)
+
+ #conv3x = F.upsample(conv3, [pool1.size()[2],pool1.size()[3]],mode='bilinear')
+ #concat2 = torch.cat((pool1,self.upconv3[1](conv3x)),dim=1)
+ #conv2 = self.iconv2(concat2)
+
+ if self.is_proj:
+ proj6 = self.proj6(conv6)
+ proj5 = self.proj5(conv5)
+ proj4 = self.proj4(conv4)
+ proj3 = self.proj3(conv3)
+ # proj2 = self.proj2(conv2)
+ # return proj6,proj5,proj4,proj3,proj2
+ return proj6,proj5,proj4,proj3
+ else:
+ # return conv6, conv5, conv4, conv3, conv2
+ return conv6, conv5, conv4, conv3
+
+class bfmodule(nn.Module):
+ def __init__(self, inplanes, outplanes):
+ super(bfmodule, self).__init__()
+ self.proj = conv2DBatchNormRelu(in_channels=inplanes,k_size=1,n_filters=64,padding=0,stride=1)
+ self.inplanes = 64
+ # Vanilla Residual Blocks
+ self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
+ self.res_block5 = self._make_layer(residualBlock,64,1,stride=2)
+ self.res_block6 = self._make_layer(residualBlock,64,1,stride=2)
+ self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
+ self.pyramid_pooling = pyramidPooling(128, levels=3)
+ # Iconvs
+ self.upconv6 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
+ padding=1, stride=1)
+ self.upconv5 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
+ padding=1, stride=1)
+ self.upconv4 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
+ padding=1, stride=1)
+ self.upconv3 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
+ padding=1, stride=1)
+ self.iconv5 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
+ padding=1, stride=1)
+ self.iconv4 = conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64,
+ padding=1, stride=1)
+ self.iconv3 = conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64,
+ padding=1, stride=1)
+ self.iconv2 = nn.Sequential(conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64,
+ padding=1, stride=1),
+ nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True))
+
+ self.proj6 = nn.Conv2d(128, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
+ self.proj5 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
+ self.proj4 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
+ self.proj3 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ if hasattr(m.bias,'data'):
+ m.bias.data.zero_()
+
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),)
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ proj = self.proj(x) # 4x
+ rconv3 = self.res_block3(proj) #8x
+ conv4 = self.res_block5(rconv3) #16x
+ conv5 = self.res_block6(conv4) #32x
+ conv6 = self.res_block7(conv5) #64x
+ conv6 = self.pyramid_pooling(conv6) #64x
+ pred6 = self.proj6(conv6)
+
+ conv6u = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]], mode='bilinear')
+ concat5 = torch.cat((conv5,self.upconv6(conv6u)),dim=1)
+ conv5 = self.iconv5(concat5) #32x
+ pred5 = self.proj5(conv5)
+
+ conv5u = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]], mode='bilinear')
+ concat4 = torch.cat((conv4,self.upconv5(conv5u)),dim=1)
+ conv4 = self.iconv4(concat4) #16x
+ pred4 = self.proj4(conv4)
+
+ conv4u = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]], mode='bilinear')
+ concat3 = torch.cat((rconv3,self.upconv4(conv4u)),dim=1)
+ conv3 = self.iconv3(concat3) # 8x
+ pred3 = self.proj3(conv3)
+
+ conv3u = F.upsample(conv3, [x.size()[2],x.size()[3]], mode='bilinear')
+ concat2 = torch.cat((proj,self.upconv3(conv3u)),dim=1)
+ pred2 = self.iconv2(concat2) # 4x
+
+ return pred2, pred3, pred4, pred5, pred6
+
diff --git a/expansion/submission.py b/expansion/submission.py
new file mode 100755
index 0000000000000000000000000000000000000000..21ebb64061b13b9b00e3c52f606d42612de3812c
--- /dev/null
+++ b/expansion/submission.py
@@ -0,0 +1,95 @@
+from __future__ import print_function
+import sys
+import cv2
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.backends.cudnn as cudnn
+import torch.optim as optim
+import torch.nn.functional as F
+cudnn.benchmark = False
+
+class Expansion():
+
+ def __init__(self, loadmodel = 'pretrained_models/optical_expansion/robust.pth', testres = 1, maxdisp = 256, fac = 1):
+
+ maxw,maxh = [int(testres*1280), int(testres*384)]
+
+ max_h = int(maxh // 64 * 64)
+ max_w = int(maxw // 64 * 64)
+ if max_h < maxh: max_h += 64
+ if max_w < maxw: max_w += 64
+ maxh = max_h
+ maxw = max_w
+
+ mean_L = [[0.33,0.33,0.33]]
+ mean_R = [[0.33,0.33,0.33]]
+
+ # construct model, VCN-expansion
+ from expansion.models.VCN_exp import VCN
+ model = VCN([1, maxw, maxh], md=[int(4*(maxdisp/256)),4,4,4,4], fac=fac,
+ exp_unc=('robust' in loadmodel)) # expansion uncertainty only in the new model
+ model = nn.DataParallel(model, device_ids=[0])
+ model.cuda()
+
+ if loadmodel is not None:
+ pretrained_dict = torch.load(loadmodel)
+ mean_L=pretrained_dict['mean_L']
+ mean_R=pretrained_dict['mean_R']
+ pretrained_dict['state_dict'] = {k:v for k,v in pretrained_dict['state_dict'].items()}
+ model.load_state_dict(pretrained_dict['state_dict'],strict=False)
+ else:
+ print('dry run')
+
+ model.eval()
+ # resize
+ maxh = 256
+ maxw = 256
+ max_h = int(maxh // 64 * 64)
+ max_w = int(maxw // 64 * 64)
+ if max_h < maxh: max_h += 64
+ if max_w < maxw: max_w += 64
+
+ # modify module according to inputs
+ from expansion.models.VCN_exp import WarpModule, flow_reg
+ for i in range(len(model.module.reg_modules)):
+ model.module.reg_modules[i] = flow_reg([1,max_w//(2**(6-i)), max_h//(2**(6-i))],
+ ent=getattr(model.module, 'flow_reg%d'%2**(6-i)).ent,\
+ maxdisp=getattr(model.module, 'flow_reg%d'%2**(6-i)).md,\
+ fac=getattr(model.module, 'flow_reg%d'%2**(6-i)).fac).cuda()
+ for i in range(len(model.module.warp_modules)):
+ model.module.warp_modules[i] = WarpModule([1,max_w//(2**(6-i)), max_h//(2**(6-i))]).cuda()
+
+ mean_L = torch.from_numpy(np.asarray(mean_L).astype(np.float32).mean(0)[np.newaxis,:,np.newaxis,np.newaxis]).cuda()
+ mean_R = torch.from_numpy(np.asarray(mean_R).astype(np.float32).mean(0)[np.newaxis,:,np.newaxis,np.newaxis]).cuda()
+
+ self.max_h = max_h
+ self.max_w = max_w
+ self.model = model
+ self.mean_L = mean_L
+ self.mean_R = mean_R
+
+ def run(self, imgL_o, imgR_o):
+ model = self.model
+ mean_L = self.mean_L
+ mean_R = self.mean_R
+
+ imgL_o[imgL_o<-1] = -1
+ imgL_o[imgL_o>1] = 1
+ imgR_o[imgR_o<-1] = -1
+ imgR_o[imgR_o>1] = 1
+ imgL = (imgL_o+1.)*0.5-mean_L
+ imgR = (imgR_o*1.)*0.5-mean_R
+
+ with torch.no_grad():
+ imgLR = torch.cat([imgL,imgR],0)
+ model.eval()
+ torch.cuda.synchronize()
+ rts = model(imgLR)
+ torch.cuda.synchronize()
+ flow, occ, logmid, logexp = rts
+
+ torch.cuda.empty_cache()
+
+ return flow, logexp
diff --git a/expansion/utils/__init__.py b/expansion/utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/expansion/utils/__pycache__/__init__.cpython-38.pyc b/expansion/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..2d86925f5b75d1d7208c8e21a5d8820c67d1396a
Binary files /dev/null and b/expansion/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/expansion/utils/__pycache__/flowlib.cpython-38.pyc b/expansion/utils/__pycache__/flowlib.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..098a490b5ccbdea2a9fa168b66c74dba9c9eef95
Binary files /dev/null and b/expansion/utils/__pycache__/flowlib.cpython-38.pyc differ
diff --git a/expansion/utils/__pycache__/io.cpython-38.pyc b/expansion/utils/__pycache__/io.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..e065c426452bde7ddd36ae2952aa4b392728eea6
Binary files /dev/null and b/expansion/utils/__pycache__/io.cpython-38.pyc differ
diff --git a/expansion/utils/__pycache__/pfm.cpython-38.pyc b/expansion/utils/__pycache__/pfm.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..ed1807440f8c97584409a5335134a826ae6d83ea
Binary files /dev/null and b/expansion/utils/__pycache__/pfm.cpython-38.pyc differ
diff --git a/expansion/utils/__pycache__/util_flow.cpython-38.pyc b/expansion/utils/__pycache__/util_flow.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..510d6d73c7122b777e80935b6d9ac0b567da2cb0
Binary files /dev/null and b/expansion/utils/__pycache__/util_flow.cpython-38.pyc differ
diff --git a/expansion/utils/flowlib.py b/expansion/utils/flowlib.py
new file mode 100755
index 0000000000000000000000000000000000000000..59096c15da5e529ecb2d85be4881ee467f0c838f
--- /dev/null
+++ b/expansion/utils/flowlib.py
@@ -0,0 +1,656 @@
+"""
+# ==============================
+# flowlib.py
+# library for optical flow processing
+# Author: Ruoteng Li
+# Date: 6th Aug 2016
+# ==============================
+"""
+import png
+from . import pfm
+import numpy as np
+import matplotlib.colors as cl
+import matplotlib.pyplot as plt
+from PIL import Image
+import cv2
+import pdb
+
+
+UNKNOWN_FLOW_THRESH = 1e7
+SMALLFLOW = 0.0
+LARGEFLOW = 1e8
+
+"""
+=============
+Flow Section
+=============
+"""
+
+def point_vec(img,flow,skip=16):
+ #img[:] = 255
+ maxsize=256
+ extendfac=2.
+ resize_factor = max(1,int(max(maxsize/img.shape[0], maxsize/img.shape[1])))
+ meshgrid = np.meshgrid(range(img.shape[1]),range(img.shape[0]))
+ dispimg = cv2.resize(img[:,:,::-1].copy(), None,fx=resize_factor,fy=resize_factor)
+ colorflow = flow_to_image(flow).astype(int)
+
+ for i in range(img.shape[1]): # x
+ for j in range(img.shape[0]): # y
+ if flow[j,i,2] != 1: continue
+ if j%skip!=0 or i%skip!=0: continue
+ xend = int((meshgrid[0][j,i]+extendfac*flow[j,i,0])*resize_factor)
+ yend = int((meshgrid[1][j,i]+extendfac*flow[j,i,1])*resize_factor)
+ leng = np.linalg.norm(flow[j,i,:2]*extendfac)
+ if leng<3:continue
+ dispimg = cv2.arrowedLine(dispimg, (meshgrid[0][j,i]*resize_factor,meshgrid[1][j,i]*resize_factor),\
+ (xend,yend),
+ (int(colorflow[j,i,2]),int(colorflow[j,i,1]),int(colorflow[j,i,0])),4,tipLength=2/leng,line_type=cv2.LINE_AA)
+ return dispimg
+
+
+def show_flow(filename):
+ """
+ visualize optical flow map using matplotlib
+ :param filename: optical flow file
+ :return: None
+ """
+ flow = read_flow(filename)
+ img = flow_to_image(flow)
+ plt.imshow(img)
+ plt.show()
+
+
+def visualize_flow(flow, mode='Y'):
+ """
+ this function visualize the input flow
+ :param flow: input flow in array
+ :param mode: choose which color mode to visualize the flow (Y: Ccbcr, RGB: RGB color)
+ :return: None
+ """
+ if mode == 'Y':
+ # Ccbcr color wheel
+ img = flow_to_image(flow)
+ plt.imshow(img)
+ plt.show()
+ elif mode == 'RGB':
+ (h, w) = flow.shape[0:2]
+ du = flow[:, :, 0]
+ dv = flow[:, :, 1]
+ valid = flow[:, :, 2]
+ max_flow = max(np.max(du), np.max(dv))
+ img = np.zeros((h, w, 3), dtype=np.float64)
+ # angle layer
+ img[:, :, 0] = np.arctan2(dv, du) / (2 * np.pi)
+ # magnitude layer, normalized to 1
+ img[:, :, 1] = np.sqrt(du * du + dv * dv) * 8 / max_flow
+ # phase layer
+ img[:, :, 2] = 8 - img[:, :, 1]
+ # clip to [0,1]
+ small_idx = img[:, :, 0:3] < 0
+ large_idx = img[:, :, 0:3] > 1
+ img[small_idx] = 0
+ img[large_idx] = 1
+ # convert to rgb
+ img = cl.hsv_to_rgb(img)
+ # remove invalid point
+ import pdb; pdb.set_trace()
+ img[:, :, 0] = img[:, :, 0] * valid
+ img[:, :, 1] = img[:, :, 1] * valid
+ img[:, :, 2] = img[:, :, 2] * valid
+ # show
+ plt.imshow(img)
+ plt.show()
+
+ return None
+
+
+def read_flow(filename):
+ """
+ read optical flow data from flow file
+ :param filename: name of the flow file
+ :return: optical flow data in numpy array
+ """
+ if filename.endswith('.flo'):
+ flow = read_flo_file(filename)
+ elif filename.endswith('.png'):
+ flow = read_png_file(filename)
+ elif filename.endswith('.pfm'):
+ flow = read_pfm_file(filename)
+ else:
+ raise Exception('Invalid flow file format!')
+
+ return flow
+
+
+def write_flow(flow, filename):
+ """
+ write optical flow in Middlebury .flo format
+ :param flow: optical flow map
+ :param filename: optical flow file path to be saved
+ :return: None
+ """
+ f = open(filename, 'wb')
+ magic = np.array([202021.25], dtype=np.float32)
+ (height, width) = flow.shape[0:2]
+ w = np.array([width], dtype=np.int32)
+ h = np.array([height], dtype=np.int32)
+ magic.tofile(f)
+ w.tofile(f)
+ h.tofile(f)
+ flow.tofile(f)
+ f.close()
+
+
+def save_flow_image(flow, image_file):
+ """
+ save flow visualization into image file
+ :param flow: optical flow data
+ :param flow_fil
+ :return: None
+ """
+ flow_img = flow_to_image(flow)
+ img_out = Image.fromarray(flow_img)
+ img_out.save(image_file)
+
+
+def flowfile_to_imagefile(flow_file, image_file):
+ """
+ convert flowfile into image file
+ :param flow: optical flow data
+ :param flow_fil
+ :return: None
+ """
+ flow = read_flow(flow_file)
+ save_flow_image(flow, image_file)
+
+
+def segment_flow(flow):
+ h = flow.shape[0]
+ w = flow.shape[1]
+ u = flow[:, :, 0]
+ v = flow[:, :, 1]
+
+ idx = ((abs(u) > LARGEFLOW) | (abs(v) > LARGEFLOW))
+ idx2 = (abs(u) == SMALLFLOW)
+ class0 = (v == 0) & (u == 0)
+ u[idx2] = 0.00001
+ tan_value = v / u
+
+ class1 = (tan_value < 1) & (tan_value >= 0) & (u > 0) & (v >= 0)
+ class2 = (tan_value >= 1) & (u >= 0) & (v >= 0)
+ class3 = (tan_value < -1) & (u <= 0) & (v >= 0)
+ class4 = (tan_value < 0) & (tan_value >= -1) & (u < 0) & (v >= 0)
+ class8 = (tan_value >= -1) & (tan_value < 0) & (u > 0) & (v <= 0)
+ class7 = (tan_value < -1) & (u >= 0) & (v <= 0)
+ class6 = (tan_value >= 1) & (u <= 0) & (v <= 0)
+ class5 = (tan_value >= 0) & (tan_value < 1) & (u < 0) & (v <= 0)
+
+ seg = np.zeros((h, w))
+
+ seg[class1] = 1
+ seg[class2] = 2
+ seg[class3] = 3
+ seg[class4] = 4
+ seg[class5] = 5
+ seg[class6] = 6
+ seg[class7] = 7
+ seg[class8] = 8
+ seg[class0] = 0
+ seg[idx] = 0
+
+ return seg
+
+
+def flow_error(tu, tv, u, v):
+ """
+ Calculate average end point error
+ :param tu: ground-truth horizontal flow map
+ :param tv: ground-truth vertical flow map
+ :param u: estimated horizontal flow map
+ :param v: estimated vertical flow map
+ :return: End point error of the estimated flow
+ """
+ smallflow = 0.0
+ '''
+ stu = tu[bord+1:end-bord,bord+1:end-bord]
+ stv = tv[bord+1:end-bord,bord+1:end-bord]
+ su = u[bord+1:end-bord,bord+1:end-bord]
+ sv = v[bord+1:end-bord,bord+1:end-bord]
+ '''
+ stu = tu[:]
+ stv = tv[:]
+ su = u[:]
+ sv = v[:]
+
+ idxUnknow = (abs(stu) > UNKNOWN_FLOW_THRESH) | (abs(stv) > UNKNOWN_FLOW_THRESH)
+ stu[idxUnknow] = 0
+ stv[idxUnknow] = 0
+ su[idxUnknow] = 0
+ sv[idxUnknow] = 0
+
+ ind2 = [(np.absolute(stu) > smallflow) | (np.absolute(stv) > smallflow)]
+ index_su = su[ind2]
+ index_sv = sv[ind2]
+ an = 1.0 / np.sqrt(index_su ** 2 + index_sv ** 2 + 1)
+ un = index_su * an
+ vn = index_sv * an
+
+ index_stu = stu[ind2]
+ index_stv = stv[ind2]
+ tn = 1.0 / np.sqrt(index_stu ** 2 + index_stv ** 2 + 1)
+ tun = index_stu * tn
+ tvn = index_stv * tn
+
+ '''
+ angle = un * tun + vn * tvn + (an * tn)
+ index = [angle == 1.0]
+ angle[index] = 0.999
+ ang = np.arccos(angle)
+ mang = np.mean(ang)
+ mang = mang * 180 / np.pi
+ '''
+
+ epe = np.sqrt((stu - su) ** 2 + (stv - sv) ** 2)
+ epe = epe[ind2]
+ mepe = np.mean(epe)
+ return mepe
+
+
+def flow_to_image(flow):
+ """
+ Convert flow into middlebury color code image
+ :param flow: optical flow map
+ :return: optical flow image in middlebury color
+ """
+ u = flow[:, :, 0]
+ v = flow[:, :, 1]
+
+ maxu = -999.
+ maxv = -999.
+ minu = 999.
+ minv = 999.
+
+ idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
+ u[idxUnknow] = 0
+ v[idxUnknow] = 0
+
+ maxu = max(maxu, np.max(u))
+ minu = min(minu, np.min(u))
+
+ maxv = max(maxv, np.max(v))
+ minv = min(minv, np.min(v))
+
+ rad = np.sqrt(u ** 2 + v ** 2)
+ maxrad = max(-1, np.max(rad))
+
+ u = u/(maxrad + np.finfo(float).eps)
+ v = v/(maxrad + np.finfo(float).eps)
+
+ img = compute_color(u, v)
+
+ idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
+ img[idx] = 0
+
+ return np.uint8(img)
+
+
+def evaluate_flow_file(gt_file, pred_file):
+ """
+ evaluate the estimated optical flow end point error according to ground truth provided
+ :param gt_file: ground truth file path
+ :param pred_file: estimated optical flow file path
+ :return: end point error, float32
+ """
+ # Read flow files and calculate the errors
+ gt_flow = read_flow(gt_file) # ground truth flow
+ eva_flow = read_flow(pred_file) # predicted flow
+ # Calculate errors
+ average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], eva_flow[:, :, 0], eva_flow[:, :, 1])
+ return average_pe
+
+
+def evaluate_flow(gt_flow, pred_flow):
+ """
+ gt: ground-truth flow
+ pred: estimated flow
+ """
+ average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], pred_flow[:, :, 0], pred_flow[:, :, 1])
+ return average_pe
+
+
+"""
+==============
+Disparity Section
+==============
+"""
+
+
+def read_disp_png(file_name):
+ """
+ Read optical flow from KITTI .png file
+ :param file_name: name of the flow file
+ :return: optical flow data in matrix
+ """
+ image_object = png.Reader(filename=file_name)
+ image_direct = image_object.asDirect()
+ image_data = list(image_direct[2])
+ (w, h) = image_direct[3]['size']
+ channel = len(image_data[0]) / w
+ flow = np.zeros((h, w, channel), dtype=np.uint16)
+ for i in range(len(image_data)):
+ for j in range(channel):
+ flow[i, :, j] = image_data[i][j::channel]
+ return flow[:, :, 0] / 256
+
+
+def disp_to_flowfile(disp, filename):
+ """
+ Read KITTI disparity file in png format
+ :param disp: disparity matrix
+ :param filename: the flow file name to save
+ :return: None
+ """
+ f = open(filename, 'wb')
+ magic = np.array([202021.25], dtype=np.float32)
+ (height, width) = disp.shape[0:2]
+ w = np.array([width], dtype=np.int32)
+ h = np.array([height], dtype=np.int32)
+ empty_map = np.zeros((height, width), dtype=np.float32)
+ data = np.dstack((disp, empty_map))
+ magic.tofile(f)
+ w.tofile(f)
+ h.tofile(f)
+ data.tofile(f)
+ f.close()
+
+
+"""
+==============
+Image Section
+==============
+"""
+
+
+def read_image(filename):
+ """
+ Read normal image of any format
+ :param filename: name of the image file
+ :return: image data in matrix uint8 type
+ """
+ img = Image.open(filename)
+ im = np.array(img)
+ return im
+
+def warp_flow(img, flow):
+ h, w = flow.shape[:2]
+ flow = flow.copy().astype(np.float32)
+ flow[:,:,0] += np.arange(w)
+ flow[:,:,1] += np.arange(h)[:,np.newaxis]
+ res = cv2.remap(img, flow, None, cv2.INTER_LINEAR)
+ return res
+
+def warp_image(im, flow):
+ """
+ Use optical flow to warp image to the next
+ :param im: image to warp
+ :param flow: optical flow
+ :return: warped image
+ """
+ from scipy import interpolate
+ image_height = im.shape[0]
+ image_width = im.shape[1]
+ flow_height = flow.shape[0]
+ flow_width = flow.shape[1]
+ n = image_height * image_width
+ (iy, ix) = np.mgrid[0:image_height, 0:image_width]
+ (fy, fx) = np.mgrid[0:flow_height, 0:flow_width]
+ fx = fx.astype(np.float64)
+ fy = fy.astype(np.float64)
+ fx += flow[:,:,0]
+ fy += flow[:,:,1]
+ mask = np.logical_or(fx <0 , fx > flow_width)
+ mask = np.logical_or(mask, fy < 0)
+ mask = np.logical_or(mask, fy > flow_height)
+ fx = np.minimum(np.maximum(fx, 0), flow_width)
+ fy = np.minimum(np.maximum(fy, 0), flow_height)
+ points = np.concatenate((ix.reshape(n,1), iy.reshape(n,1)), axis=1)
+ xi = np.concatenate((fx.reshape(n, 1), fy.reshape(n,1)), axis=1)
+ warp = np.zeros((image_height, image_width, im.shape[2]))
+ for i in range(im.shape[2]):
+ channel = im[:, :, i]
+ plt.imshow(channel, cmap='gray')
+ values = channel.reshape(n, 1)
+ new_channel = interpolate.griddata(points, values, xi, method='cubic')
+ new_channel = np.reshape(new_channel, [flow_height, flow_width])
+ new_channel[mask] = 1
+ warp[:, :, i] = new_channel.astype(np.uint8)
+
+ return warp.astype(np.uint8)
+
+
+"""
+==============
+Others
+==============
+"""
+
+def pfm_to_flo(pfm_file):
+ flow_filename = pfm_file[0:pfm_file.find('.pfm')] + '.flo'
+ (data, scale) = pfm.readPFM(pfm_file)
+ flow = data[:, :, 0:2]
+ write_flow(flow, flow_filename)
+
+
+def scale_image(image, new_range):
+ """
+ Linearly scale the image into desired range
+ :param image: input image
+ :param new_range: the new range to be aligned
+ :return: image normalized in new range
+ """
+ min_val = np.min(image).astype(np.float32)
+ max_val = np.max(image).astype(np.float32)
+ min_val_new = np.array(min(new_range), dtype=np.float32)
+ max_val_new = np.array(max(new_range), dtype=np.float32)
+ scaled_image = (image - min_val) / (max_val - min_val) * (max_val_new - min_val_new) + min_val_new
+ return scaled_image.astype(np.uint8)
+
+
+def compute_color(u, v):
+ """
+ compute optical flow color map
+ :param u: optical flow horizontal map
+ :param v: optical flow vertical map
+ :return: optical flow in color code
+ """
+ [h, w] = u.shape
+ img = np.zeros([h, w, 3])
+ nanIdx = np.isnan(u) | np.isnan(v)
+ u[nanIdx] = 0
+ v[nanIdx] = 0
+
+ colorwheel = make_color_wheel()
+ ncols = np.size(colorwheel, 0)
+
+ rad = np.sqrt(u**2+v**2)
+
+ a = np.arctan2(-v, -u) / np.pi
+
+ fk = (a+1) / 2 * (ncols - 1) + 1
+
+ k0 = np.floor(fk).astype(int)
+
+ k1 = k0 + 1
+ k1[k1 == ncols+1] = 1
+ f = fk - k0
+
+ for i in range(0, np.size(colorwheel,1)):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0-1] / 255
+ col1 = tmp[k1-1] / 255
+ col = (1-f) * col0 + f * col1
+
+ idx = rad <= 1
+ col[idx] = 1-rad[idx]*(1-col[idx])
+ notidx = np.logical_not(idx)
+
+ col[notidx] *= 0.75
+ img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
+
+ return img
+
+
+def make_color_wheel():
+ """
+ Generate color wheel according Middlebury color code
+ :return: Color wheel
+ """
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+
+ colorwheel = np.zeros([ncols, 3])
+
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
+ col += RY
+
+ # YG
+ colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
+ colorwheel[col:col+YG, 1] = 255
+ col += YG
+
+ # GC
+ colorwheel[col:col+GC, 1] = 255
+ colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
+ col += GC
+
+ # CB
+ colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
+ colorwheel[col:col+CB, 2] = 255
+ col += CB
+
+ # BM
+ colorwheel[col:col+BM, 2] = 255
+ colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
+ col += + BM
+
+ # MR
+ colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
+ colorwheel[col:col+MR, 0] = 255
+
+ return colorwheel
+
+
+def read_flo_file(filename):
+ """
+ Read from Middlebury .flo file
+ :param flow_file: name of the flow file
+ :return: optical flow data in matrix
+ """
+ f = open(filename, 'rb')
+ magic = np.fromfile(f, np.float32, count=1)
+ data2d = None
+
+ if 202021.25 != magic:
+ print('Magic number incorrect. Invalid .flo file')
+ else:
+ w = np.fromfile(f, np.int32, count=1)
+ h = np.fromfile(f, np.int32, count=1)
+ #print("Reading %d x %d flow file in .flo format" % (h, w))
+ flow = np.ones((h[0],w[0],3))
+ data2d = np.fromfile(f, np.float32, count=2 * w[0] * h[0])
+ # reshape data into 3D array (columns, rows, channels)
+ data2d = np.resize(data2d, (h[0], w[0], 2))
+ flow[:,:,:2] = data2d
+ f.close()
+ return flow
+
+
+def read_png_file(flow_file):
+ """
+ Read from KITTI .png file
+ :param flow_file: name of the flow file
+ :return: optical flow data in matrix
+ """
+ flow = cv2.imread(flow_file,-1)[:,:,::-1].astype(np.float64)
+ # flow_object = png.Reader(filename=flow_file)
+ # flow_direct = flow_object.asDirect()
+ # flow_data = list(flow_direct[2])
+ # (w, h) = flow_direct[3]['size']
+ # #print("Reading %d x %d flow file in .png format" % (h, w))
+ # flow = np.zeros((h, w, 3), dtype=np.float64)
+ # for i in range(len(flow_data)):
+ # flow[i, :, 0] = flow_data[i][0::3]
+ # flow[i, :, 1] = flow_data[i][1::3]
+ # flow[i, :, 2] = flow_data[i][2::3]
+
+ invalid_idx = (flow[:, :, 2] == 0)
+ flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0
+ flow[invalid_idx, 0] = 0
+ flow[invalid_idx, 1] = 0
+ return flow
+
+
+def read_pfm_file(flow_file):
+ """
+ Read from .pfm file
+ :param flow_file: name of the flow file
+ :return: optical flow data in matrix
+ """
+ (data, scale) = pfm.readPFM(flow_file)
+ return data
+
+
+# fast resample layer
+def resample(img, sz):
+ """
+ img: flow map to be resampled
+ sz: new flow map size. Must be [height,weight]
+ """
+ original_image_size = img.shape
+ in_height = img.shape[0]
+ in_width = img.shape[1]
+ out_height = sz[0]
+ out_width = sz[1]
+ out_flow = np.zeros((out_height, out_width, 2))
+ # find scale
+ height_scale = float(in_height) / float(out_height)
+ width_scale = float(in_width) / float(out_width)
+
+ [x,y] = np.meshgrid(range(out_width), range(out_height))
+ xx = x * width_scale
+ yy = y * height_scale
+ x0 = np.floor(xx).astype(np.int32)
+ x1 = x0 + 1
+ y0 = np.floor(yy).astype(np.int32)
+ y1 = y0 + 1
+
+ x0 = np.clip(x0,0,in_width-1)
+ x1 = np.clip(x1,0,in_width-1)
+ y0 = np.clip(y0,0,in_height-1)
+ y1 = np.clip(y1,0,in_height-1)
+
+ Ia = img[y0,x0,:]
+ Ib = img[y1,x0,:]
+ Ic = img[y0,x1,:]
+ Id = img[y1,x1,:]
+
+ wa = (y1-yy) * (x1-xx)
+ wb = (yy-y0) * (x1-xx)
+ wc = (y1-yy) * (xx-x0)
+ wd = (yy-y0) * (xx-x0)
+ out_flow[:,:,0] = (Ia[:,:,0]*wa + Ib[:,:,0]*wb + Ic[:,:,0]*wc + Id[:,:,0]*wd) * out_width / in_width
+ out_flow[:,:,1] = (Ia[:,:,1]*wa + Ib[:,:,1]*wb + Ic[:,:,1]*wc + Id[:,:,1]*wd) * out_height / in_height
+
+ return out_flow
+
diff --git a/expansion/utils/io.py b/expansion/utils/io.py
new file mode 100755
index 0000000000000000000000000000000000000000..4677d32651b9f017142f4047f938052b1901ee4f
--- /dev/null
+++ b/expansion/utils/io.py
@@ -0,0 +1,164 @@
+import errno
+import os
+import shutil
+import sys
+import traceback
+import zipfile
+
+if sys.version_info[0] == 2:
+ import urllib2
+else:
+ import urllib.request
+
+
+# Converts a string to bytes (for writing the string into a file). Provided for
+# compatibility with Python 2 and 3.
+def StrToBytes(text):
+ if sys.version_info[0] == 2:
+ return text
+ else:
+ return bytes(text, 'UTF-8')
+
+
+# Outputs the given text and lets the user input a response (submitted by
+# pressing the return key). Provided for compatibility with Python 2 and 3.
+def GetUserInput(text):
+ if sys.version_info[0] == 2:
+ return raw_input(text)
+ else:
+ return input(text)
+
+
+# Creates the given directory (hierarchy), which may already exist. Provided for
+# compatibility with Python 2 and 3.
+def MakeDirsExistOk(directory_path):
+ try:
+ os.makedirs(directory_path)
+ except OSError as exception:
+ if exception.errno != errno.EEXIST:
+ raise
+
+
+# Deletes all files and folders within the given folder.
+def DeleteFolderContents(folder_path):
+ for file_name in os.listdir(folder_path):
+ file_path = os.path.join(folder_path, file_name)
+ try:
+ if os.path.isfile(file_path):
+ os.unlink(file_path)
+ else: #if os.path.isdir(file_path):
+ shutil.rmtree(file_path)
+ except Exception as e:
+ print('Exception in DeleteFolderContents():')
+ print(e)
+ print('Stack trace:')
+ print(traceback.format_exc())
+
+
+# Creates the given directory, respectively deletes all content of the directory
+# in case it already exists.
+def MakeCleanDirectory(folder_path):
+ if os.path.isdir(folder_path):
+ DeleteFolderContents(folder_path)
+ else:
+ MakeDirsExistOk(folder_path)
+
+
+# Downloads the given URL to a file in the given directory. Returns the
+# path to the downloaded file.
+# In part adapted from: https://stackoverflow.com/questions/22676
+def DownloadFile(url, dest_dir_path):
+ file_name = url.split('/')[-1]
+ dest_file_path = os.path.join(dest_dir_path, file_name)
+
+ if os.path.isfile(dest_file_path):
+ print('The following file already exists:')
+ print(dest_file_path)
+ print('Please choose whether to re-download and overwrite the file [o] or to skip downloading this file [s] by entering o or s.')
+ while True:
+ response = GetUserInput("> ")
+ if response == 's':
+ return dest_file_path
+ elif response == 'o':
+ break
+ else:
+ print('Please enter o or s.')
+
+ url_object = None
+ if sys.version_info[0] == 2:
+ url_object = urllib2.urlopen(url)
+ else:
+ url_object = urllib.request.urlopen(url)
+
+ with open(dest_file_path, 'wb') as outfile:
+ meta = url_object.info()
+ file_size = 0
+ if sys.version_info[0] == 2:
+ file_size = int(meta.getheaders("Content-Length")[0])
+ else:
+ file_size = int(meta["Content-Length"])
+ print("Downloading: %s (size [bytes]: %s)" % (url, file_size))
+
+ file_size_downloaded = 0
+ block_size = 8192
+ while True:
+ buffer = url_object.read(block_size)
+ if not buffer:
+ break
+
+ file_size_downloaded += len(buffer)
+ outfile.write(buffer)
+
+ sys.stdout.write("%d / %d (%3f%%)\r" % (file_size_downloaded, file_size, file_size_downloaded * 100. / file_size))
+ sys.stdout.flush()
+
+ return dest_file_path
+
+
+# Unzips the given zip file into the given directory.
+def UnzipFile(file_path, unzip_dir_path, overwrite=True):
+ zip_ref = zipfile.ZipFile(open(file_path, 'rb'))
+
+ if not overwrite:
+ for f in zip_ref.namelist():
+ if not os.path.isfile(os.path.join(unzip_dir_path, f)):
+ zip_ref.extract(f, path=unzip_dir_path)
+ else:
+ print('Not overwriting {}'.format(f))
+ else:
+ zip_ref.extractall(unzip_dir_path)
+ zip_ref.close()
+
+
+# Creates a zip file with the contents of the given directory.
+# The archive_base_path must not include the extension .zip. The full, final
+# path of the archive is returned by the function.
+def ZipDirectory(archive_base_path, root_dir_path):
+ # return shutil.make_archive(archive_base_path, 'zip', root_dir_path) # THIS WILL ALWAYS HAVE ./ FOLDER INCLUDED
+ with zipfile.ZipFile(archive_base_path+'.zip', "w", compression=zipfile.ZIP_DEFLATED) as zf:
+ base_path = os.path.normpath(root_dir_path)
+ for dirpath, dirnames, filenames in os.walk(root_dir_path):
+ for name in sorted(dirnames):
+ path = os.path.normpath(os.path.join(dirpath, name))
+ zf.write(path, os.path.relpath(path, base_path))
+ for name in filenames:
+ path = os.path.normpath(os.path.join(dirpath, name))
+ if os.path.isfile(path):
+ zf.write(path, os.path.relpath(path, base_path))
+
+ return archive_base_path+'.zip'
+
+
+# Downloads a zip file and directly unzips it.
+def DownloadAndUnzipFile(url, archive_dir_path, unzip_dir_path, overwrite=True):
+ archive_path = DownloadFile(url, archive_dir_path)
+ UnzipFile(archive_path, unzip_dir_path, overwrite=overwrite)
+
+def mkdir_p(path):
+ try:
+ os.makedirs(path)
+ except OSError as exc: # Python >2.5
+ if exc.errno == errno.EEXIST and os.path.isdir(path):
+ pass
+ else:
+ raise
diff --git a/expansion/utils/logger.py b/expansion/utils/logger.py
new file mode 100755
index 0000000000000000000000000000000000000000..9bd31e79407f8e7e94d236c9b0e620403d1e3d85
--- /dev/null
+++ b/expansion/utils/logger.py
@@ -0,0 +1,113 @@
+"""
+File: logger.py
+Modified by: Senthil Purushwalkam
+Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
+Email: spurushwandrewcmuedu
+Github: https://github.com/senthilps8
+Description:
+"""
+import pdb
+import tensorflow as tf
+from torch.autograd import Variable
+import numpy as np
+import scipy.misc
+import os
+try:
+ from StringIO import StringIO # Python 2.7
+except ImportError:
+ from io import BytesIO # Python 3.x
+
+
+class Logger(object):
+
+ def __init__(self, log_dir, name=None):
+ """Create a summary writer logging to log_dir."""
+ if name is None:
+ name = 'temp'
+ self.name = name
+ if name is not None:
+ try:
+ os.makedirs(os.path.join(log_dir, name))
+ except:
+ pass
+ self.writer = tf.summary.FileWriter(os.path.join(log_dir, name),
+ filename_suffix=name)
+ else:
+ self.writer = tf.summary.FileWriter(log_dir, filename_suffix=name)
+
+ def scalar_summary(self, tag, value, step):
+ """Log a scalar variable."""
+ summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
+ self.writer.add_summary(summary, step)
+
+ def image_summary(self, tag, images, step):
+ """Log a list of images."""
+
+ img_summaries = []
+ for i, img in enumerate(images):
+ # Write the image to a string
+ try:
+ s = StringIO()
+ except:
+ s = BytesIO()
+ scipy.misc.toimage(img).save(s, format="png")
+
+ # Create an Image object
+ img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
+ height=img.shape[0],
+ width=img.shape[1])
+ # Create a Summary value
+ img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
+
+ # Create and write Summary
+ summary = tf.Summary(value=img_summaries)
+ self.writer.add_summary(summary, step)
+
+ def histo_summary(self, tag, values, step, bins=1000):
+ """Log a histogram of the tensor of values."""
+
+ # Create a histogram using numpy
+ counts, bin_edges = np.histogram(values, bins=bins)
+
+ # Fill the fields of the histogram proto
+ hist = tf.HistogramProto()
+ hist.min = float(np.min(values))
+ hist.max = float(np.max(values))
+ hist.num = int(np.prod(values.shape))
+ hist.sum = float(np.sum(values))
+ hist.sum_squares = float(np.sum(values**2))
+
+ # Drop the start of the first bin
+ bin_edges = bin_edges[1:]
+
+ # Add bin edges and counts
+ for edge in bin_edges:
+ hist.bucket_limit.append(edge)
+ for c in counts:
+ hist.bucket.append(c)
+
+ # Create and write Summary
+ summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
+ self.writer.add_summary(summary, step)
+ self.writer.flush()
+
+ def to_np(self, x):
+ return x.data.cpu().numpy()
+
+ def to_var(self, x):
+ if torch.cuda.is_available():
+ x = x.cuda()
+ return Variable(x)
+
+ def model_param_histo_summary(self, model, step):
+ """log histogram summary of model's parameters
+ and parameter gradients
+ """
+ for tag, value in model.named_parameters():
+ if value.grad is None:
+ continue
+ tag = tag.replace('.', '/')
+ tag = self.name+'/'+tag
+ self.histo_summary(tag, self.to_np(value), step)
+ self.histo_summary(tag+'/grad', self.to_np(value.grad), step)
+
diff --git a/expansion/utils/multiscaleloss.py b/expansion/utils/multiscaleloss.py
new file mode 100755
index 0000000000000000000000000000000000000000..b3240c7d94090c9c1447b0b844d0c16c36204796
--- /dev/null
+++ b/expansion/utils/multiscaleloss.py
@@ -0,0 +1,86 @@
+"""
+Taken from https://github.com/ClementPinard/FlowNetPytorch
+"""
+import pdb
+import torch
+import torch.nn.functional as F
+
+
+def EPE(input_flow, target_flow, mask, sparse=False, mean=True):
+ #mask = target_flow[:,2]>0
+ target_flow = target_flow[:,:2]
+ EPE_map = torch.norm(target_flow-input_flow,2,1)
+ batch_size = EPE_map.size(0)
+ if sparse:
+ # invalid flow is defined with both flow coordinates to be exactly 0
+ mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)
+
+ EPE_map = EPE_map[~mask]
+ if mean:
+ return EPE_map[mask].mean()
+ else:
+ return EPE_map[mask].sum()/batch_size
+
+def rob_EPE(input_flow, target_flow, mask, sparse=False, mean=True):
+ #mask = target_flow[:,2]>0
+ target_flow = target_flow[:,:2]
+ #TODO
+# EPE_map = torch.norm(target_flow-input_flow,2,1)
+ EPE_map = (torch.norm(target_flow-input_flow,1,1)+0.01).pow(0.4)
+ batch_size = EPE_map.size(0)
+ if sparse:
+ # invalid flow is defined with both flow coordinates to be exactly 0
+ mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)
+
+ EPE_map = EPE_map[~mask]
+ if mean:
+ return EPE_map[mask].mean()
+ else:
+ return EPE_map[mask].sum()/batch_size
+
+def sparse_max_pool(input, size):
+ '''Downsample the input by considering 0 values as invalid.
+
+ Unfortunately, no generic interpolation mode can resize a sparse map correctly,
+ the strategy here is to use max pooling for positive values and "min pooling"
+ for negative values, the two results are then summed.
+ This technique allows sparsity to be minized, contrary to nearest interpolation,
+ which could potentially lose information for isolated data points.'''
+
+ positive = (input > 0).float()
+ negative = (input < 0).float()
+ output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d(-input * negative, size)
+ return output
+
+
+def multiscaleEPE(network_output, target_flow, mask, weights=None, sparse=False, rob_loss = False):
+ def one_scale(output, target, mask, sparse):
+
+ b, _, h, w = output.size()
+
+ if sparse:
+ target_scaled = sparse_max_pool(target, (h, w))
+ else:
+ target_scaled = F.interpolate(target, (h, w), mode='area')
+ mask = F.interpolate(mask.float().unsqueeze(1), (h, w), mode='bilinear').squeeze(1)==1
+ if rob_loss:
+ return rob_EPE(output, target_scaled, mask, sparse, mean=False)
+ else:
+ return EPE(output, target_scaled, mask, sparse, mean=False)
+
+ if type(network_output) not in [tuple, list]:
+ network_output = [network_output]
+ if weights is None:
+ weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article
+ assert(len(weights) == len(network_output))
+
+ loss = 0
+ for output, weight in zip(network_output, weights):
+ loss += weight * one_scale(output, target_flow, mask, sparse)
+ return loss
+
+
+def realEPE(output, target, mask, sparse=False):
+ b, _, h, w = target.size()
+ upsampled_output = F.interpolate(output, (h,w), mode='bilinear', align_corners=False)
+ return EPE(upsampled_output, target,mask, sparse, mean=True)
diff --git a/expansion/utils/pfm.py b/expansion/utils/pfm.py
new file mode 100755
index 0000000000000000000000000000000000000000..65a3dbb10ae96a72110294feb7cdc8df9d9c8b0d
--- /dev/null
+++ b/expansion/utils/pfm.py
@@ -0,0 +1,79 @@
+import re
+import numpy as np
+import sys
+
+def readPFM(file):
+ file = open(file, 'rb')
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if (sys.version[0]) == '3':
+ header = header.decode('utf-8')
+ if header == 'PF':
+ color = True
+ elif header == 'Pf':
+ color = False
+ else:
+ raise Exception('Not a PFM file.')
+
+ if (sys.version[0]) == '3':
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
+ else:
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
+ if dim_match:
+ width, height = map(int, dim_match.groups())
+ else:
+ raise Exception('Malformed PFM header.')
+
+ if (sys.version[0]) == '3':
+ scale = float(file.readline().rstrip().decode('utf-8'))
+ else:
+ scale = float(file.readline().rstrip())
+
+ if scale < 0: # little-endian
+ endian = '<'
+ scale = -scale
+ else:
+ endian = '>' # big-endian
+
+ data = np.fromfile(file, endian + 'f')
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+ return data, scale
+
+
+def writePFM(file, image, scale=1):
+ file = open(file, 'wb')
+
+ color = None
+
+ if image.dtype.name != 'float32':
+ raise Exception('Image dtype must be float32.')
+
+ image = np.flipud(image)
+
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
+ color = False
+ else:
+ raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
+
+ file.write('PF\n' if color else 'Pf\n')
+ file.write('%d %d\n' % (image.shape[1], image.shape[0]))
+
+ endian = image.dtype.byteorder
+
+ if endian == '<' or endian == '=' and sys.byteorder == 'little':
+ scale = -scale
+
+ file.write('%f\n' % scale)
+
+ image.tofile(file)
diff --git a/expansion/utils/readpfm.py b/expansion/utils/readpfm.py
new file mode 100755
index 0000000000000000000000000000000000000000..a79837854408d80bb033f5b794e3294b2da50897
--- /dev/null
+++ b/expansion/utils/readpfm.py
@@ -0,0 +1,51 @@
+import re
+import numpy as np
+import sys
+
+
+def readPFM(file):
+ file = open(file, 'rb')
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if (sys.version[0]) == '3':
+ header = header.decode('utf-8')
+ if header == 'PF':
+ color = True
+ elif header == 'Pf':
+ color = False
+ else:
+ raise Exception('Not a PFM file.')
+
+ if (sys.version[0]) == '3':
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
+ else:
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
+ if dim_match:
+ width, height = map(int, dim_match.groups())
+ else:
+ raise Exception('Malformed PFM header.')
+
+ if (sys.version[0]) == '3':
+ scale = float(file.readline().rstrip().decode('utf-8'))
+ else:
+ scale = float(file.readline().rstrip())
+
+ if scale < 0: # little-endian
+ endian = '<'
+ scale = -scale
+ else:
+ endian = '>' # big-endian
+
+ data = np.fromfile(file, endian + 'f')
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+ return data, scale
+
diff --git a/expansion/utils/sintel_io.py b/expansion/utils/sintel_io.py
new file mode 100755
index 0000000000000000000000000000000000000000..c3633c9d08a8026e8cd652a2df263c6e893c0976
--- /dev/null
+++ b/expansion/utils/sintel_io.py
@@ -0,0 +1,214 @@
+#! /usr/bin/env python2
+
+"""
+I/O script to save and load the data coming with the MPI-Sintel low-level
+computer vision benchmark.
+
+For more details about the benchmark, please visit www.mpi-sintel.de
+
+CHANGELOG:
+v1.0 (2015/02/03): First release
+
+Copyright (c) 2015 Jonas Wulff
+Max Planck Institute for Intelligent Systems, Tuebingen, Germany
+
+"""
+
+# Requirements: Numpy as PIL/Pillow
+import numpy as np
+from PIL import Image
+
+# Check for endianness, based on Daniel Scharstein's optical flow code.
+# Using little-endian architecture, these two should be equal.
+TAG_FLOAT = 202021.25
+TAG_CHAR = 'PIEH'
+
+def flow_read(filename):
+ """ Read optical flow from file, return (U,V) tuple.
+
+ Original code by Deqing Sun, adapted from Daniel Scharstein.
+ """
+ f = open(filename,'rb')
+ check = np.fromfile(f,dtype=np.float32,count=1)[0]
+ assert check == TAG_FLOAT, ' flow_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
+ width = np.fromfile(f,dtype=np.int32,count=1)[0]
+ height = np.fromfile(f,dtype=np.int32,count=1)[0]
+ size = width*height
+ assert width > 0 and height > 0 and size > 1 and size < 100000000, ' flow_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height)
+ tmp = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width*2))
+ u = tmp[:,np.arange(width)*2]
+ v = tmp[:,np.arange(width)*2 + 1]
+ return u,v
+
+def flow_write(filename,uv,v=None):
+ """ Write optical flow to file.
+
+ If v is None, uv is assumed to contain both u and v channels,
+ stacked in depth.
+
+ Original code by Deqing Sun, adapted from Daniel Scharstein.
+ """
+ nBands = 2
+
+ if v is None:
+ assert(uv.ndim == 3)
+ assert(uv.shape[2] == 2)
+ u = uv[:,:,0]
+ v = uv[:,:,1]
+ else:
+ u = uv
+
+ assert(u.shape == v.shape)
+ height,width = u.shape
+ f = open(filename,'wb')
+ # write the header
+ f.write(TAG_CHAR)
+ np.array(width).astype(np.int32).tofile(f)
+ np.array(height).astype(np.int32).tofile(f)
+ # arrange into matrix form
+ tmp = np.zeros((height, width*nBands))
+ tmp[:,np.arange(width)*2] = u
+ tmp[:,np.arange(width)*2 + 1] = v
+ tmp.astype(np.float32).tofile(f)
+ f.close()
+
+
+def depth_read(filename):
+ """ Read depth data from file, return as numpy array. """
+ f = open(filename,'rb')
+ check = np.fromfile(f,dtype=np.float32,count=1)[0]
+ assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
+ width = np.fromfile(f,dtype=np.int32,count=1)[0]
+ height = np.fromfile(f,dtype=np.int32,count=1)[0]
+ size = width*height
+ assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height)
+ depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width))
+ return depth
+
+def depth_write(filename, depth):
+ """ Write depth to file. """
+ height,width = depth.shape[:2]
+ f = open(filename,'wb')
+ # write the header
+ f.write(TAG_CHAR)
+ np.array(width).astype(np.int32).tofile(f)
+ np.array(height).astype(np.int32).tofile(f)
+
+ depth.astype(np.float32).tofile(f)
+ f.close()
+
+
+def disparity_write(filename,disparity,bitdepth=16):
+ """ Write disparity to file.
+
+ bitdepth can be either 16 (default) or 32.
+
+ The maximum disparity is 1024, since the image width in Sintel
+ is 1024.
+ """
+ d = disparity.copy()
+
+ # Clip disparity.
+ d[d>1024] = 1024
+ d[d<0] = 0
+
+ d_r = (d / 4.0).astype('uint8')
+ d_g = ((d * (2.0**6)) % 256).astype('uint8')
+
+ out = np.zeros((d.shape[0],d.shape[1],3),dtype='uint8')
+ out[:,:,0] = d_r
+ out[:,:,1] = d_g
+
+ if bitdepth > 16:
+ d_b = (d * (2**14) % 256).astype('uint8')
+ out[:,:,2] = d_b
+
+ Image.fromarray(out,'RGB').save(filename,'PNG')
+
+
+def disparity_read(filename):
+ """ Return disparity read from filename. """
+ f_in = np.array(Image.open(filename))
+ d_r = f_in[:,:,0].astype('float64')
+ d_g = f_in[:,:,1].astype('float64')
+ d_b = f_in[:,:,2].astype('float64')
+
+ depth = d_r * 4 + d_g / (2**6) + d_b / (2**14)
+ return depth
+
+
+#def cam_read(filename):
+# """ Read camera data, return (M,N) tuple.
+#
+# M is the intrinsic matrix, N is the extrinsic matrix, so that
+#
+# x = M*N*X,
+# with x being a point in homogeneous image pixel coordinates, X being a
+# point in homogeneous world coordinates.
+# """
+# txtdata = np.loadtxt(filename)
+# intrinsic = txtdata[0,:9].reshape((3,3))
+# extrinsic = textdata[1,:12].reshape((3,4))
+# return intrinsic,extrinsic
+#
+#
+#def cam_write(filename,M,N):
+# """ Write intrinsic matrix M and extrinsic matrix N to file. """
+# Z = np.zeros((2,12))
+# Z[0,:9] = M.ravel()
+# Z[1,:12] = N.ravel()
+# np.savetxt(filename,Z)
+
+def cam_read(filename):
+ """ Read camera data, return (M,N) tuple.
+
+ M is the intrinsic matrix, N is the extrinsic matrix, so that
+
+ x = M*N*X,
+ with x being a point in homogeneous image pixel coordinates, X being a
+ point in homogeneous world coordinates.
+ """
+ f = open(filename,'rb')
+ check = np.fromfile(f,dtype=np.float32,count=1)[0]
+ assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
+ M = np.fromfile(f,dtype='float64',count=9).reshape((3,3))
+ N = np.fromfile(f,dtype='float64',count=12).reshape((3,4))
+ return M,N
+
+def cam_write(filename, M, N):
+ """ Write intrinsic matrix M and extrinsic matrix N to file. """
+ f = open(filename,'wb')
+ # write the header
+ f.write(TAG_CHAR)
+ M.astype('float64').tofile(f)
+ N.astype('float64').tofile(f)
+ f.close()
+
+
+def segmentation_write(filename,segmentation):
+ """ Write segmentation to file. """
+
+ segmentation_ = segmentation.astype('int32')
+ seg_r = np.floor(segmentation_ / (256**2)).astype('uint8')
+ seg_g = np.floor((segmentation_ % (256**2)) / 256).astype('uint8')
+ seg_b = np.floor(segmentation_ % 256).astype('uint8')
+
+ out = np.zeros((segmentation.shape[0],segmentation.shape[1],3),dtype='uint8')
+ out[:,:,0] = seg_r
+ out[:,:,1] = seg_g
+ out[:,:,2] = seg_b
+
+ Image.fromarray(out,'RGB').save(filename,'PNG')
+
+
+def segmentation_read(filename):
+ """ Return disparity read from filename. """
+ f_in = np.array(Image.open(filename))
+ seg_r = f_in[:,:,0].astype('int32')
+ seg_g = f_in[:,:,1].astype('int32')
+ seg_b = f_in[:,:,2].astype('int32')
+
+ segmentation = (seg_r * 256 + seg_g) * 256 + seg_b
+ return segmentation
+
+
diff --git a/expansion/utils/util_flow.py b/expansion/utils/util_flow.py
new file mode 100755
index 0000000000000000000000000000000000000000..13c683370f8f2b4b6ac6b077d05b0964753821bb
--- /dev/null
+++ b/expansion/utils/util_flow.py
@@ -0,0 +1,272 @@
+import math
+import png
+import struct
+import array
+import numpy as np
+import cv2
+import pdb
+
+from io import *
+
+UNKNOWN_FLOW_THRESH = 1e9;
+UNKNOWN_FLOW = 1e10;
+
+# Middlebury checks
+TAG_STRING = 'PIEH' # use this when WRITING the file
+TAG_FLOAT = 202021.25 # check for this when READING the file
+
+def readPFM(file):
+ import re
+ file = open(file, 'rb')
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header == b'PF':
+ color = True
+ elif header == b'Pf':
+ color = False
+ else:
+ raise Exception('Not a PFM file.')
+
+ dim_match = re.match(b'^(\d+)\s(\d+)\s$', file.readline())
+ if dim_match:
+ width, height = map(int, dim_match.groups())
+ else:
+ raise Exception('Malformed PFM header.')
+
+ scale = float(file.readline().rstrip())
+ if scale < 0: # little-endian
+ endian = '<'
+ scale = -scale
+ else:
+ endian = '>' # big-endian
+
+ data = np.fromfile(file, endian + 'f')
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+ return data, scale
+
+
+def save_pfm(file, image, scale = 1):
+ import sys
+ color = None
+
+ if image.dtype.name != 'float32':
+ raise Exception('Image dtype must be float32.')
+
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
+ color = False
+ else:
+ raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
+
+ file.write('PF\n' if color else 'Pf\n')
+ file.write('%d %d\n' % (image.shape[1], image.shape[0]))
+
+ endian = image.dtype.byteorder
+
+ if endian == '<' or endian == '=' and sys.byteorder == 'little':
+ scale = -scale
+
+ file.write('%f\n' % scale)
+
+ image.tofile(file)
+
+
+def ReadMiddleburyFloFile(path):
+ """ Read .FLO file as specified by Middlebury.
+
+ Returns tuple (width, height, u, v, mask), where u, v, mask are flat
+ arrays of values.
+ """
+
+ with open(path, 'rb') as fil:
+ tag = struct.unpack('f', fil.read(4))[0]
+ width = struct.unpack('i', fil.read(4))[0]
+ height = struct.unpack('i', fil.read(4))[0]
+
+ assert tag == TAG_FLOAT
+
+ #data = np.fromfile(path, dtype=np.float, count=-1)
+ #data = data[3:]
+
+ fmt = 'f' * width*height*2
+ data = struct.unpack(fmt, fil.read(4*width*height*2))
+
+ u = data[::2]
+ v = data[1::2]
+
+ mask = map(lambda x,y: abs(x) 0:
+ # print(u[ind], v[ind], mask[ind], row[3*x], row[3*x+1], row[3*x+2])
+
+ #png_reader.close()
+
+ return (width, height, u, v, mask)
+
+
+def WriteMiddleburyFloFile(path, width, height, u, v, mask=None):
+ """ Write .FLO file as specified by Middlebury.
+ """
+
+ if mask is not None:
+ u_masked = map(lambda x,y: x if y else UNKNOWN_FLOW, u, mask)
+ v_masked = map(lambda x,y: x if y else UNKNOWN_FLOW, v, mask)
+ else:
+ u_masked = u
+ v_masked = v
+
+ fmt = 'f' * width*height*2
+ # Interleave lists
+ data = [x for t in zip(u_masked,v_masked) for x in t]
+
+ with open(path, 'wb') as fil:
+ fil.write(str.encode(TAG_STRING))
+ fil.write(struct.pack('i', width))
+ fil.write(struct.pack('i', height))
+ fil.write(struct.pack(fmt, *data))
+
+
+def write_flow(path,flow):
+
+ invalid_idx = (flow[:, :, 2] == 0)
+ flow[:, :, 0:2] = flow[:, :, 0:2]*64.+ 2 ** 15
+ flow[invalid_idx, 0] = 0
+ flow[invalid_idx, 1] = 0
+
+ flow = flow.astype(np.uint16)
+ flow = cv2.imwrite(path, flow[:,:,::-1])
+
+ #WriteKittiPngFile(path,
+ # flow.shape[1], flow.shape[0], flow[:,:,0].flatten(),
+ # flow[:,:,1].flatten(), flow[:,:,2].flatten())
+
+
+
+def WriteKittiPngFile(path, width, height, u, v, mask=None):
+ """ Write 16-bit .PNG file as specified by KITTI-2015 (flow).
+
+ u, v are lists of float values
+ mask is a list of floats, denoting the *valid* pixels.
+ """
+
+ data = array.array('H',[0])*width*height*3
+
+ for i,(u_,v_,mask_) in enumerate(zip(u,v,mask)):
+ data[3*i] = int(u_*64.0+2**15)
+ data[3*i+1] = int(v_*64.0+2**15)
+ data[3*i+2] = int(mask_)
+
+ # if mask_ > 0:
+ # print(data[3*i], data[3*i+1],data[3*i+2])
+
+ with open(path, 'wb') as png_file:
+ png_writer = png.Writer(width=width, height=height, bitdepth=16, compression=3, greyscale=False)
+ png_writer.write_array(png_file, data)
+
+
+def ConvertMiddleburyFloToKittiPng(src_path, dest_path):
+ width, height, u, v, mask = ReadMiddleburyFloFile(src_path)
+ WriteKittiPngFile(dest_path, width, height, u, v, mask=mask)
+
+def ConvertKittiPngToMiddleburyFlo(src_path, dest_path):
+ width, height, u, v, mask = ReadKittiPngFile(src_path)
+ WriteMiddleburyFloFile(dest_path, width, height, u, v, mask=mask)
+
+
+def ParseFilenameKitti(filename):
+ # Parse kitti filename (seq_frameno.xx),
+ # return seq, frameno, ext.
+ # Be aware that seq might contain the dataset name (if contained as prefix)
+ ext = filename[filename.rfind('.'):]
+ frameno = filename[filename.rfind('_')+1:filename.rfind('.')]
+ frameno = int(frameno)
+ seq = filename[:filename.rfind('_')]
+ return seq, frameno, ext
+
+
+def read_calib_file(filepath):
+ """Read in a calibration file and parse into a dictionary."""
+ data = {}
+
+ with open(filepath, 'r') as f:
+ for line in f.readlines():
+ key, value = line.split(':', 1)
+ # The only non-float values in these files are dates, which
+ # we don't care about anyway
+ try:
+ data[key] = np.array([float(x) for x in value.split()])
+ except ValueError:
+ pass
+
+ return data
+
+def load_calib_cam_to_cam(cam_to_cam_file):
+ # We'll return the camera calibration as a dictionary
+ data = {}
+
+ # Load and parse the cam-to-cam calibration data
+ filedata = read_calib_file(cam_to_cam_file)
+
+ # Create 3x4 projection matrices
+ P_rect_00 = np.reshape(filedata['P_rect_00'], (3, 4))
+ P_rect_10 = np.reshape(filedata['P_rect_01'], (3, 4))
+ P_rect_20 = np.reshape(filedata['P_rect_02'], (3, 4))
+ P_rect_30 = np.reshape(filedata['P_rect_03'], (3, 4))
+
+ # Compute the camera intrinsics
+ data['K_cam0'] = P_rect_00[0:3, 0:3]
+ data['K_cam1'] = P_rect_10[0:3, 0:3]
+ data['K_cam2'] = P_rect_20[0:3, 0:3]
+ data['K_cam3'] = P_rect_30[0:3, 0:3]
+
+ data['b00'] = P_rect_00[0, 3] / P_rect_00[0, 0]
+ data['b10'] = P_rect_10[0, 3] / P_rect_10[0, 0]
+ data['b20'] = P_rect_20[0, 3] / P_rect_20[0, 0]
+ data['b30'] = P_rect_30[0, 3] / P_rect_30[0, 0]
+
+ return data
+
diff --git a/interface/flask_app.py b/interface/flask_app.py
new file mode 100755
index 0000000000000000000000000000000000000000..f4b28432fcf8439dad33abe2d5cb6b439025ea2b
--- /dev/null
+++ b/interface/flask_app.py
@@ -0,0 +1,107 @@
+from flask import Flask, render_template, request, redirect, url_for, abort
+import json
+
+app = Flask(__name__)
+
+import sys
+sys.path.append(".")
+sys.path.append("..")
+
+import argparse
+from PIL import Image, ImageOps
+import numpy as np
+import base64
+import cv2
+from inference import demo
+
+def Base64ToNdarry(img_base64):
+ img_data = base64.b64decode(img_base64)
+ img_np = np.fromstring(img_data, np.uint8)
+ src = cv2.imdecode(img_np, cv2.IMREAD_ANYCOLOR)
+
+ return src
+
+def NdarrayToBase64(dst):
+ result, dst_data = cv2.imencode('.png', dst)
+ dst_base64 = base64.b64encode(dst_data)
+
+ return dst_base64
+
+parser = argparse.ArgumentParser(description='User controllable latent transformer')
+parser.add_argument('--checkpoint_path', default='pretrained_models/latent_transformer/cat.pt')
+args = parser.parse_args()
+
+demo = demo(args.checkpoint_path)
+
+@app.route("/", methods=["GET", "POST"])
+#@auth.login_required
+def init():
+ if request.method == "GET":
+ input_img = demo.run()
+ input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
+ return render_template("index.html", filepath1=input_base64, canvas_img=input_base64, result=True)
+ if request.method == "POST":
+ if 'zi' in request.form.keys():
+ input_img = demo.move(z=-0.05)
+ elif 'zo' in request.form.keys():
+ input_img = demo.move(z=0.05)
+ elif 'u' in request.form.keys():
+ input_img = demo.move(y=-0.5, z=-0.0)
+ elif 'd' in request.form.keys():
+ input_img = demo.move(y=0.5, z=-0.0)
+ elif 'l' in request.form.keys():
+ input_img = demo.move(x=-0.5, z=-0.0)
+ elif 'r' in request.form.keys():
+ input_img = demo.move(x=0.5, z=-0.0)
+ else:
+ input_img = demo.run()
+
+ input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
+ return render_template("index.html", filepath1=input_base64, canvas_img=input_base64, result=True)
+
+@app.route('/zoom', methods=["POST"])
+def zoom_func():
+
+ dz = json.loads(request.form['dz'])
+ sx = json.loads(request.form['sx'])
+ sy = json.loads(request.form['sy'])
+ stop_points = json.loads(request.form['stop_points'])
+
+ input_img = demo.zoom(dz,sxsy=[sx,sy],stop_points=stop_points)
+ input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
+ res = {'img':input_base64}
+ return json.dumps(res)
+
+@app.route('/translate', methods=["POST"])
+def translate_func():
+
+ dx = json.loads(request.form['dx'])
+ dy = json.loads(request.form['dy'])
+ dz = json.loads(request.form['dz'])
+ sx = json.loads(request.form['sx'])
+ sy = json.loads(request.form['sy'])
+ stop_points = json.loads(request.form['stop_points'])
+ zi = json.loads(request.form['zi'])
+ zo = json.loads(request.form['zo'])
+
+ input_img = demo.translate([dx,dy],sxsy=[sx,sy],stop_points=stop_points,zoom_in=zi,zoom_out=zo)
+ input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
+ res = {'img':input_base64}
+ return json.dumps(res)
+
+@app.route('/changestyle', methods=["POST"])
+def changestyle_func():
+ input_img = demo.change_style()
+ input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
+ res = {'img':input_base64}
+ return json.dumps(res)
+
+@app.route('/reset', methods=["POST"])
+def reset_func():
+ input_img = demo.reset()
+ input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
+ res = {'img':input_base64}
+ return json.dumps(res)
+
+if __name__ == "__main__":
+ app.run(debug=False, host='0.0.0.0', port=8000)
\ No newline at end of file
diff --git a/interface/inference.py b/interface/inference.py
new file mode 100755
index 0000000000000000000000000000000000000000..568f8f8be5903b53505b30ca416bfbb57b3d7287
--- /dev/null
+++ b/interface/inference.py
@@ -0,0 +1,117 @@
+import os
+from argparse import Namespace
+import numpy as np
+import torch
+import sys
+
+sys.path.append(".")
+sys.path.append("..")
+
+from models.StyleGANControler import StyleGANControler
+
+class demo():
+
+ def __init__(self, checkpoint_path, truncation = 0.5, use_average_code_as_input = False):
+ self.truncation = truncation
+ self.use_average_code_as_input = use_average_code_as_input
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
+ opts = ckpt['opts']
+ opts['checkpoint_path'] = checkpoint_path
+ self.opts = Namespace(**ckpt['opts'])
+
+ self.net = StyleGANControler(self.opts)
+ self.net.eval()
+ self.net.cuda()
+ self.target_layers = [0,1,2,3,4,5]
+
+ self.w1 = None
+ self.w1_after = None
+ self.f1 = None
+
+ def run(self):
+ z1 = torch.randn(1,512).to("cuda")
+ x1, self.w1, self.f1 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_feature_map=True,return_latents=True,truncation=self.truncation, truncation_latent=self.net.latent_avg[0])
+ self.w1_after = self.w1.clone()
+ x1 = self.net.face_pool(x1)
+ result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
+ return result
+
+ def translate(self, dxy, sxsy=[0,0], stop_points=[], zoom_in=False, zoom_out=False):
+ dz = -5. if zoom_in else 0.
+ dz = 5. if zoom_out else dz
+
+ dxyz = np.array([dxy[0],dxy[1],dz], dtype=np.float32)
+ dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
+ dxyz[:2] = dxyz[:2]/dxy_norm
+ vec_num = dxy_norm/10
+
+ x = torch.from_numpy(np.array([[dxyz]],dtype=np.float32)).cuda()
+ f1 = torch.nn.functional.interpolate(self.f1, (256,256))
+ y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0)
+
+ if len(stop_points)>0:
+ x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1)
+ tmp = []
+ for sp in stop_points:
+ tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1))
+ y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1)
+
+ if not self.use_average_code_as_input:
+ w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
+ w1 = self.w1.clone()
+ w1[:,self.target_layers] = w_hat
+ else:
+ w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
+ w1 = self.w1.clone()
+ w1[:,self.target_layers] = self.w1.clone()[:,self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers]
+
+ x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
+
+ self.w1_after = w1.clone()
+ x1 = self.net.face_pool(x1)
+ result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
+ return result
+
+ def zoom(self, dz, sxsy=[0,0], stop_points=[]):
+ vec_num = abs(dz)/5
+ dz = 100*np.sign(dz)
+ x = torch.from_numpy(np.array([[[1.,0,dz]]],dtype=np.float32)).cuda()
+ f1 = torch.nn.functional.interpolate(self.f1, (256,256))
+ y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0)
+
+ if len(stop_points)>0:
+ x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1)
+ tmp = []
+ for sp in stop_points:
+ tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1))
+ y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1)
+
+ if not self.use_average_code_as_input:
+ w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
+ w1 = self.w1.clone()
+ w1[:,self.target_layers] = w_hat
+ else:
+ w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
+ w1 = self.w1.clone()
+ w1[:,self.target_layers] = self.w1.clone()[:,self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers]
+
+
+ x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
+
+ x1 = self.net.face_pool(x1)
+ result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
+ return result
+
+ def change_style(self):
+ z1 = torch.randn(1,512).to("cuda")
+ x1, w2 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_latents=True, truncation=self.truncation, truncation_latent=self.net.latent_avg[0])
+ self.w1_after[:,6:] = w2.detach()[:,0]
+ x1, _ = self.net.decoder([self.w1_after], input_is_latent=True, randomize_noise=False, return_latents=False)
+ result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
+ return result
+
+ def reset(self):
+ x1, _ = self.net.decoder([self.w1], input_is_latent=True, randomize_noise=False, return_latents=False)
+ result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
+ return result
+
\ No newline at end of file
diff --git a/interface/templates/index.html b/interface/templates/index.html
new file mode 100755
index 0000000000000000000000000000000000000000..422afd27e84c1b28c030a4f480776d53de9fc5ee
--- /dev/null
+++ b/interface/templates/index.html
@@ -0,0 +1,195 @@
+
+
+
+
+
+
+
+
+
+ Mouse drag: | Translation |
+ Middle mouse button: | Set anchor point |
+ Mouse wheel: | Zoom in & out |
+ 'i' or 'o' key + mouse drag: | Translation with zooming in & out |
+ 's' key: | style mixing |
+
+
+
+
+
+
\ No newline at end of file
diff --git a/licenses/LICENSE_ gengshan-y_expansion b/licenses/LICENSE_ gengshan-y_expansion
new file mode 100755
index 0000000000000000000000000000000000000000..e685a844ef325e623ba048fee0b6de7ad1c57553
--- /dev/null
+++ b/licenses/LICENSE_ gengshan-y_expansion
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 Carnegie Mellon University
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/licenses/LICENSE_HuangYG123 b/licenses/LICENSE_HuangYG123
new file mode 100755
index 0000000000000000000000000000000000000000..c539fea307a624b0941fb808d6ae3ab4db552529
--- /dev/null
+++ b/licenses/LICENSE_HuangYG123
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 HuangYG123
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/licenses/LICENSE_S-aiueo32 b/licenses/LICENSE_S-aiueo32
new file mode 100755
index 0000000000000000000000000000000000000000..81e7b18bd6fcfd5a81e08d0bcb192be28cd6723c
--- /dev/null
+++ b/licenses/LICENSE_S-aiueo32
@@ -0,0 +1,25 @@
+BSD 2-Clause License
+
+Copyright (c) 2020, Sou Uchida
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/licenses/LICENSE_TreB1eN b/licenses/LICENSE_TreB1eN
new file mode 100755
index 0000000000000000000000000000000000000000..1c7d3585c795c41d2334036b01a8d660a5235671
--- /dev/null
+++ b/licenses/LICENSE_TreB1eN
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2018 TreB1eN
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/licenses/LICENSE_lessw2020 b/licenses/LICENSE_lessw2020
new file mode 100755
index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac
--- /dev/null
+++ b/licenses/LICENSE_lessw2020
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/licenses/LICENSE_pixel2style2pixel b/licenses/LICENSE_pixel2style2pixel
new file mode 100755
index 0000000000000000000000000000000000000000..272e5a8959a88dce922045124af98cf74b9ee9a4
--- /dev/null
+++ b/licenses/LICENSE_pixel2style2pixel
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 Elad Richardson, Yuval Alaluf
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/licenses/LICENSE_rosinality b/licenses/LICENSE_rosinality
new file mode 100755
index 0000000000000000000000000000000000000000..81da3fce025084b7005be5405d3842fbea29b5ba
--- /dev/null
+++ b/licenses/LICENSE_rosinality
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2019 Kim Seonghyeon
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/models/StyleGANControler.py b/models/StyleGANControler.py
new file mode 100755
index 0000000000000000000000000000000000000000..b0766fd8fe51056ee548ac690c435abd554b301a
--- /dev/null
+++ b/models/StyleGANControler.py
@@ -0,0 +1,70 @@
+import torch
+from torch import nn
+from models.networks import latent_transformer
+from models.stylegan2.model import Generator
+import numpy as np
+
+def get_keys(d, name):
+ if 'state_dict' in d:
+ d = d['state_dict']
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
+ return d_filt
+
+
+class StyleGANControler(nn.Module):
+
+ def __init__(self, opts):
+ super(StyleGANControler, self).__init__()
+ self.set_opts(opts)
+ # Define architecture
+
+ if 'ffhq' in self.opts.stylegan_weights:
+ self.style_num = 18
+ elif 'car' in self.opts.stylegan_weights:
+ self.style_num = 16
+ elif 'cat' in self.opts.stylegan_weights:
+ self.style_num = 14
+ elif 'church' in self.opts.stylegan_weights:
+ self.style_num = 14
+ elif 'anime' in self.opts.stylegan_weights:
+ self.style_num = 16
+
+ self.encoder = self.set_encoder()
+ if self.style_num==18:
+ self.decoder = Generator(1024, 512, 8, channel_multiplier=2)
+ elif self.style_num==16:
+ self.decoder = Generator(512, 512, 8, channel_multiplier=2)
+ elif self.style_num==14:
+ self.decoder = Generator(256, 512, 8, channel_multiplier=2)
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
+
+ # Load weights if needed
+ self.load_weights()
+
+ def set_encoder(self):
+ encoder = latent_transformer.Network(self.opts)
+ return encoder
+
+ def load_weights(self):
+ if self.opts.checkpoint_path is not None:
+ print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path))
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
+ self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
+ self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
+ self.__load_latent_avg(ckpt)
+ else:
+ print('Loading decoder weights from pretrained!')
+ ckpt = torch.load(self.opts.stylegan_weights)
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=True)
+ self.__load_latent_avg(ckpt, repeat=self.opts.style_num)
+
+ def set_opts(self, opts):
+ self.opts = opts
+
+ def __load_latent_avg(self, ckpt, repeat=None):
+ if 'latent_avg' in ckpt:
+ self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
+ if repeat is not None:
+ self.latent_avg = self.latent_avg.repeat(repeat, 1)
+ else:
+ self.latent_avg = None
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/networks/__init__.py b/models/networks/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/networks/latent_transformer.py b/models/networks/latent_transformer.py
new file mode 100755
index 0000000000000000000000000000000000000000..f465ce7285397ed3460e668f1c734e02483b2903
--- /dev/null
+++ b/models/networks/latent_transformer.py
@@ -0,0 +1,162 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange
+
+# classes
+class PreNorm(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+ def forward(self, x, **kwargs):
+ return self.fn(self.norm(x), **kwargs)
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim, dropout = 0.):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(dim, hidden_dim),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(hidden_dim, dim),
+ nn.Dropout(dropout)
+ )
+ def forward(self, x):
+ return self.net(x)
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+
+ self.attend = nn.Softmax(dim = -1)
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, dim),
+ nn.Dropout(dropout)
+ ) if project_out else nn.Identity()
+
+ def forward(self, x):
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
+
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+
+ attn = self.attend(dots)
+
+ out = torch.matmul(attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+class CrossAttention(nn.Module):
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+
+ self.to_k = nn.Linear(dim, inner_dim , bias=False)
+ self.to_v = nn.Linear(dim, inner_dim , bias = False)
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, dim),
+ nn.Dropout(dropout)
+ ) if project_out else nn.Identity()
+
+ def forward(self, x_qkv, query_length=1):
+ h = self.heads
+
+ k = self.to_k(x_qkv)[:, query_length:]
+ k = rearrange(k, 'b n (h d) -> b h n d', h = h)
+
+ v = self.to_v(x_qkv)[:, query_length:]
+ v = rearrange(v, 'b n (h d) -> b h n d', h = h)
+
+ q = self.to_q(x_qkv)[:, :query_length]
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
+
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
+
+ attn = dots.softmax(dim=-1)
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ out = self.to_out(out)
+
+ return out
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(nn.ModuleList([
+ PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
+ ]))
+ def forward(self, x):
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+ return x
+
+class TransformerDecoder(nn.Module):
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
+ super().__init__()
+ self.pos_embedding = nn.Parameter(torch.randn(1, 6, dim))
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(nn.ModuleList([
+ PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
+ PreNorm(dim, CrossAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
+ ]))
+ def forward(self, x, y):
+ x = x + self.pos_embedding[:, :x.shape[1]]
+ for sattn, cattn, ff in self.layers:
+ x = sattn(x) + x
+ xy = torch.cat((x,y), dim=1)
+ x = cattn(xy, query_length=x.shape[1]) + x
+ x = ff(x) + x
+ return x
+
+class Network(nn.Module):
+ def __init__(self, opts):
+ super(Network, self).__init__()
+
+ self.transformer_encoder = TransformerEncoder(dim=512, depth=6, heads=8, dim_head=64, mlp_dim=512, dropout=0)
+ self.transformer_decoder = TransformerDecoder(dim=512, depth=6, heads=8, dim_head=64, mlp_dim=512, dropout=0)
+ self.layer1 = nn.Linear(3, 256)
+ self.layer2 = nn.Linear(512, 256)
+ self.layer3 = nn.Linear(512, 512)
+ self.layer4 = nn.Linear(512, 512)
+ self.mlp_head = nn.Sequential(
+ nn.Linear(512, 512)
+ )
+
+ def forward(self, w, x, y, alpha=1.):
+ #w: latent vectors
+ #x: flow vectors
+ #y: StyleGAN features
+ xh = F.relu(self.layer1(x))
+ yh = F.relu(self.layer2(y))
+ xyh = torch.cat([xh,yh], dim=2)
+ xyh = F.relu(self.layer3(xyh))
+ xyh = self.transformer_encoder(xyh)
+
+ wh = F.relu(self.layer4(w))
+
+ h = self.transformer_decoder(wh, xyh)
+ h = self.mlp_head(h)
+ w_hat = w+alpha*h
+ return w_hat
diff --git a/models/stylegan2/__init__.py b/models/stylegan2/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/stylegan2/model.py b/models/stylegan2/model.py
new file mode 100755
index 0000000000000000000000000000000000000000..f6e4c3441f4d12294f3961be3e7ed932a93090e8
--- /dev/null
+++ b/models/stylegan2/model.py
@@ -0,0 +1,714 @@
+import math
+import random
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
+
+
+class PixelNorm(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
+
+
+def make_kernel(k):
+ k = torch.tensor(k, dtype=torch.float32)
+
+ if k.ndim == 1:
+ k = k[None, :] * k[:, None]
+
+ k /= k.sum()
+
+ return k
+
+
+class Upsample(nn.Module):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel) * (factor ** 2)
+ self.register_buffer('kernel', kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
+
+ return out
+
+
+class Downsample(nn.Module):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel)
+ self.register_buffer('kernel', kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
+
+ return out
+
+
+class Blur(nn.Module):
+ def __init__(self, kernel, pad, upsample_factor=1):
+ super().__init__()
+
+ kernel = make_kernel(kernel)
+
+ if upsample_factor > 1:
+ kernel = kernel * (upsample_factor ** 2)
+
+ self.register_buffer('kernel', kernel)
+
+ self.pad = pad
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
+
+ return out
+
+
+class EqualConv2d(nn.Module):
+ def __init__(
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
+ ):
+ super().__init__()
+
+ self.weight = nn.Parameter(
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
+ )
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
+
+ self.stride = stride
+ self.padding = padding
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_channel))
+
+ else:
+ self.bias = None
+
+ def forward(self, input):
+ out = F.conv2d(
+ input,
+ self.weight * self.scale,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
+ )
+
+
+class EqualLinear(nn.Module):
+ def __init__(
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
+ ):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
+
+ else:
+ self.bias = None
+
+ self.activation = activation
+
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
+ self.lr_mul = lr_mul
+
+ def forward(self, input):
+ if self.activation:
+ out = F.linear(input, self.weight * self.scale)
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
+
+ else:
+ out = F.linear(
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
+ )
+
+
+class ScaledLeakyReLU(nn.Module):
+ def __init__(self, negative_slope=0.2):
+ super().__init__()
+
+ self.negative_slope = negative_slope
+
+ def forward(self, input):
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
+
+ return out * math.sqrt(2)
+
+
+class ModulatedConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ demodulate=True,
+ upsample=False,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ ):
+ super().__init__()
+
+ self.eps = 1e-8
+ self.kernel_size = kernel_size
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+ self.upsample = upsample
+ self.downsample = downsample
+
+ if upsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2 + 1
+
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
+
+ fan_in = in_channel * kernel_size ** 2
+ self.scale = 1 / math.sqrt(fan_in)
+ self.padding = kernel_size // 2
+
+ self.weight = nn.Parameter(
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
+ )
+
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
+
+ self.demodulate = demodulate
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
+ f'upsample={self.upsample}, downsample={self.downsample})'
+ )
+
+ def forward(self, input, style, input_is_stylespace=False):
+ batch, in_channel, height, width = input.shape
+
+ if not input_is_stylespace:
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
+ weight = self.scale * self.weight * style
+
+ if self.demodulate:
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
+
+ weight = weight.view(
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ )
+
+ if self.upsample:
+ input = input.view(1, batch * in_channel, height, width)
+ weight = weight.view(
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ )
+ weight = weight.transpose(1, 2).reshape(
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
+ )
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+ out = self.blur(out)
+
+ elif self.downsample:
+ input = self.blur(input)
+ _, _, height, width = input.shape
+ input = input.view(1, batch * in_channel, height, width)
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ else:
+ input = input.view(1, batch * in_channel, height, width)
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ return out, style
+
+
+class NoiseInjection(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.zeros(1))
+
+ def forward(self, image, noise=None):
+ if noise is None:
+ batch, _, height, width = image.shape
+ noise = image.new_empty(batch, 1, height, width).normal_()
+
+ return image + self.weight * noise
+
+
+class ConstantInput(nn.Module):
+ def __init__(self, channel, size=4):
+ super().__init__()
+
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
+
+ def forward(self, input):
+ batch = input.shape[0]
+ out = self.input.repeat(batch, 1, 1, 1)
+
+ return out
+
+
+class StyledConv(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ demodulate=True,
+ ):
+ super().__init__()
+
+ self.conv = ModulatedConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=upsample,
+ blur_kernel=blur_kernel,
+ demodulate=demodulate,
+ )
+
+ self.noise = NoiseInjection()
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
+ # self.activate = ScaledLeakyReLU(0.2)
+ self.activate = FusedLeakyReLU(out_channel)
+
+ def forward(self, input, style, noise=None, input_is_stylespace=False):
+ out, style = self.conv(input, style, input_is_stylespace=input_is_stylespace)
+ out = self.noise(out, noise=noise)
+ # out = out + self.bias
+ out = self.activate(out)
+
+ return out, style
+
+class ToRGB(nn.Module):
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ if upsample:
+ self.upsample = Upsample(blur_kernel)
+
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
+
+ def forward(self, input, style, skip=None, input_is_stylespace=False):
+ out, style = self.conv(input, style, input_is_stylespace=input_is_stylespace)
+ out = out + self.bias
+
+ if skip is not None:
+ skip = self.upsample(skip)
+
+ out = out + skip
+
+ return out, style
+
+class Generator(nn.Module):
+ def __init__(
+ self,
+ size,
+ style_dim,
+ n_mlp,
+ channel_multiplier=2,
+ blur_kernel=[1, 3, 3, 1],
+ lr_mlp=0.01,
+ ):
+ super().__init__()
+
+ self.size = size
+
+ self.style_dim = style_dim
+
+ layers = [PixelNorm()]
+
+ for i in range(n_mlp):
+ layers.append(
+ EqualLinear(
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
+ )
+ )
+
+ self.style = nn.Sequential(*layers)
+
+ self.channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ self.input = ConstantInput(self.channels[4])
+ self.conv1 = StyledConv(
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
+ )
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
+
+ self.log_size = int(math.log(size, 2))
+ self.num_layers = (self.log_size - 2) * 2 + 1
+
+ self.convs = nn.ModuleList()
+ self.upsamples = nn.ModuleList()
+ self.to_rgbs = nn.ModuleList()
+ self.noises = nn.Module()
+
+ in_channel = self.channels[4]
+
+ for layer_idx in range(self.num_layers):
+ res = (layer_idx + 5) // 2
+ shape = [1, 1, 2 ** res, 2 ** res]
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
+
+ for i in range(3, self.log_size + 1):
+ out_channel = self.channels[2 ** i]
+
+ self.convs.append(
+ StyledConv(
+ in_channel,
+ out_channel,
+ 3,
+ style_dim,
+ upsample=True,
+ blur_kernel=blur_kernel,
+ )
+ )
+
+ self.convs.append(
+ StyledConv(
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
+ )
+ )
+
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
+
+ in_channel = out_channel
+
+ self.n_latent = self.log_size * 2 - 2
+
+ def make_noise(self):
+ device = self.input.input.device
+
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
+
+ for i in range(3, self.log_size + 1):
+ for _ in range(2):
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
+
+ return noises
+
+ def mean_latent(self, n_latent):
+ latent_in = torch.randn(
+ n_latent, self.style_dim, device=self.input.input.device
+ )
+ latent = self.style(latent_in).mean(0, keepdim=True)
+
+ return latent
+
+ def get_latent(self, input):
+ return self.style(input)
+
+ def forward(
+ self,
+ styles,
+ return_latents=False,
+ return_features=False,
+ inject_index=None,
+ truncation=1,
+ truncation_latent=None,
+ input_is_latent=False,
+ input_is_stylespace=False,
+ noise=None,
+ randomize_noise=True,
+ return_feature_map=False,
+ return_s=False
+ ):
+
+ if not input_is_latent and not input_is_stylespace:
+ styles = [self.style(s) for s in styles]
+
+ if noise is None:
+ if randomize_noise:
+ noise = [None] * self.num_layers
+ else:
+ noise = [
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
+ ]
+
+ if truncation < 1 and not input_is_stylespace:
+ style_t = []
+
+ for style in styles:
+ style_t.append(
+ truncation_latent + truncation * (style - truncation_latent)
+ )
+
+ styles = style_t
+
+ if input_is_stylespace:
+ latent = styles[0]
+ elif len(styles) < 2:
+ inject_index = self.n_latent
+
+ if styles[0].ndim < 3:
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ else:
+ latent = styles[0]
+
+ else:
+ if inject_index is None:
+ inject_index = random.randint(1, self.n_latent - 1)
+
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
+
+ latent = torch.cat([latent, latent2], 1)
+
+ style_vector = []
+
+ if not input_is_stylespace:
+ out = self.input(latent)
+ out, out_style = self.conv1(out, latent[:, 0], noise=noise[0])
+ style_vector.append(out_style)
+
+ skip, out_style = self.to_rgb1(out, latent[:, 1])
+ style_vector.append(out_style)
+
+ i = 1
+ else:
+ out = self.input(latent[0])
+ out, out_style = self.conv1(out, latent[0], noise=noise[0], input_is_stylespace=input_is_stylespace)
+ style_vector.append(out_style)
+
+ skip, out_style = self.to_rgb1(out, latent[1], input_is_stylespace=input_is_stylespace)
+ style_vector.append(out_style)
+
+ i = 2
+
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
+ ):
+ if not input_is_stylespace:
+ out, out_style1 = conv1(out, latent[:, i], noise=noise1)
+ out, out_style2 = conv2(out, latent[:, i + 1], noise=noise2)
+ skip, rgb_style = to_rgb(out, latent[:, i + 2], skip)
+ if i==7:
+ feature_map = out
+ style_vector.extend([out_style1, out_style2, rgb_style])
+ i += 2
+ else:
+ out, out_style1 = conv1(out, latent[i], noise=noise1, input_is_stylespace=input_is_stylespace)
+ out, out_style2 = conv2(out, latent[i + 1], noise=noise2, input_is_stylespace=input_is_stylespace)
+ skip, rgb_style = to_rgb(out, latent[i + 2], skip, input_is_stylespace=input_is_stylespace)
+
+ style_vector.extend([out_style1, out_style2, rgb_style])
+
+ i += 3
+
+ image = skip
+
+ if return_feature_map:
+ if return_latents:
+ return image, latent, feature_map
+ elif return_s:
+ return image, style_vector, feature_map
+ else:
+ return image, feature_map
+
+ if return_latents:
+ return image, latent
+ elif return_s:
+ return image, style_vector
+ elif return_features:
+ return image, out
+ else:
+ return image, None
+
+
+class ConvLayer(nn.Sequential):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ bias=True,
+ activate=True,
+ ):
+ layers = []
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
+
+ stride = 2
+ self.padding = 0
+
+ else:
+ stride = 1
+ self.padding = kernel_size // 2
+
+ layers.append(
+ EqualConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ padding=self.padding,
+ stride=stride,
+ bias=bias and not activate,
+ )
+ )
+
+ if activate:
+ if bias:
+ layers.append(FusedLeakyReLU(out_channel))
+
+ else:
+ layers.append(ScaledLeakyReLU(0.2))
+
+ super().__init__(*layers)
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
+
+ self.skip = ConvLayer(
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
+ )
+
+ def forward(self, input):
+ out = self.conv1(input)
+ out = self.conv2(out)
+
+ skip = self.skip(input)
+ out = (out + skip) / math.sqrt(2)
+
+ return out
+
+
+class Discriminator(nn.Module):
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ convs = [ConvLayer(3, channels[size], 1)]
+
+ log_size = int(math.log(size, 2))
+
+ in_channel = channels[size]
+
+ for i in range(log_size, 2, -1):
+ out_channel = channels[2 ** (i - 1)]
+
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
+
+ in_channel = out_channel
+
+ self.convs = nn.Sequential(*convs)
+
+ self.stddev_group = 4
+ self.stddev_feat = 1
+
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
+ self.final_linear = nn.Sequential(
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
+ EqualLinear(channels[4], 1),
+ )
+
+ def forward(self, input):
+ out = self.convs(input)
+
+ batch, channel, height, width = out.shape
+ group = min(batch, self.stddev_group)
+ stddev = out.view(
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
+ )
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
+ stddev = stddev.repeat(group, 1, height, width)
+ out = torch.cat([out, stddev], 1)
+
+ out = self.final_conv(out)
+
+ out = out.view(batch, -1)
+ out = self.final_linear(out)
+
+ return out
diff --git a/models/stylegan2/op/__init__.py b/models/stylegan2/op/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3
--- /dev/null
+++ b/models/stylegan2/op/__init__.py
@@ -0,0 +1,2 @@
+from .fused_act import FusedLeakyReLU, fused_leaky_relu
+from .upfirdn2d import upfirdn2d
diff --git a/models/stylegan2/op/fused_act.py b/models/stylegan2/op/fused_act.py
new file mode 100755
index 0000000000000000000000000000000000000000..76ae78e49570971ee3fe303844b2c3b3fee77fa0
--- /dev/null
+++ b/models/stylegan2/op/fused_act.py
@@ -0,0 +1,85 @@
+import os
+
+import torch
+from torch import nn
+from torch.autograd import Function
+from torch.utils.cpp_extension import load
+
+module_path = os.path.dirname(__file__)
+fused = load(
+ 'fused',
+ sources=[
+ os.path.join(module_path, 'fused_bias_act.cpp'),
+ os.path.join(module_path, 'fused_bias_act_kernel.cu'),
+ ],
+)
+
+
+class FusedLeakyReLUFunctionBackward(Function):
+ @staticmethod
+ def forward(ctx, grad_output, out, negative_slope, scale):
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ empty = grad_output.new_empty(0)
+
+ grad_input = fused.fused_bias_act(
+ grad_output, empty, out, 3, 1, negative_slope, scale
+ )
+
+ dim = [0]
+
+ if grad_input.ndim > 2:
+ dim += list(range(2, grad_input.ndim))
+
+ grad_bias = grad_input.sum(dim).detach()
+
+ return grad_input, grad_bias
+
+ @staticmethod
+ def backward(ctx, gradgrad_input, gradgrad_bias):
+ out, = ctx.saved_tensors
+ gradgrad_out = fused.fused_bias_act(
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
+ )
+
+ return gradgrad_out, None, None, None
+
+
+class FusedLeakyReLUFunction(Function):
+ @staticmethod
+ def forward(ctx, input, bias, negative_slope, scale):
+ empty = input.new_empty(0)
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ out, = ctx.saved_tensors
+
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
+ grad_output, out, ctx.negative_slope, ctx.scale
+ )
+
+ return grad_input, grad_bias, None, None
+
+
+class FusedLeakyReLU(nn.Module):
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
+ super().__init__()
+
+ self.bias = nn.Parameter(torch.zeros(channel))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
diff --git a/models/stylegan2/op/fused_bias_act.cpp b/models/stylegan2/op/fused_bias_act.cpp
new file mode 100755
index 0000000000000000000000000000000000000000..02be898f970bcc8ea297867fcaa4e71b24b3d949
--- /dev/null
+++ b/models/stylegan2/op/fused_bias_act.cpp
@@ -0,0 +1,21 @@
+#include
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(bias);
+
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
+}
\ No newline at end of file
diff --git a/models/stylegan2/op/fused_bias_act_kernel.cu b/models/stylegan2/op/fused_bias_act_kernel.cu
new file mode 100755
index 0000000000000000000000000000000000000000..c9fa56fea7ede7072dc8925cfb0148f136eb85b8
--- /dev/null
+++ b/models/stylegan2/op/fused_bias_act_kernel.cu
@@ -0,0 +1,99 @@
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+
+template
+static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
+
+ scalar_t zero = 0.0;
+
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
+ scalar_t x = p_x[xi];
+
+ if (use_bias) {
+ x += p_b[(xi / step_b) % size_b];
+ }
+
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
+
+ scalar_t y;
+
+ switch (act * 10 + grad) {
+ default:
+ case 10: y = x; break;
+ case 11: y = x; break;
+ case 12: y = 0.0; break;
+
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
+ case 32: y = 0.0; break;
+ }
+
+ out[xi] = y * scale;
+ }
+}
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ auto x = input.contiguous();
+ auto b = bias.contiguous();
+ auto ref = refer.contiguous();
+
+ int use_bias = b.numel() ? 1 : 0;
+ int use_ref = ref.numel() ? 1 : 0;
+
+ int size_x = x.numel();
+ int size_b = b.numel();
+ int step_b = 1;
+
+ for (int i = 1 + 1; i < x.dim(); i++) {
+ step_b *= x.size(i);
+ }
+
+ int loop_x = 4;
+ int block_size = 4 * 32;
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
+
+ auto y = torch::empty_like(x);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
+ fused_bias_act_kernel<<>>(
+ y.data_ptr(),
+ x.data_ptr(),
+ b.data_ptr(),
+ ref.data_ptr(),
+ act,
+ grad,
+ alpha,
+ scale,
+ loop_x,
+ size_x,
+ step_b,
+ size_b,
+ use_bias,
+ use_ref
+ );
+ });
+
+ return y;
+}
\ No newline at end of file
diff --git a/models/stylegan2/op/upfirdn2d.cpp b/models/stylegan2/op/upfirdn2d.cpp
new file mode 100755
index 0000000000000000000000000000000000000000..d2e633dc896433c205e18bc3e455539192ff968e
--- /dev/null
+++ b/models/stylegan2/op/upfirdn2d.cpp
@@ -0,0 +1,23 @@
+#include
+
+
+torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(kernel);
+
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
+}
\ No newline at end of file
diff --git a/models/stylegan2/op/upfirdn2d.py b/models/stylegan2/op/upfirdn2d.py
new file mode 100755
index 0000000000000000000000000000000000000000..7bc5a1e331c2bbb1893ac748cfd0f144ff0651b4
--- /dev/null
+++ b/models/stylegan2/op/upfirdn2d.py
@@ -0,0 +1,184 @@
+import os
+
+import torch
+from torch.autograd import Function
+from torch.utils.cpp_extension import load
+
+module_path = os.path.dirname(__file__)
+upfirdn2d_op = load(
+ 'upfirdn2d',
+ sources=[
+ os.path.join(module_path, 'upfirdn2d.cpp'),
+ os.path.join(module_path, 'upfirdn2d_kernel.cu'),
+ ],
+)
+
+
+class UpFirDn2dBackward(Function):
+ @staticmethod
+ def forward(
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
+ ):
+ up_x, up_y = up
+ down_x, down_y = down
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+ grad_input = upfirdn2d_op.upfirdn2d(
+ grad_output,
+ grad_kernel,
+ down_x,
+ down_y,
+ up_x,
+ up_y,
+ g_pad_x0,
+ g_pad_x1,
+ g_pad_y0,
+ g_pad_y1,
+ )
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
+
+ ctx.save_for_backward(kernel)
+
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ ctx.up_x = up_x
+ ctx.up_y = up_y
+ ctx.down_x = down_x
+ ctx.down_y = down_y
+ ctx.pad_x0 = pad_x0
+ ctx.pad_x1 = pad_x1
+ ctx.pad_y0 = pad_y0
+ ctx.pad_y1 = pad_y1
+ ctx.in_size = in_size
+ ctx.out_size = out_size
+
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_input):
+ kernel, = ctx.saved_tensors
+
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
+
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
+ gradgrad_input,
+ kernel,
+ ctx.up_x,
+ ctx.up_y,
+ ctx.down_x,
+ ctx.down_y,
+ ctx.pad_x0,
+ ctx.pad_x1,
+ ctx.pad_y0,
+ ctx.pad_y1,
+ )
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
+ gradgrad_out = gradgrad_out.view(
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
+ )
+
+ return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+ @staticmethod
+ def forward(ctx, input, kernel, up, down, pad):
+ up_x, up_y = up
+ down_x, down_y = down
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ kernel_h, kernel_w = kernel.shape
+ batch, channel, in_h, in_w = input.shape
+ ctx.in_size = input.shape
+
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ ctx.out_size = (out_h, out_w)
+
+ ctx.up = (up_x, up_y)
+ ctx.down = (down_x, down_y)
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+ g_pad_x0 = kernel_w - pad_x0 - 1
+ g_pad_y0 = kernel_h - pad_y0 - 1
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+ out = upfirdn2d_op.upfirdn2d(
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
+ )
+ # out = out.view(major, out_h, out_w, minor)
+ out = out.view(-1, channel, out_h, out_w)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, grad_kernel = ctx.saved_tensors
+
+ grad_input = UpFirDn2dBackward.apply(
+ grad_output,
+ kernel,
+ grad_kernel,
+ ctx.up,
+ ctx.down,
+ ctx.pad,
+ ctx.g_pad,
+ ctx.in_size,
+ ctx.out_size,
+ )
+
+ return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ out = UpFirDn2d.apply(
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
+ )
+
+ return out
+
+
+def upfirdn2d_native(
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
+):
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
+ )
+ out = out[
+ :,
+ max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
+ :,
+ ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape(
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
+ )
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+
+ return out[:, ::down_y, ::down_x, :]
diff --git a/models/stylegan2/op/upfirdn2d_kernel.cu b/models/stylegan2/op/upfirdn2d_kernel.cu
new file mode 100755
index 0000000000000000000000000000000000000000..2a710aa6adc3d43ac93136a1814e3c39970e1c7e
--- /dev/null
+++ b/models/stylegan2/op/upfirdn2d_kernel.cu
@@ -0,0 +1,272 @@
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+
+static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
+ int c = a / b;
+
+ if (c * b > a) {
+ c--;
+ }
+
+ return c;
+}
+
+
+struct UpFirDn2DKernelParams {
+ int up_x;
+ int up_y;
+ int down_x;
+ int down_y;
+ int pad_x0;
+ int pad_x1;
+ int pad_y0;
+ int pad_y1;
+
+ int major_dim;
+ int in_h;
+ int in_w;
+ int minor_dim;
+ int kernel_h;
+ int kernel_w;
+ int out_h;
+ int out_w;
+ int loop_major;
+ int loop_x;
+};
+
+
+template
+__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
+
+ __shared__ volatile float sk[kernel_h][kernel_w];
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
+
+ int minor_idx = blockIdx.x;
+ int tile_out_y = minor_idx / p.minor_dim;
+ minor_idx -= tile_out_y * p.minor_dim;
+ tile_out_y *= tile_out_h;
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
+ int ky = tap_idx / kernel_w;
+ int kx = tap_idx - ky * kernel_w;
+ scalar_t v = 0.0;
+
+ if (kx < p.kernel_w & ky < p.kernel_h) {
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
+ }
+
+ sk[ky][kx] = v;
+ }
+
+ for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
+ for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
+ int tile_in_x = floor_div(tile_mid_x, up_x);
+ int tile_in_y = floor_div(tile_mid_y, up_y);
+
+ __syncthreads();
+
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
+ int rel_in_y = in_idx / tile_in_w;
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
+ int in_x = rel_in_x + tile_in_x;
+ int in_y = rel_in_y + tile_in_y;
+
+ scalar_t v = 0.0;
+
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
+ }
+
+ sx[rel_in_y][rel_in_x] = v;
+ }
+
+ __syncthreads();
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
+ int rel_out_y = out_idx / tile_out_w;
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
+ int out_x = rel_out_x + tile_out_x;
+ int out_y = rel_out_y + tile_out_y;
+
+ int mid_x = tile_mid_x + rel_out_x * down_x;
+ int mid_y = tile_mid_y + rel_out_y * down_y;
+ int in_x = floor_div(mid_x, up_x);
+ int in_y = floor_div(mid_y, up_y);
+ int rel_in_x = in_x - tile_in_x;
+ int rel_in_y = in_y - tile_in_y;
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
+
+ scalar_t v = 0.0;
+
+ #pragma unroll
+ for (int y = 0; y < kernel_h / up_y; y++)
+ #pragma unroll
+ for (int x = 0; x < kernel_w / up_x; x++)
+ v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
+
+ if (out_x < p.out_w & out_y < p.out_h) {
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
+ }
+ }
+ }
+ }
+}
+
+
+torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ UpFirDn2DKernelParams p;
+
+ auto x = input.contiguous();
+ auto k = kernel.contiguous();
+
+ p.major_dim = x.size(0);
+ p.in_h = x.size(1);
+ p.in_w = x.size(2);
+ p.minor_dim = x.size(3);
+ p.kernel_h = k.size(0);
+ p.kernel_w = k.size(1);
+ p.up_x = up_x;
+ p.up_y = up_y;
+ p.down_x = down_x;
+ p.down_y = down_y;
+ p.pad_x0 = pad_x0;
+ p.pad_x1 = pad_x1;
+ p.pad_y0 = pad_y0;
+ p.pad_y1 = pad_y1;
+
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
+
+ auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
+
+ int mode = -1;
+
+ int tile_out_h;
+ int tile_out_w;
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 1;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
+ mode = 2;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 3;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 4;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 5;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 6;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ dim3 block_size;
+ dim3 grid_size;
+
+ if (tile_out_h > 0 && tile_out_w) {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 1;
+ block_size = dim3(32 * 8, 1, 1);
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ }
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
+ switch (mode) {
+ case 1:
+ upfirdn2d_kernel<<>>(
+ out.data_ptr(), x.data_ptr(), k.data_ptr(), p
+ );
+
+ break;
+
+ case 2:
+ upfirdn2d_kernel<<>>(
+ out.data_ptr(), x.data_ptr(), k.data_ptr(), p
+ );
+
+ break;
+
+ case 3:
+ upfirdn2d_kernel<<>>(
+ out.data_ptr(), x.data_ptr(), k.data_ptr(), p
+ );
+
+ break;
+
+ case 4:
+ upfirdn2d_kernel<<>>(
+ out.data_ptr(), x.data_ptr(), k.data_ptr(), p
+ );
+
+ break;
+
+ case 5:
+ upfirdn2d_kernel<<>>(
+ out.data_ptr(), x.data_ptr(), k.data_ptr(), p
+ );
+
+ break;
+
+ case 6:
+ upfirdn2d_kernel<<>>(
+ out.data_ptr(), x.data_ptr(), k.data_ptr(), p
+ );
+
+ break;
+ }
+ });
+
+ return out;
+}
\ No newline at end of file
diff --git a/options/__init__.py b/options/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/options/train_options.py b/options/train_options.py
new file mode 100755
index 0000000000000000000000000000000000000000..c33c9217be048f5574bde91eb1c6ebd186a265bd
--- /dev/null
+++ b/options/train_options.py
@@ -0,0 +1,33 @@
+from argparse import ArgumentParser
+
+class TrainOptions:
+
+ def __init__(self):
+ self.parser = ArgumentParser()
+ self.initialize()
+
+ def initialize(self):
+ self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
+
+ self.parser.add_argument('--batch_size', default=1, type=int, help='Batch size for training')
+ self.parser.add_argument('--learning_rate', default=0.001, type=float, help='Optimizer learning rate')
+ self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')
+ self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model')
+
+ self.parser.add_argument('--lpips_lambda', default=0., type=float, help='LPIPS loss multiplier factor')
+ self.parser.add_argument('--l2_lambda', default=0, type=float, help='L2 loss multiplier factor')
+ self.parser.add_argument('--l2latent_lambda', default=1.0, type=float, help='L2 loss multiplier factor')
+
+ self.parser.add_argument('--stylegan_weights', default='pretrained_models/stylegan2-cat-config-f.pt', type=str, help='Path to StyleGAN model weights')
+ self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint')
+
+ self.parser.add_argument('--max_steps', default=60100, type=int, help='Maximum number of training steps')
+ self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training')
+ self.parser.add_argument('--save_interval', default=10000, type=int, help='Model checkpoint interval')
+
+ self.parser.add_argument('--style_num', default=14, type=int, help='The number of StyleGAN layers get latent codes ')
+ self.parser.add_argument('--channel_multiplier', default=2, type=int, help='StyleGAN parameter')
+
+ def parse(self):
+ opts = self.parser.parse_args()
+ return opts
\ No newline at end of file
diff --git a/scripts/train.py b/scripts/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..7450aef4a544d3f36568d86a34870fe63ec521d3
--- /dev/null
+++ b/scripts/train.py
@@ -0,0 +1,29 @@
+"""
+This file runs the main training/val loop
+"""
+import os
+import json
+import sys
+import pprint
+
+sys.path.append(".")
+sys.path.append("..")
+
+from options.train_options import TrainOptions
+from training.coach import Coach
+
+
+def main():
+ opts = TrainOptions().parse()
+ os.makedirs(opts.exp_dir, exist_ok=True)
+
+ opts_dict = vars(opts)
+ pprint.pprint(opts_dict)
+ with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
+ json.dump(opts_dict, f, indent=4, sort_keys=True)
+
+ coach = Coach(opts)
+ coach.train()
+
+if __name__ == '__main__':
+ main()
diff --git a/training/__init__.py b/training/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/training/coach.py b/training/coach.py
new file mode 100755
index 0000000000000000000000000000000000000000..30e7009653b8ddecb2bc53bfd4e5aa6f1f23b6ef
--- /dev/null
+++ b/training/coach.py
@@ -0,0 +1,211 @@
+import os
+import math, random
+import numpy as np
+import matplotlib
+import matplotlib.pyplot as plt
+matplotlib.use('Agg')
+
+import torch
+from torch import nn
+from torch.utils.tensorboard import SummaryWriter
+import torch.nn.functional as F
+
+from utils import common
+from criteria.lpips.lpips import LPIPS
+from models.StyleGANControler import StyleGANControler
+from training.ranger import Ranger
+
+from expansion.submission import Expansion
+from expansion.utils.flowlib import point_vec
+
+class Coach:
+ def __init__(self, opts):
+ self.opts = opts
+ if self.opts.checkpoint_path is None:
+ self.global_step = 0
+ else:
+ self.global_step = int(os.path.splitext(os.path.basename(self.opts.checkpoint_path))[0].split('_')[-1])
+
+ self.device = 'cuda:0' # TODO: Allow multiple GPU? currently using CUDA_VISIBLE_DEVICES
+ self.opts.device = self.device
+
+ # Initialize network
+ self.net = StyleGANControler(self.opts).to(self.device)
+
+ # Initialize loss
+ if self.opts.lpips_lambda > 0:
+ self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
+ self.mse_loss = nn.MSELoss().to(self.device).eval()
+
+ # Initialize optimizer
+ self.optimizer = self.configure_optimizers()
+
+ # Initialize logger
+ log_dir = os.path.join(opts.exp_dir, 'logs')
+ os.makedirs(log_dir, exist_ok=True)
+ self.logger = SummaryWriter(log_dir=log_dir)
+
+ # Initialize checkpoint dir
+ self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
+ self.best_val_loss = None
+ if self.opts.save_interval is None:
+ self.opts.save_interval = self.opts.max_steps
+
+ # Initialize optical flow estimator
+ self.ex = Expansion()
+
+ # Set flow normalization values
+ if 'ffhq' in self.opts.stylegan_weights:
+ self.sigma_f = 4
+ self.sigma_e = 0.02
+ elif 'car' in self.opts.stylegan_weights:
+ self.sigma_f = 5
+ self.sigma_e = 0.03
+ elif 'cat' in self.opts.stylegan_weights:
+ self.sigma_f = 12
+ self.sigma_e = 0.04
+ elif 'church' in self.opts.stylegan_weights:
+ self.sigma_f = 8
+ self.sigma_e = 0.02
+ elif 'anime' in self.opts.stylegan_weights:
+ self.sigma_f = 7
+ self.sigma_e = 0.025
+
+ def train(self, truncation = 0.3, sigma = 0.1, target_layers = [0,1,2,3,4,5]):
+
+ x = np.array(range(0,256,16)).astype(np.float32)/127.5-1.
+ y = np.array(range(0,256,16)).astype(np.float32)/127.5-1.
+ xx, yy = np.meshgrid(x,y)
+ grid = np.concatenate([xx[:,:,None],yy[:,:,None]], axis=2)
+ grid = torch.from_numpy(grid[None,:]).cuda()
+ grid = grid.repeat(self.opts.batch_size,1,1,1)
+
+ while self.global_step < self.opts.max_steps:
+ with torch.no_grad():
+ z1 = torch.randn(self.opts.batch_size,512).to("cuda")
+ z2 = torch.randn(self.opts.batch_size,self.net.style_num, 512).to("cuda")
+
+ x1, w1, f1 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_feature_map=True,return_latents=True,truncation=truncation, truncation_latent=self.net.latent_avg[0])
+ x1 = self.net.face_pool(x1)
+ x2, w2 = self.net.decoder([z2],input_is_latent=False,randomize_noise=False,return_latents=True, truncation_latent=self.net.latent_avg[0])
+ x2 = self.net.face_pool(x2)
+ w_mid = w1.clone()
+ w_mid[:,target_layers] = w_mid[:,target_layers]+sigma*(w2[:,target_layers]-w_mid[:,target_layers])
+ x_mid, _ = self.net.decoder([w_mid], input_is_latent=True, randomize_noise=False, return_latents=False)
+ x_mid = self.net.face_pool(x_mid)
+
+ flow, logexp = self.ex.run(x1.detach(),x_mid.detach())
+ flow_feature = torch.cat([flow/self.sigma_f, logexp/self.sigma_e], dim=1)
+ f1 = F.interpolate(f1, (flow_feature.shape[2:]))
+ f1 = F.grid_sample(f1, grid, mode='nearest', align_corners=True)
+ flow_feature = F.grid_sample(flow_feature, grid, mode='nearest', align_corners=True)
+ flow_feature = flow_feature.view(flow_feature.shape[0], flow_feature.shape[1], -1).permute(0,2,1)
+ f1 = f1.view(f1.shape[0], f1.shape[1], -1).permute(0,2,1)
+
+ self.net.train()
+ self.optimizer.zero_grad()
+ w_hat = self.net.encoder(w1[:,target_layers].detach(), flow_feature.detach(), f1.detach())
+ loss, loss_dict, id_logs = self.calc_loss(w_hat, w_mid[:,target_layers].detach())
+ loss.backward()
+ self.optimizer.step()
+
+ w_mid[:,target_layers] = w_hat.detach()
+ x_hat, _ = self.net.decoder([w_mid], input_is_latent=True, randomize_noise=False)
+ x_hat = self.net.face_pool(x_hat)
+ if self.global_step % self.opts.image_interval == 0 or (
+ self.global_step < 1000 and self.global_step % 100 == 0):
+ imgL_o = ((x1.detach()+1.)*127.5)[0].permute(1,2,0).cpu().numpy()
+ flow = torch.cat((flow,torch.ones_like(flow)[:,:1]), dim=1)[0].permute(1,2,0).cpu().numpy()
+ flowvis = point_vec(imgL_o, flow)
+ flowvis = torch.from_numpy(flowvis[:,:,::-1].copy()).permute(2,0,1).unsqueeze(0)/127.5-1.
+ self.parse_and_log_images(None, flowvis, x_mid, x_hat, title='trained_images')
+ print(loss_dict)
+
+ if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
+ self.checkpoint_me(loss_dict, is_best=False)
+
+ if self.global_step == self.opts.max_steps:
+ print('OMG, finished training!')
+ break
+
+ self.global_step += 1
+
+ def checkpoint_me(self, loss_dict, is_best):
+ save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step)
+ save_dict = self.__get_save_dict()
+ checkpoint_path = os.path.join(self.checkpoint_dir, save_name)
+ torch.save(save_dict, checkpoint_path)
+ with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f:
+ if is_best:
+ f.write('**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict))
+ else:
+ f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict))
+
+ def configure_optimizers(self):
+ params = list(self.net.encoder.parameters())
+ if self.opts.train_decoder:
+ params += list(self.net.decoder.parameters())
+ if self.opts.optim_name == 'adam':
+ optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate)
+ else:
+ optimizer = Ranger(params, lr=self.opts.learning_rate)
+ return optimizer
+
+ def calc_loss(self, latent, w, y_hat=None, y=None):
+ loss_dict = {}
+ loss = 0.0
+ id_logs = None
+
+ if self.opts.l2_lambda > 0 and (y_hat is not None) and (y is not None):
+ loss_l2 = F.mse_loss(y_hat, y)
+ loss_dict['loss_l2'] = float(loss_l2)
+ loss += loss_l2 * self.opts.l2_lambda
+ if self.opts.lpips_lambda > 0 and (y_hat is not None) and (y is not None):
+ loss_lpips = self.lpips_loss(y_hat, y)
+ loss_dict['loss_lpips'] = float(loss_lpips)
+ loss += loss_lpips * self.opts.lpips_lambda
+ if self.opts.l2latent_lambda > 0:
+ loss_l2 = F.mse_loss(latent, w)
+ loss_dict['loss_l2latent'] = float(loss_l2)
+ loss += loss_l2 * self.opts.l2latent_lambda
+
+ loss_dict['loss'] = float(loss)
+ return loss, loss_dict, id_logs
+
+ def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=1):
+ im_data = []
+ for i in range(display_count):
+ cur_im_data = {
+ 'input_face': common.tensor2im(x[i]),
+ 'target_face': common.tensor2im(y[i]),
+ 'output_face': common.tensor2im(y_hat[i]),
+ }
+ if id_logs is not None:
+ for key in id_logs[i]:
+ cur_im_data[key] = id_logs[i][key]
+ im_data.append(cur_im_data)
+ self.log_images(title, im_data=im_data, subscript=subscript)
+
+
+ def log_images(self, name, im_data, subscript=None, log_latest=False):
+ fig = common.vis_faces(im_data)
+ step = self.global_step
+ if log_latest:
+ step = 0
+ if subscript:
+ path = os.path.join(self.logger.log_dir, name, '{}_{:04d}.jpg'.format(subscript, step))
+ else:
+ path = os.path.join(self.logger.log_dir, name, '{:04d}.jpg'.format(step))
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ fig.savefig(path)
+ plt.close(fig)
+
+ def __get_save_dict(self):
+ save_dict = {
+ 'state_dict': self.net.state_dict(),
+ 'opts': vars(self.opts)
+ }
+
+ save_dict['latent_avg'] = self.net.latent_avg
+ return save_dict
\ No newline at end of file
diff --git a/training/ranger.py b/training/ranger.py
new file mode 100755
index 0000000000000000000000000000000000000000..441cded317542a82229ddfbd3073639f8d3e706b
--- /dev/null
+++ b/training/ranger.py
@@ -0,0 +1,164 @@
+# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
+
+# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
+# and/or
+# https://github.com/lessw2020/Best-Deep-Learning-Optimizers
+
+# Ranger has now been used to capture 12 records on the FastAI leaderboard.
+
+# This version = 20.4.11
+
+# Credits:
+# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
+# RAdam --> https://github.com/LiyuanLucasLiu/RAdam
+# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
+# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
+
+# summary of changes:
+# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
+# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
+# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
+# changes 8/31/19 - fix references to *self*.N_sma_threshold;
+# changed eps to 1e-5 as better default than 1e-8.
+
+import math
+import torch
+from torch.optim.optimizer import Optimizer
+
+
+class Ranger(Optimizer):
+
+ def __init__(self, params, lr=1e-3, # lr
+ alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options
+ betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options
+ use_gc=True, gc_conv_only=False
+ # Gradient centralization on or off, applied to conv layers only or conv + fc layers
+ ):
+
+ # parameter checks
+ if not 0.0 <= alpha <= 1.0:
+ raise ValueError(f'Invalid slow update rate: {alpha}')
+ if not 1 <= k:
+ raise ValueError(f'Invalid lookahead steps: {k}')
+ if not lr > 0:
+ raise ValueError(f'Invalid Learning Rate: {lr}')
+ if not eps > 0:
+ raise ValueError(f'Invalid eps: {eps}')
+
+ # parameter comments:
+ # beta1 (momentum) of .95 seems to work better than .90...
+ # N_sma_threshold of 5 seems better in testing than 4.
+ # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
+
+ # prep defaults and init torch.optim base
+ defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold,
+ eps=eps, weight_decay=weight_decay)
+ super().__init__(params, defaults)
+
+ # adjustable threshold
+ self.N_sma_threshhold = N_sma_threshhold
+
+ # look ahead params
+
+ self.alpha = alpha
+ self.k = k
+
+ # radam buffer for state
+ self.radam_buffer = [[None, None, None] for ind in range(10)]
+
+ # gc on or off
+ self.use_gc = use_gc
+
+ # level of gradient centralization
+ self.gc_gradient_threshold = 3 if gc_conv_only else 1
+
+ def __setstate__(self, state):
+ super(Ranger, self).__setstate__(state)
+
+ def step(self, closure=None):
+ loss = None
+
+ # Evaluate averages and grad, update param tensors
+ for group in self.param_groups:
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data.float()
+
+ if grad.is_sparse:
+ raise RuntimeError('Ranger optimizer does not support sparse gradients')
+
+ p_data_fp32 = p.data.float()
+
+ state = self.state[p] # get state dict for this param
+
+ if len(state) == 0: # if first time to run...init dictionary with our desired entries
+ # if self.first_run_check==0:
+ # self.first_run_check=1
+ # print("Initializing slow buffer...should not see this at load from saved model!")
+ state['step'] = 0
+ state['exp_avg'] = torch.zeros_like(p_data_fp32)
+ state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
+
+ # look ahead weight storage now in state dict
+ state['slow_buffer'] = torch.empty_like(p.data)
+ state['slow_buffer'].copy_(p.data)
+
+ else:
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
+
+ # begin computations
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ # GC operation for Conv layers and FC layers
+ if grad.dim() > self.gc_gradient_threshold:
+ grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
+
+ state['step'] += 1
+
+ # compute variance mov avg
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+ # compute mean moving avg
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+
+ buffered = self.radam_buffer[int(state['step'] % 10)]
+
+ if state['step'] == buffered[0]:
+ N_sma, step_size = buffered[1], buffered[2]
+ else:
+ buffered[0] = state['step']
+ beta2_t = beta2 ** state['step']
+ N_sma_max = 2 / (1 - beta2) - 1
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+ buffered[1] = N_sma
+ if N_sma > self.N_sma_threshhold:
+ step_size = math.sqrt(
+ (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
+ N_sma_max - 2)) / (1 - beta1 ** state['step'])
+ else:
+ step_size = 1.0 / (1 - beta1 ** state['step'])
+ buffered[2] = step_size
+
+ if group['weight_decay'] != 0:
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+
+ # apply lr
+ if N_sma > self.N_sma_threshhold:
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+ p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
+ else:
+ p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
+
+ p.data.copy_(p_data_fp32)
+
+ # integrated look ahead...
+ # we do it at the param level instead of group level
+ if state['step'] % group['k'] == 0:
+ slow_p = state['slow_buffer'] # get access to slow param tensor
+ slow_p.add_(p.data - slow_p, alpha=self.alpha) # (fast weights - slow weights) * alpha
+ p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor
+
+ return loss
\ No newline at end of file
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/common.py b/utils/common.py
new file mode 100755
index 0000000000000000000000000000000000000000..5b11560ef7538dd12ec7d1d51ab650efb3f15ea4
--- /dev/null
+++ b/utils/common.py
@@ -0,0 +1,93 @@
+import cv2
+import numpy as np
+from PIL import Image
+import matplotlib.pyplot as plt
+import random
+
+
+# Log images
+def log_input_image(x, opts):
+ if opts.label_nc == 0:
+ return tensor2im(x)
+ elif opts.label_nc == 1:
+ return tensor2sketch(x)
+ else:
+ return tensor2map(x)
+
+
+def tensor2im(var):
+ var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
+ var = ((var + 1) / 2)
+ var[var < 0] = 0
+ var[var > 1] = 1
+ var = var * 255
+ return Image.fromarray(var.astype('uint8'))
+
+
+def tensor2map(var):
+ mask = np.argmax(var.data.cpu().numpy(), axis=0)
+ colors = get_colors()
+ mask_image = np.ones(shape=(mask.shape[0], mask.shape[1], 3))
+ for class_idx in np.unique(mask):
+ mask_image[mask == class_idx] = colors[class_idx]
+ mask_image = mask_image.astype('uint8')
+ return Image.fromarray(mask_image)
+
+
+def tensor2sketch(var):
+ im = var[0].cpu().detach().numpy()
+ im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
+ im = (im * 255).astype(np.uint8)
+ return Image.fromarray(im)
+
+
+# Visualization utils
+def get_colors():
+ # currently support up to 19 classes (for the celebs-hq-mask dataset)
+ colors = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255],
+ [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204],
+ [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
+
+ # asign random colors to more 200 classes
+ random.seed(0)
+ for i in range(200):
+ colors.append([random.randint(0,255),random.randint(0,255),random.randint(0,255)])
+ return colors
+
+
+def vis_faces(log_hooks):
+ display_count = len(log_hooks)
+ fig = plt.figure(figsize=(8, 4 * display_count))
+ gs = fig.add_gridspec(display_count, 3)
+ for i in range(display_count):
+ hooks_dict = log_hooks[i]
+ fig.add_subplot(gs[i, 0])
+ if 'diff_input' in hooks_dict:
+ vis_faces_with_id(hooks_dict, fig, gs, i)
+ else:
+ vis_faces_no_id(hooks_dict, fig, gs, i)
+ plt.tight_layout()
+ return fig
+
+
+def vis_faces_with_id(hooks_dict, fig, gs, i):
+ plt.imshow(hooks_dict['input_face'])
+ plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input'])))
+ fig.add_subplot(gs[i, 1])
+ plt.imshow(hooks_dict['target_face'])
+ plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']),
+ float(hooks_dict['diff_target'])))
+ fig.add_subplot(gs[i, 2])
+ plt.imshow(hooks_dict['output_face'])
+ plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target'])))
+
+
+def vis_faces_no_id(hooks_dict, fig, gs, i):
+ plt.imshow(hooks_dict['input_face'], cmap="gray")
+ plt.title('Input')
+ fig.add_subplot(gs[i, 1])
+ plt.imshow(hooks_dict['target_face'])
+ plt.title('Target')
+ fig.add_subplot(gs[i, 2])
+ plt.imshow(hooks_dict['output_face'])
+ plt.title('Output')