diff --git a/LICENSE b/LICENSE new file mode 100755 index 0000000000000000000000000000000000000000..7f570ae807e29e43c7d7b8c56954bc2b6c969a0b --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Yuki Endo + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100755 index 0000000000000000000000000000000000000000..bab85c36e531f0723147e22554a45c2f9a7bfbd2 --- /dev/null +++ b/README.md @@ -0,0 +1,51 @@ +# User-Controllable Latent Transformer for StyleGAN Image Layout Editing + + +

+ +

+ +This repository contains our implementation of the following paper: + +Yuki Endo: "User-Controllable Latent Transformer for StyleGAN Image Layout Editing," Computer Graphpics Forum (Pacific Graphics 2022) [[Project](http://www.cgg.cs.tsukuba.ac.jp/~endo/projects/UserControllableLT)] [[PDF (preprint)]()] + +## Prerequisites +1. Python 3.8 +2. PyTorch 1.9.0 +3. Flask +4. Others (see env.yml) + +## Preparation +Download and decompress our pre-trained models. + +## Inference with our pre-trained models +
+We provide an interactive interface based on Flask. This interface can be locally launched with +``` +python interface/flask_app.py --checkpoint_path=pretrained_models/latent_transformer/cat.pt +``` +The interface can be accessed via http://localhost:8000/. + +## Training +The latent transformer can be trained with +``` +python scripts/train.py --exp_dir=results --stylegan_weights=pretrained_models/stylegan2-cat-config-f.pt +``` + +## Citation +Please cite our paper if you find the code useful: +``` +@Article{endoPG2022, +Title = {User-Controllable Latent Transformer for StyleGAN Image Layout Editing}, +Author = {Yuki Endo}, +Journal = {Computer Graphics Forum}, +volume = {}, +number = {}, +pages = {}, +doi = {}, +Year = {2022} +} +``` + +## Acknowledgements +This code heavily borrows from the [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel) and [expansion](https://github.com/gengshan-y/expansion) repositories. diff --git a/criteria/__init__.py b/criteria/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/criteria/lpips/__init__.py b/criteria/lpips/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/criteria/lpips/lpips.py b/criteria/lpips/lpips.py new file mode 100755 index 0000000000000000000000000000000000000000..36b220b37a60c08391a2e31b24c4fde49b726a8a --- /dev/null +++ b/criteria/lpips/lpips.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + +from criteria.lpips.networks import get_network, LinLayers +from criteria.lpips.utils import get_state_dict + + +class LPIPS(nn.Module): + r"""Creates a criterion that measures + Learned Perceptual Image Patch Similarity (LPIPS). + Arguments: + net_type (str): the network type to compare the features: + 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. + version (str): the version of LPIPS. Default: 0.1. + """ + def __init__(self, net_type: str = 'alex', version: str = '0.1'): + + assert version in ['0.1'], 'v0.1 is only supported now' + + super(LPIPS, self).__init__() + + # pretrained network + self.net = get_network(net_type).to("cuda") + + # linear layers + self.lin = LinLayers(self.net.n_channels_list).to("cuda") + self.lin.load_state_dict(get_state_dict(net_type, version)) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + feat_x, feat_y = self.net(x), self.net(y) + + diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] + res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] + + return torch.sum(torch.cat(res, 0)) / x.shape[0] diff --git a/criteria/lpips/networks.py b/criteria/lpips/networks.py new file mode 100755 index 0000000000000000000000000000000000000000..7b182d0dfd3f7da57f13fdb9d724efd2f2ec5615 --- /dev/null +++ b/criteria/lpips/networks.py @@ -0,0 +1,96 @@ +from typing import Sequence + +from itertools import chain + +import torch +import torch.nn as nn +from torchvision import models + +from criteria.lpips.utils import normalize_activation + + +def get_network(net_type: str): + if net_type == 'alex': + return AlexNet() + elif net_type == 'squeeze': + return SqueezeNet() + elif net_type == 'vgg': + return VGG16() + else: + raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') + + +class LinLayers(nn.ModuleList): + def __init__(self, n_channels_list: Sequence[int]): + super(LinLayers, self).__init__([ + nn.Sequential( + nn.Identity(), + nn.Conv2d(nc, 1, 1, 1, 0, bias=False) + ) for nc in n_channels_list + ]) + + for param in self.parameters(): + param.requires_grad = False + + +class BaseNet(nn.Module): + def __init__(self): + super(BaseNet, self).__init__() + + # register buffer + self.register_buffer( + 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer( + 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def set_requires_grad(self, state: bool): + for param in chain(self.parameters(), self.buffers()): + param.requires_grad = state + + def z_score(self, x: torch.Tensor): + return (x - self.mean) / self.std + + def forward(self, x: torch.Tensor): + x = self.z_score(x) + + output = [] + for i, (_, layer) in enumerate(self.layers._modules.items(), 1): + x = layer(x) + if i in self.target_layers: + output.append(normalize_activation(x)) + if len(output) == len(self.target_layers): + break + return output + + +class SqueezeNet(BaseNet): + def __init__(self): + super(SqueezeNet, self).__init__() + + self.layers = models.squeezenet1_1(True).features + self.target_layers = [2, 5, 8, 10, 11, 12, 13] + self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] + + self.set_requires_grad(False) + + +class AlexNet(BaseNet): + def __init__(self): + super(AlexNet, self).__init__() + + self.layers = models.alexnet(True).features + self.target_layers = [2, 5, 8, 10, 12] + self.n_channels_list = [64, 192, 384, 256, 256] + + self.set_requires_grad(False) + + +class VGG16(BaseNet): + def __init__(self): + super(VGG16, self).__init__() + + self.layers = models.vgg16(True).features + self.target_layers = [4, 9, 16, 23, 30] + self.n_channels_list = [64, 128, 256, 512, 512] + + self.set_requires_grad(False) \ No newline at end of file diff --git a/criteria/lpips/utils.py b/criteria/lpips/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..28f1867edc419a4b679b2775bcfa26ad6c9704b4 --- /dev/null +++ b/criteria/lpips/utils.py @@ -0,0 +1,30 @@ +from collections import OrderedDict + +import torch + + +def normalize_activation(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def get_state_dict(net_type: str = 'alex', version: str = '0.1'): + # build url + url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ + + f'master/lpips/weights/v{version}/{net_type}.pth' + + # download + old_state_dict = torch.hub.load_state_dict_from_url( + url, progress=True, + map_location=None if torch.cuda.is_available() else torch.device('cpu') + ) + + # rename keys + new_state_dict = OrderedDict() + for key, val in old_state_dict.items(): + new_key = key + new_key = new_key.replace('lin', '') + new_key = new_key.replace('model.', '') + new_state_dict[new_key] = val + + return new_state_dict diff --git a/docs/teaser.jpg b/docs/teaser.jpg new file mode 100755 index 0000000000000000000000000000000000000000..9edca9dea6895789be942a061776ead041efdee6 Binary files /dev/null and b/docs/teaser.jpg differ diff --git a/docs/thumb.gif b/docs/thumb.gif new file mode 100755 index 0000000000000000000000000000000000000000..80f434f5e416339e96def927957d8ada091b7130 Binary files /dev/null and b/docs/thumb.gif differ diff --git a/env.yaml b/env.yaml new file mode 100755 index 0000000000000000000000000000000000000000..95f073b68ae7f500d5c0311e5912350e0fd6c116 --- /dev/null +++ b/env.yaml @@ -0,0 +1,380 @@ +name: uclt +channels: + - pytorch + - anaconda + - nvidia + - conda-forge + - defaults +dependencies: + - _ipyw_jlab_nb_ext_conf=0.1.0=py38_0 + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=1_llvm + - absl-py=0.13.0=pyhd8ed1ab_0 + - aiohttp=3.7.4.post0=py38h497a2fe_0 + - albumentations=1.0.3=pyhd8ed1ab_0 + - alsa-lib=1.2.3=h516909a_0 + - anaconda-client=1.8.0=py38h06a4308_0 + - anaconda-navigator=2.0.4=py38_0 + - anyio=2.2.0=py38h06a4308_1 + - appdirs=1.4.4=pyh9f0ad1d_0 + - argon2-cffi=20.1.0=py38h27cfd23_1 + - async-timeout=3.0.1=py_1000 + - async_generator=1.10=pyhd3eb1b0_0 + - attrs=21.2.0=pyhd3eb1b0_0 + - babel=2.9.1=pyhd3eb1b0_0 + - backcall=0.2.0=pyhd3eb1b0_0 + - backports=1.0=pyhd3eb1b0_2 + - backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0 + - backports.tempfile=1.0=pyhd3eb1b0_1 + - backports.weakref=1.0.post1=py_1 + - beautifulsoup4=4.9.3=pyha847dfd_0 + - blas=1.0=mkl + - bleach=4.0.0=pyhd3eb1b0_0 + - blinker=1.4=py_1 + - brotli=1.0.9=h7f98852_5 + - brotli-bin=1.0.9=h7f98852_5 + - brotlipy=0.7.0=py38h27cfd23_1003 + - bzip2=1.0.8=h7b6447c_0 + - c-ares=1.17.1=h27cfd23_0 + - ca-certificates=2021.10.8=ha878542_0 + - cachetools=4.2.2=pyhd8ed1ab_0 + - cairo=1.16.0=hf32fb01_1 + - certifi=2021.10.8=py38h578d9bd_1 + - cffi=1.14.6=py38h400218f_0 + - chardet=4.0.0=py38h06a4308_1003 + - click=8.0.1=pyhd3eb1b0_0 + - cloudpickle=1.6.0=py_0 + - clyent=1.2.2=py38_1 + - conda=4.11.0=py38h578d9bd_0 + - conda-build=3.21.4=py38h06a4308_0 + - conda-content-trust=0.1.1=pyhd3eb1b0_0 + - conda-env=2.6.0=1 + - conda-package-handling=1.7.3=py38h27cfd23_1 + - conda-repo-cli=1.0.4=pyhd3eb1b0_0 + - conda-token=0.3.0=pyhd3eb1b0_0 + - conda-verify=3.4.2=py_1 + - cryptography=3.4.7=py38hd23ed53_0 + - cudatoolkit=11.1.74=h6bb024c_0 + - cycler=0.10.0=py_2 + - cytoolz=0.11.0=py38h497a2fe_3 + - dask-core=2021.8.1=pyhd8ed1ab_0 + - dbus=1.13.18=hb2f20db_0 + - decorator=5.0.9=pyhd3eb1b0_0 + - defusedxml=0.7.1=pyhd3eb1b0_0 + - dill=0.3.4=pyhd8ed1ab_0 + - dominate=2.6.0=pyhd8ed1ab_0 + - entrypoints=0.3=py38_0 + - enum34=1.1.10=py38h32f6830_2 + - expat=2.4.1=h2531618_2 + - ffmpeg=4.3.2=hca11adc_0 + - filelock=3.0.12=pyhd3eb1b0_1 + - flask=1.1.2=pyh9f0ad1d_0 + - flask-httpauth=4.4.0=pyhd8ed1ab_0 + - fontconfig=2.13.1=h6c09931_0 + - fonttools=4.25.0=pyhd3eb1b0_0 + - freetype=2.10.4=h5ab3b9f_0 + - fsspec=2021.7.0=pyhd8ed1ab_0 + - ftfy=6.0.3=pyhd8ed1ab_0 + - func_timeout=4.3.5=py_0 + - future=0.18.2=py38_1 + - gdown=4.2.0=pyhd8ed1ab_0 + - geos=3.10.0=h9c3ff4c_0 + - gettext=0.19.8.1=h0b5b191_1005 + - git=2.23.0=pl526hacde149_0 + - glib=2.68.4=h9c3ff4c_0 + - glib-tools=2.68.4=h9c3ff4c_0 + - glob2=0.7=pyhd3eb1b0_0 + - gmp=6.2.1=h58526e2_0 + - gnutls=3.6.13=h85f3911_1 + - google-auth=1.35.0=pyh6c4a22f_0 + - google-auth-oauthlib=0.4.5=pyhd8ed1ab_0 + - gputil=1.4.0=pyh9f0ad1d_0 + - graphite2=1.3.13=h58526e2_1001 + - gst-plugins-base=1.18.4=hf529b03_2 + - gstreamer=1.18.4=h76c114f_2 + - harfbuzz=2.9.0=h83ec7ef_0 + - hdf5=1.10.6=nompi_h6a2412b_1114 + - icu=68.1=h58526e2_0 + - idna=2.10=pyhd3eb1b0_0 + - imagecodecs-lite=2019.12.3=py38h5c078b8_3 + - imageio=2.9.0=py_0 + - imageio-ffmpeg=0.4.5=pyhd8ed1ab_0 + - imgaug=0.4.0=py_1 + - importlib-metadata=3.10.0=py38h06a4308_0 + - importlib_metadata=3.10.0=hd3eb1b0_0 + - intel-openmp=2021.3.0=h06a4308_3350 + - ipykernel=5.3.4=py38h5ca1d4c_0 + - ipympl=0.8.2=pyhd8ed1ab_0 + - ipython=7.26.0=py38hb070fc8_0 + - ipython_genutils=0.2.0=pyhd3eb1b0_1 + - ipywidgets=7.6.3=pyhd3eb1b0_1 + - itsdangerous=2.0.1=pyhd3eb1b0_0 + - jasper=1.900.1=h07fcdf6_1006 + - jedi=0.18.0=py38h06a4308_1 + - jinja2=2.11.3=pyhd3eb1b0_0 + - joblib=1.1.0=pyhd8ed1ab_0 + - jpeg=9d=h36c2ea0_0 + - json5=0.9.6=pyhd3eb1b0_0 + - jsonnet=0.17.0=py38hadf7658_0 + - jsonschema=3.2.0=py_2 + - jupyter_client=6.1.12=pyhd3eb1b0_0 + - jupyter_core=4.7.1=py38h06a4308_0 + - jupyter_server=1.4.1=py38h06a4308_0 + - jupyterlab=3.1.7=pyhd3eb1b0_0 + - jupyterlab_pygments=0.1.2=py_0 + - jupyterlab_server=2.7.1=pyhd3eb1b0_0 + - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 + - kiwisolver=1.3.1=py38h1fd1430_1 + - krb5=1.19.2=hcc1bbae_0 + - lame=3.100=h7f98852_1001 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.35.1=h7274673_9 + - libarchive=3.4.2=h62408e4_0 + - libblas=3.9.0=11_linux64_mkl + - libbrotlicommon=1.0.9=h7f98852_5 + - libbrotlidec=1.0.9=h7f98852_5 + - libbrotlienc=1.0.9=h7f98852_5 + - libcblas=3.9.0=11_linux64_mkl + - libcurl=7.78.0=h2574ce0_0 + - libedit=3.1.20191231=he28a2e2_2 + - libev=4.33=h516909a_1 + - libevent=2.1.10=hcdb4288_3 + - libffi=3.3=he6710b0_2 + - libgcc-ng=11.1.0=hc902ee8_8 + - libgfortran-ng=11.1.0=h69a702a_8 + - libgfortran5=11.1.0=h6c583b3_8 + - libglib=2.68.4=h3e27bee_0 + - libiconv=1.16=h516909a_0 + - liblapack=3.9.0=11_linux64_mkl + - liblapacke=3.9.0=11_linux64_mkl + - liblief=0.10.1=he6710b0_0 + - libllvm11=11.1.0=hf817b99_2 + - libnghttp2=1.43.0=h812cca2_0 + - libogg=1.3.4=h7f98852_1 + - libopencv=4.5.2=py38hcdf9bf1_0 + - libopus=1.3.1=h7f98852_1 + - libpng=1.6.37=hbc83047_0 + - libpq=13.3=hd57d9b9_0 + - libprotobuf=3.15.8=h780b84a_0 + - libsodium=1.0.18=h7b6447c_0 + - libssh2=1.9.0=ha56f1ee_6 + - libstdcxx-ng=11.1.0=h56837e0_8 + - libtiff=4.2.0=h85742a9_0 + - libuuid=1.0.3=h1bed415_2 + - libuv=1.40.0=h7b6447c_0 + - libvorbis=1.3.7=h9c3ff4c_0 + - libwebp-base=1.2.0=h27cfd23_0 + - libxcb=1.14=h7b6447c_0 + - libxkbcommon=1.0.3=he3ba5ed_0 + - libxml2=2.9.12=h72842e0_0 + - llvm-openmp=12.0.1=h4bd325d_1 + - locket=0.2.0=py_2 + - lz4-c=1.9.3=h295c915_1 + - markdown=3.3.4=pyhd8ed1ab_0 + - markupsafe=2.0.1=py38h27cfd23_0 + - matplotlib=3.4.2=py38h578d9bd_0 + - matplotlib-base=3.4.2=py38hab158f2_0 + - matplotlib-inline=0.1.2=pyhd3eb1b0_2 + - mistune=0.8.4=py38h7b6447c_1000 + - mkl=2021.3.0=h06a4308_520 + - mkl-service=2.4.0=py38h7f8727e_0 + - mkl_fft=1.3.0=py38h42c9631_2 + - mkl_random=1.2.2=py38h51133e4_0 + - multidict=5.1.0=py38h497a2fe_1 + - munkres=1.1.4=pyh9f0ad1d_0 + - mysql-common=8.0.25=ha770c72_0 + - mysql-libs=8.0.25=h935591d_0 + - navigator-updater=0.2.1=py38_0 + - nbclassic=0.2.6=pyhd3eb1b0_0 + - nbclient=0.5.3=pyhd3eb1b0_0 + - nbconvert=6.1.0=py38h06a4308_0 + - nbformat=5.1.3=pyhd3eb1b0_0 + - ncurses=6.2=he6710b0_1 + - nest-asyncio=1.5.1=pyhd3eb1b0_0 + - nettle=3.6=he412f7d_0 + - networkx=2.3=py_0 + - ninja=1.10.2=hff7bd54_1 + - notebook=6.4.3=py38h06a4308_0 + - nspr=4.30=h9c3ff4c_0 + - nss=3.69=hb5efdd6_0 + - numpy=1.20.3=py38hf144106_0 + - numpy-base=1.20.3=py38h74d4b33_0 + - oauthlib=3.1.1=pyhd8ed1ab_0 + - olefile=0.46=py_0 + - opencv=4.5.2=py38h578d9bd_0 + - openh264=2.1.1=h780b84a_0 + - openjpeg=2.3.0=h05c96fa_1 + - openssl=1.1.1l=h7f98852_0 + - packaging=21.0=pyhd3eb1b0_0 + - pandas=1.3.2=py38h43a58ef_0 + - pandocfilters=1.4.3=py38h06a4308_1 + - parso=0.8.2=pyhd3eb1b0_0 + - partd=1.2.0=pyhd8ed1ab_0 + - patchelf=0.12=h2531618_1 + - pathlib=1.0.1=py38h578d9bd_4 + - patsy=0.5.1=py_0 + - pcre=8.45=h295c915_0 + - perl=5.26.2=h14c3975_0 + - pexpect=4.8.0=pyhd3eb1b0_3 + - pickleshare=0.7.5=pyhd3eb1b0_1003 + - pillow=8.3.1=py38h2c7a002_0 + - pip=21.2.2=py38h06a4308_0 + - pixman=0.40.0=h36c2ea0_0 + - pkginfo=1.7.1=py38h06a4308_0 + - pooch=1.5.1=pyhd8ed1ab_0 + - portalocker=1.7.0=py38h578d9bd_1 + - prometheus_client=0.11.0=pyhd3eb1b0_0 + - prompt-toolkit=3.0.17=pyh06a4308_0 + - protobuf=3.15.8=py38h709712a_0 + - psutil=5.8.0=py38h27cfd23_1 + - ptyprocess=0.7.0=pyhd3eb1b0_2 + - py-lief=0.10.1=py38h403a769_0 + - py-opencv=4.5.2=py38hd0cf306_0 + - pyasn1=0.4.8=py_0 + - pyasn1-modules=0.2.7=py_0 + - pycosat=0.6.3=py38h7b6447c_1 + - pycparser=2.20=py_2 + - pygments=2.10.0=pyhd3eb1b0_0 + - pyjwt=2.1.0=pyhd8ed1ab_0 + - pyopenssl=20.0.1=pyhd3eb1b0_1 + - pyparsing=2.4.7=pyhd3eb1b0_0 + - pypng=0.0.20=py_0 + - pyqt=5.12.3=py38h578d9bd_7 + - pyqt-impl=5.12.3=py38h7400c14_7 + - pyqt5-sip=4.19.18=py38h709712a_7 + - pyqtchart=5.12=py38h7400c14_7 + - pyqtwebengine=5.12.1=py38h7400c14_7 + - pyrsistent=0.17.3=py38h7b6447c_0 + - pysocks=1.7.1=py38h06a4308_0 + - python=3.8.10=h12debd9_8 + - python-dateutil=2.8.2=pyhd3eb1b0_0 + - python-libarchive-c=2.9=pyhd3eb1b0_1 + - python-lmdb=0.99=py38h709712a_0 + - python_abi=3.8=2_cp38 + - pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0 + - pytz=2021.1=pyhd3eb1b0_0 + - pyu2f=0.1.5=pyhd8ed1ab_0 + - pywavelets=1.1.1=py38h5c078b8_3 + - pyyaml=5.4.1=py38h27cfd23_1 + - pyzmq=22.2.1=py38h295c915_1 + - qt=5.12.9=hda022c4_4 + - qtpy=1.9.0=py_0 + - readline=8.1=h27cfd23_0 + - regex=2021.8.28=py38h497a2fe_0 + - requests=2.25.1=pyhd3eb1b0_0 + - requests-oauthlib=1.3.0=pyh9f0ad1d_0 + - ripgrep=12.1.1=0 + - rsa=4.7.2=pyh44b312d_0 + - ruamel_yaml=0.15.100=py38h27cfd23_0 + - scikit-image=0.18.3=py38h43a58ef_0 + - scikit-learn=1.0=py38hacb3eff_1 + - scipy=1.7.1=py38h56a6a73_0 + - seaborn=0.11.2=hd8ed1ab_0 + - seaborn-base=0.11.2=pyhd8ed1ab_0 + - send2trash=1.5.0=pyhd3eb1b0_1 + - setuptools=52.0.0=py38h06a4308_0 + - shapely=1.8.0=py38hf7953bd_1 + - sip=4.19.13=py38he6710b0_0 + - sniffio=1.2.0=py38h06a4308_1 + - soupsieve=2.2.1=pyhd3eb1b0_0 + - sqlite=3.36.0=hc218d9a_0 + - statsmodels=0.12.2=py38h5c078b8_0 + - tensorboard=2.6.0=pyhd8ed1ab_1 + - tensorboard-data-server=0.6.0=py38h2b97feb_0 + - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 + - tensorboardx=2.4=pyhd8ed1ab_0 + - terminado=0.9.4=py38h06a4308_0 + - testpath=0.5.0=pyhd3eb1b0_0 + - threadpoolctl=3.0.0=pyh8a188c0_0 + - tifffile=2019.7.26.2=py38_0 + - tk=8.6.10=hbc83047_0 + - toolz=0.11.1=py_0 + - torchfile=0.1.0=py_0 + - tornado=6.1=py38h27cfd23_0 + - tqdm=4.62.1=pyhd3eb1b0_1 + - traitlets=5.0.5=pyhd3eb1b0_0 + - typing_extensions=3.10.0.0=pyh06a4308_0 + - urllib3=1.26.6=pyhd3eb1b0_1 + - wcwidth=0.2.5=py_0 + - webencodings=0.5.1=py38_1 + - werkzeug=1.0.1=pyhd3eb1b0_0 + - wheel=0.37.0=pyhd3eb1b0_0 + - widgetsnbextension=3.5.1=py38_0 + - x264=1!161.3030=h7f98852_1 + - xmltodict=0.12.0=py_0 + - xz=5.2.5=h7b6447c_0 + - yacs=0.1.6=py_0 + - yaml=0.2.5=h7b6447c_0 + - yarl=1.6.3=py38h497a2fe_2 + - zeromq=4.3.4=h2531618_0 + - zipp=3.5.0=pyhd3eb1b0_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.9=haebb681_0 + - pip: + - addict==2.4.0 + - altair==4.2.0 + - astor==0.8.1 + - astunparse==1.6.3 + - backports-zoneinfo==0.2.1 + - base58==2.1.1 + - basicsr==1.3.4.1 + - boto3==1.18.33 + - botocore==1.21.33 + - clang==5.0 + - clean-fid==0.1.22 + - clip==1.0 + - colorama==0.4.4 + - commonmark==0.9.1 + - cython==0.29.30 + - einops==0.3.2 + - enum-compat==0.0.3 + - facexlib==0.2.0.3 + - filterpy==1.4.5 + - flatbuffers==1.12 + - gast==0.4.0 + - google-pasta==0.2.0 + - grpcio==1.39.0 + - h5py==3.1.0 + - ipdb==0.13.9 + - jacinle==1.0.0 + - jmespath==0.10.0 + - jsonpickle==2.2.0 + - keras==2.7.0 + - keras-preprocessing==1.1.2 + - libclang==12.0.0 + - llvmlite==0.37.0 + - lpips==0.1.4 + - numba==0.54.0 + - opencv-python==4.5.3.56 + - opt-einsum==3.3.0 + - pkgconfig==1.5.5 + - pyarrow==8.0.0 + - pydantic==1.8.2 + - pydeck==0.7.1 + - pyhocon==0.3.58 + - pytz-deprecation-shim==0.1.0.post0 + - pyvis==0.2.1 + - realesrgan==0.2.2.3 + - rich==10.9.0 + - s3transfer==0.5.0 + - six==1.15.0 + - sklearn==0.0 + - streamlit==0.64.0 + - tabulate==0.8.9 + - tb-nightly==2.7.0a20210827 + - tensorflow-estimator==2.7.0 + - tensorflow-gpu==2.7.0 + - tensorflow-io-gcs-filesystem==0.21.0 + - tensorfn==0.1.19 + - termcolor==1.1.0 + - toml==0.10.2 + - torchsample==0.1.3 + - torchvision==0.10.0+cu111 + - typing-extensions==3.7.4.3 + - tzdata==2022.1 + - tzlocal==4.2 + - validators==0.19.0 + - vit-pytorch==0.24.3 + - watchdog==2.1.8 + - wrapt==1.12.1 + - yapf==0.31.0 \ No newline at end of file diff --git a/expansion/__init__.py b/expansion/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/expansion/dataloader/__init__.py b/expansion/dataloader/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/expansion/dataloader/__pycache__/__init__.cpython-38.pyc b/expansion/dataloader/__pycache__/__init__.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..36079aff3f6dfa75f083f17f30c7a0dbc6f69dd5 Binary files /dev/null and b/expansion/dataloader/__pycache__/__init__.cpython-38.pyc differ diff --git a/expansion/dataloader/__pycache__/seqlist.cpython-38.pyc b/expansion/dataloader/__pycache__/seqlist.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..980a3a34d42ed22b8c711cb73be9424ace31bf95 Binary files /dev/null and b/expansion/dataloader/__pycache__/seqlist.cpython-38.pyc differ diff --git a/expansion/dataloader/chairslist.py b/expansion/dataloader/chairslist.py new file mode 100755 index 0000000000000000000000000000000000000000..107f4ee2914a369d31fd2c89b49a7d4ebba51638 --- /dev/null +++ b/expansion/dataloader/chairslist.py @@ -0,0 +1,33 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np +import glob + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + l0_train = [] + l1_train = [] + flow_train = [] + for flow_map in sorted(glob.glob(os.path.join(filepath,'*_flow.flo'))): + root_filename = flow_map[:-9] + img1 = root_filename+'_img1.ppm' + img2 = root_filename+'_img2.ppm' + if not (os.path.isfile(os.path.join(filepath,img1)) and os.path.isfile(os.path.join(filepath,img2))): + continue + + l0_train.append(img1) + l1_train.append(img2) + flow_train.append(flow_map) + + return l0_train, l1_train, flow_train diff --git a/expansion/dataloader/chairssdlist.py b/expansion/dataloader/chairssdlist.py new file mode 100755 index 0000000000000000000000000000000000000000..714c81c8a801beb549806bd72ac6da7d284d25e9 --- /dev/null +++ b/expansion/dataloader/chairssdlist.py @@ -0,0 +1,30 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np +import glob + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + l0_train = [] + l1_train = [] + flow_train = [] + for flow_map in sorted(glob.glob('%s/flow/*.pfm'%filepath)): + img1 = flow_map.replace('flow','t0').replace('.pfm','.png') + img2 = flow_map.replace('flow','t1').replace('.pfm','.png') + + l0_train.append(img1) + l1_train.append(img2) + flow_train.append(flow_map) + + return l0_train, l1_train, flow_train diff --git a/expansion/dataloader/depth_transforms.py b/expansion/dataloader/depth_transforms.py new file mode 100755 index 0000000000000000000000000000000000000000..19a768c788137bfb89077b6c415dc9401a540e4e --- /dev/null +++ b/expansion/dataloader/depth_transforms.py @@ -0,0 +1,471 @@ +from __future__ import division +import torch +import random +import numpy as np +import numbers +import types +import scipy.ndimage as ndimage +import pdb +import torchvision +import PIL.Image as Image +import cv2 +from torch.nn import functional as F + + +class Compose(object): + """ Composes several co_transforms together. + For example: + >>> co_transforms.Compose([ + >>> co_transforms.CenterCrop(10), + >>> co_transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, co_transforms): + self.co_transforms = co_transforms + + def __call__(self, input, target,intr): + for t in self.co_transforms: + input,target,intr = t(input,target,intr) + return input,target,intr + + +class Scale(object): + """ Rescales the inputs and target arrays to the given 'size'. + 'size' will be the size of the smaller edge. + For example, if height > width, then image will be + rescaled to (size * height / width, size) + size: size of the smaller edge + interpolation order: Default: 2 (bilinear) + """ + + def __init__(self, size, order=1): + self.ratio = size + self.order = order + if order==0: + self.code=cv2.INTER_NEAREST + elif order==1: + self.code=cv2.INTER_LINEAR + elif order==2: + self.code=cv2.INTER_CUBIC + + def __call__(self, inputs, target): + if self.ratio==1: + return inputs, target + h, w, _ = inputs[0].shape + ratio = self.ratio + + inputs[0] = cv2.resize(inputs[0], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR) + inputs[1] = cv2.resize(inputs[1], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR) + # keep the mask same + tmp = cv2.resize(target[:,:,2], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_NEAREST) + target = cv2.resize(target, None, fx=ratio,fy=ratio,interpolation=self.code) * ratio + target[:,:,2] = tmp + + + return inputs, target + + +class RandomCrop(object): + """Crops the given PIL.Image at a random location to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, inputs,target,intr): + h, w, _ = inputs[0].shape + th, tw = self.size + if w < tw: tw=w + if h < th: th=h + + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + intr[1] -= x1 + intr[2] -= y1 + + inputs[0] = inputs[0][y1: y1 + th,x1: x1 + tw].astype(float) + inputs[1] = inputs[1][y1: y1 + th,x1: x1 + tw].astype(float) + return inputs, target[y1: y1 + th,x1: x1 + tw].astype(float), list(np.asarray(intr).astype(float)) + list(np.asarray([1.,0.,0.,1.,0.,0.]).astype(float)) + + + +class SpatialAug(object): + def __init__(self, crop, scale=None, rot=None, trans=None, squeeze=None, schedule_coeff=1, order=1, black=False): + self.crop = crop + self.scale = scale + self.rot = rot + self.trans = trans + self.squeeze = squeeze + self.t = np.zeros(6) + self.schedule_coeff = schedule_coeff + self.order = order + self.black = black + + def to_identity(self): + self.t[0] = 1; self.t[2] = 0; self.t[4] = 0; self.t[1] = 0; self.t[3] = 1; self.t[5] = 0; + + def left_multiply(self, u0, u1, u2, u3, u4, u5): + result = np.zeros(6) + result[0] = self.t[0]*u0 + self.t[1]*u2; + result[1] = self.t[0]*u1 + self.t[1]*u3; + + result[2] = self.t[2]*u0 + self.t[3]*u2; + result[3] = self.t[2]*u1 + self.t[3]*u3; + + result[4] = self.t[4]*u0 + self.t[5]*u2 + u4; + result[5] = self.t[4]*u1 + self.t[5]*u3 + u5; + self.t = result + + def inverse(self): + result = np.zeros(6) + a = self.t[0]; c = self.t[2]; e = self.t[4]; + b = self.t[1]; d = self.t[3]; f = self.t[5]; + + denom = a*d - b*c; + + result[0] = d / denom; + result[1] = -b / denom; + result[2] = -c / denom; + result[3] = a / denom; + result[4] = (c*f-d*e) / denom; + result[5] = (b*e-a*f) / denom; + + return result + + def grid_transform(self, meshgrid, t, normalize=True, gridsize=None): + if gridsize is None: + h, w = meshgrid[0].shape + else: + h, w = gridsize + vgrid = torch.cat([(meshgrid[0] * t[0] + meshgrid[1] * t[2] + t[4])[:,:,np.newaxis], + (meshgrid[0] * t[1] + meshgrid[1] * t[3] + t[5])[:,:,np.newaxis]],-1) + if normalize: + vgrid[:,:,0] = 2.0*vgrid[:,:,0]/max(w-1,1)-1.0 + vgrid[:,:,1] = 2.0*vgrid[:,:,1]/max(h-1,1)-1.0 + return vgrid + + + def __call__(self, inputs, target, intr): + h, w, _ = inputs[0].shape + th, tw = self.crop + meshgrid = torch.meshgrid([torch.Tensor(range(th)), torch.Tensor(range(tw))])[::-1] + cornergrid = torch.meshgrid([torch.Tensor([0,th-1]), torch.Tensor([0,tw-1])])[::-1] + + for i in range(50): + # im0 + self.to_identity() + #TODO add mirror + if np.random.binomial(1,0.5): + mirror = True + else: + mirror = False + ##TODO + #mirror = False + if mirror: + self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th); + else: + self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th); + scale0 = 1; scale1 = 1; squeeze0 = 1; squeeze1 = 1; + if not self.rot is None: + rot0 = np.random.uniform(-self.rot[0],+self.rot[0]) + rot1 = np.random.uniform(-self.rot[1]*self.schedule_coeff, self.rot[1]*self.schedule_coeff) + rot0 + self.left_multiply(np.cos(rot0), np.sin(rot0), -np.sin(rot0), np.cos(rot0), 0, 0) + if not self.trans is None: + trans0 = np.random.uniform(-self.trans[0],+self.trans[0], 2) + trans1 = np.random.uniform(-self.trans[1]*self.schedule_coeff,+self.trans[1]*self.schedule_coeff, 2) + trans0 + self.left_multiply(1, 0, 0, 1, trans0[0] * tw, trans0[1] * th) + if not self.squeeze is None: + squeeze0 = np.exp(np.random.uniform(-self.squeeze[0], self.squeeze[0])) + squeeze1 = np.exp(np.random.uniform(-self.squeeze[1]*self.schedule_coeff, self.squeeze[1]*self.schedule_coeff)) * squeeze0 + if not self.scale is None: + scale0 = np.exp(np.random.uniform(self.scale[2]-self.scale[0], self.scale[2]+self.scale[0])) + scale1 = np.exp(np.random.uniform(-self.scale[1]*self.schedule_coeff, self.scale[1]*self.schedule_coeff)) * scale0 + self.left_multiply(1.0/(scale0*squeeze0), 0, 0, 1.0/(scale0/squeeze0), 0, 0) + + self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h); + transmat0 = self.t.copy() + + # im1 + self.to_identity() + if mirror: + self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th); + else: + self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th); + if not self.rot is None: + self.left_multiply(np.cos(rot1), np.sin(rot1), -np.sin(rot1), np.cos(rot1), 0, 0) + if not self.trans is None: + self.left_multiply(1, 0, 0, 1, trans1[0] * tw, trans1[1] * th) + self.left_multiply(1.0/(scale1*squeeze1), 0, 0, 1.0/(scale1/squeeze1), 0, 0) + self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h); + transmat1 = self.t.copy() + transmat1_inv = self.inverse() + + if self.black: + # black augmentation, allowing 0 values in the input images + # https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/black_augmentation_layer.cu + break + else: + if ((self.grid_transform(cornergrid, transmat0, gridsize=[float(h),float(w)]).abs()>1).sum() +\ + (self.grid_transform(cornergrid, transmat1, gridsize=[float(h),float(w)]).abs()>1).sum()) == 0: + break + if i==49: + print('max_iter in augmentation') + self.to_identity() + self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th); + self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h); + transmat0 = self.t.copy() + transmat1 = self.t.copy() + + # do the real work + vgrid = self.grid_transform(meshgrid, transmat0,gridsize=[float(h),float(w)]) + inputs_0 = F.grid_sample(torch.Tensor(inputs[0]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0) + if self.order == 0: + target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0) + else: + target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0) + + mask_0 = target[:,:,2:3].copy(); mask_0[mask_0==0]=np.nan + if self.order == 0: + mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0) + else: + mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0) + mask_0[torch.isnan(mask_0)] = 0 + + + vgrid = self.grid_transform(meshgrid, transmat1,gridsize=[float(h),float(w)]) + inputs_1 = F.grid_sample(torch.Tensor(inputs[1]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0) + + # flow + pos = target_0[:,:,:2] + self.grid_transform(meshgrid, transmat0,normalize=False) + pos = self.grid_transform(pos.permute(2,0,1),transmat1_inv,normalize=False) + if target_0.shape[2]>=4: + # scale + exp = target_0[:,:,3:] * scale1 / scale0 + target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1), + (pos[:,:,1] - meshgrid[1]).unsqueeze(-1), + mask_0, + exp], -1) + else: + target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1), + (pos[:,:,1] - meshgrid[1]).unsqueeze(-1), + mask_0], -1) + inputs = [np.asarray(inputs_0).astype(float), np.asarray(inputs_1).astype(float)] + target = np.asarray(target).astype(float) + return inputs,target, list(np.asarray(intr+list(transmat0)).astype(float)) + + + +class pseudoPCAAug(object): + """ + Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu + This version is faster. + """ + def __init__(self, schedule_coeff=1): + self.augcolor = torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.5, hue=0.5/3.14) + + def __call__(self, inputs, target,intr): + img = np.concatenate([inputs[0],inputs[1]],0) + shape = img.shape[0]//2 + aug_img = np.asarray(self.augcolor(Image.fromarray(np.uint8(img*255))))/255. + inputs[0] = aug_img[:shape] + inputs[1] = aug_img[shape:] + #inputs[0] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[0]*255))))/255. + #inputs[1] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[1]*255))))/255. + return inputs,target,intr + + +class PCAAug(object): + """ + Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu + """ + def __init__(self, lmult_pow =[0.4, 0,-0.2], + lmult_mult =[0.4, 0,0, ], + lmult_add =[0.03,0,0, ], + sat_pow =[0.4, 0,0, ], + sat_mult =[0.5, 0,-0.3], + sat_add =[0.03,0,0, ], + col_pow =[0.4, 0,0, ], + col_mult =[0.2, 0,0, ], + col_add =[0.02,0,0, ], + ladd_pow =[0.4, 0,0, ], + ladd_mult =[0.4, 0,0, ], + ladd_add =[0.04,0,0, ], + col_rotate =[1., 0,0, ], + schedule_coeff=1): + # no mean + self.pow_nomean = [1,1,1] + self.add_nomean = [0,0,0] + self.mult_nomean = [1,1,1] + self.pow_withmean = [1,1,1] + self.add_withmean = [0,0,0] + self.mult_withmean = [1,1,1] + self.lmult_pow = 1 + self.lmult_mult = 1 + self.lmult_add = 0 + self.col_angle = 0 + if not ladd_pow is None: + self.pow_nomean[0] =np.exp(np.random.normal(ladd_pow[2], ladd_pow[0])) + if not col_pow is None: + self.pow_nomean[1] =np.exp(np.random.normal(col_pow[2], col_pow[0])) + self.pow_nomean[2] =np.exp(np.random.normal(col_pow[2], col_pow[0])) + + if not ladd_add is None: + self.add_nomean[0] =np.random.normal(ladd_add[2], ladd_add[0]) + if not col_add is None: + self.add_nomean[1] =np.random.normal(col_add[2], col_add[0]) + self.add_nomean[2] =np.random.normal(col_add[2], col_add[0]) + + if not ladd_mult is None: + self.mult_nomean[0] =np.exp(np.random.normal(ladd_mult[2], ladd_mult[0])) + if not col_mult is None: + self.mult_nomean[1] =np.exp(np.random.normal(col_mult[2], col_mult[0])) + self.mult_nomean[2] =np.exp(np.random.normal(col_mult[2], col_mult[0])) + + # with mean + if not sat_pow is None: + self.pow_withmean[1] =np.exp(np.random.uniform(sat_pow[2]-sat_pow[0], sat_pow[2]+sat_pow[0])) + self.pow_withmean[2] =self.pow_withmean[1] + if not sat_add is None: + self.add_withmean[1] =np.random.uniform(sat_add[2]-sat_add[0], sat_add[2]+sat_add[0]) + self.add_withmean[2] =self.add_withmean[1] + if not sat_mult is None: + self.mult_withmean[1] = np.exp(np.random.uniform(sat_mult[2]-sat_mult[0], sat_mult[2]+sat_mult[0])) + self.mult_withmean[2] = self.mult_withmean[1] + + if not lmult_pow is None: + self.lmult_pow = np.exp(np.random.uniform(lmult_pow[2]-lmult_pow[0], lmult_pow[2]+lmult_pow[0])) + if not lmult_mult is None: + self.lmult_mult= np.exp(np.random.uniform(lmult_mult[2]-lmult_mult[0], lmult_mult[2]+lmult_mult[0])) + if not lmult_add is None: + self.lmult_add = np.random.uniform(lmult_add[2]-lmult_add[0], lmult_add[2]+lmult_add[0]) + if not col_rotate is None: + self.col_angle= np.random.uniform(col_rotate[2]-col_rotate[0], col_rotate[2]+col_rotate[0]) + + # eigen vectors + self.eigvec = np.reshape([0.51,0.56,0.65,0.79,0.01,-0.62,0.35,-0.83,0.44],[3,3]).transpose() + + + def __call__(self, inputs, target, intr): + inputs[0] = self.pca_image(inputs[0]) + inputs[1] = self.pca_image(inputs[1]) + return inputs,target,intr + + def pca_image(self, rgb): + eig = np.dot(rgb, self.eigvec) + max_rgb = np.clip(rgb,0,np.inf).max((0,1)) + min_rgb = rgb.min((0,1)) + mean_rgb = rgb.mean((0,1)) + max_abs_eig = np.abs(eig).max((0,1)) + max_l = np.sqrt(np.sum(max_abs_eig*max_abs_eig)) + mean_eig = np.dot(mean_rgb, self.eigvec) + + # no-mean stuff + eig -= mean_eig[np.newaxis, np.newaxis] + + for c in range(3): + if max_abs_eig[c] > 1e-2: + mean_eig[c] /= max_abs_eig[c] + eig[:,:,c] = eig[:,:,c] / max_abs_eig[c]; + eig[:,:,c] = np.power(np.abs(eig[:,:,c]),self.pow_nomean[c]) *\ + ((eig[:,:,c] > 0) -0.5)*2 + eig[:,:,c] = eig[:,:,c] + self.add_nomean[c] + eig[:,:,c] = eig[:,:,c] * self.mult_nomean[c] + eig += mean_eig[np.newaxis,np.newaxis] + + # withmean stuff + if max_abs_eig[0] > 1e-2: + eig[:,:,0] = np.power(np.abs(eig[:,:,0]),self.pow_withmean[0]) * \ + ((eig[:,:,0]>0)-0.5)*2; + eig[:,:,0] = eig[:,:,0] + self.add_withmean[0]; + eig[:,:,0] = eig[:,:,0] * self.mult_withmean[0]; + + s = np.sqrt(eig[:,:,1]*eig[:,:,1] + eig[:,:,2] * eig[:,:,2]) + smask = s > 1e-2 + s1 = np.power(s, self.pow_withmean[1]); + s1 = np.clip(s1 + self.add_withmean[1], 0,np.inf) + s1 = s1 * self.mult_withmean[1] + s1 = s1 * smask + s*(1-smask) + + # color angle + if self.col_angle!=0: + temp1 = np.cos(self.col_angle) * eig[:,:,1] - np.sin(self.col_angle) * eig[:,:,2] + temp2 = np.sin(self.col_angle) * eig[:,:,1] + np.cos(self.col_angle) * eig[:,:,2] + eig[:,:,1] = temp1 + eig[:,:,2] = temp2 + + # to origin magnitude + for c in range(3): + if max_abs_eig[c] > 1e-2: + eig[:,:,c] = eig[:,:,c] * max_abs_eig[c] + + if max_l > 1e-2: + l1 = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2]) + l1 = l1 / max_l + + eig[:,:,1][smask] = (eig[:,:,1] / s * s1)[smask] + eig[:,:,2][smask] = (eig[:,:,2] / s * s1)[smask] + #eig[:,:,1] = (eig[:,:,1] / s * s1) * smask + eig[:,:,1] * (1-smask) + #eig[:,:,2] = (eig[:,:,2] / s * s1) * smask + eig[:,:,2] * (1-smask) + + if max_l > 1e-2: + l = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2]) + l1 = np.power(l1, self.lmult_pow) + l1 = np.clip(l1 + self.lmult_add, 0, np.inf) + l1 = l1 * self.lmult_mult + l1 = l1 * max_l + lmask = l > 1e-2 + eig[lmask] = (eig / l[:,:,np.newaxis] * l1[:,:,np.newaxis])[lmask] + for c in range(3): + eig[:,:,c][lmask] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c]))[lmask] + # for c in range(3): +# # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] * lmask + eig[:,:,c] * (1-lmask) + # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] + # eig[:,:,c] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c])) * lmask + eig[:,:,c] * (1-lmask) + + return np.clip(np.dot(eig, self.eigvec.transpose()), 0, 1) + + +class ChromaticAug(object): + """ + Chromatic augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu + """ + def __init__(self, noise = 0.06, + gamma = 0.02, + brightness = 0.02, + contrast = 0.02, + color = 0.02, + schedule_coeff=1): + + self.noise = np.random.uniform(0,noise) + self.gamma = np.exp(np.random.normal(0, gamma*schedule_coeff)) + self.brightness = np.random.normal(0, brightness*schedule_coeff) + self.contrast = np.exp(np.random.normal(0, contrast*schedule_coeff)) + self.color = np.exp(np.random.normal(0, color*schedule_coeff,3)) + + def __call__(self, inputs, target, intr): + inputs[1] = self.chrom_aug(inputs[1]) + # noise + inputs[0]+=np.random.normal(0, self.noise, inputs[0].shape) + inputs[1]+=np.random.normal(0, self.noise, inputs[0].shape) + return inputs,target,intr + + def chrom_aug(self, rgb): + # color change + mean_in = rgb.sum(-1) + rgb = rgb*self.color[np.newaxis,np.newaxis] + brightness_coeff = mean_in / (rgb.sum(-1)+0.01) + rgb = np.clip(rgb*brightness_coeff[:,:,np.newaxis],0,1) + # gamma + rgb = np.power(rgb,self.gamma) + # brightness + rgb += self.brightness + # contrast + rgb = 0.5 + ( rgb-0.5)*self.contrast + rgb = np.clip(rgb, 0, 1) + return rgb diff --git a/expansion/dataloader/depthloader.py b/expansion/dataloader/depthloader.py new file mode 100755 index 0000000000000000000000000000000000000000..b189ba5b70a546eeb0a7e78812f19057e43e02fa --- /dev/null +++ b/expansion/dataloader/depthloader.py @@ -0,0 +1,222 @@ +import os +import numbers +import torch +import torch.utils.data as data +import torch +import torchvision.transforms as transforms +import random +from PIL import Image, ImageOps +import numpy as np +import torchvision +from . import depth_transforms as flow_transforms +import pdb +import cv2 +from utils.flowlib import read_flow +from utils.util_flow import readPFM, load_calib_cam_to_cam + +def default_loader(path): + return Image.open(path).convert('RGB') + +def flow_loader(path): + if '.pfm' in path: + data = readPFM(path)[0] + data[:,:,2] = 1 + return data + else: + return read_flow(path) + +def load_exts(cam_file): + with open(cam_file, 'r') as f: + lines = f.readlines() + + l_exts = [] + r_exts = [] + for l in lines: + if 'L ' in l: + l_exts.append(np.asarray([float(i) for i in l[2:].strip().split(' ')]).reshape(4,4)) + if 'R ' in l: + r_exts.append(np.asarray([float(i) for i in l[2:].strip().split(' ')]).reshape(4,4)) + return l_exts,r_exts + +def disparity_loader(path): + if '.png' in path: + data = Image.open(path) + data = np.ascontiguousarray(data,dtype=np.float32)/256 + return data + else: + return readPFM(path)[0] + +# triangulation +def triangulation(disp, xcoord, ycoord, bl=1, fl = 450, cx = 479.5, cy = 269.5): + depth = bl*fl / disp # 450px->15mm focal length + X = (xcoord - cx) * depth / fl + Y = (ycoord - cy) * depth / fl + Z = depth + P = np.concatenate((X[np.newaxis],Y[np.newaxis],Z[np.newaxis]),0).reshape(3,-1) + P = np.concatenate((P,np.ones((1,P.shape[-1]))),0) + return P + +class myImageFloder(data.Dataset): + def __init__(self, iml0, iml1, flowl0, loader=default_loader, dploader= flow_loader, scale=1.,shape=[320,448], order=1, noise=0.06, pca_augmentor=True, prob = 1.,sc=False,disp0=None,disp1=None,calib=None ): + self.iml0 = iml0 + self.iml1 = iml1 + self.flowl0 = flowl0 + self.loader = loader + self.dploader = dploader + self.scale=scale + self.shape=shape + self.order=order + self.noise = noise + self.pca_augmentor = pca_augmentor + self.prob = prob + self.sc = sc + self.disp0 = disp0 + self.disp1 = disp1 + self.calib = calib + + def __getitem__(self, index): + iml0 = self.iml0[index] + iml1 = self.iml1[index] + flowl0= self.flowl0[index] + th, tw = self.shape + + iml0 = self.loader(iml0) + iml1 = self.loader(iml1) + + # get disparity + if self.sc: + flowl0 = self.dploader(flowl0) + flowl0 = np.ascontiguousarray(flowl0,dtype=np.float32) + flowl0[np.isnan(flowl0)] = 1e6 # set to max + if 'camera_data.txt' in self.calib[index]: + bl=1 + if '15mm_' in self.calib[index]: + fl=450 # 450 + else: + fl=1050 + cx = 479.5 + cy = 269.5 + # negative disp + d1 = np.abs(disparity_loader(self.disp0[index])) + d2 = np.abs(disparity_loader(self.disp1[index]) + d1) + elif 'Sintel' in self.calib[index]: + fl = 1000 + bl = 1 + cx = 511.5 + cy = 217.5 + d1 = np.zeros(flowl0.shape[:2]) + d2 = np.zeros(flowl0.shape[:2]) + else: + ints = load_calib_cam_to_cam(self.calib[index]) + fl = ints['K_cam2'][0,0] + cx = ints['K_cam2'][0,2] + cy = ints['K_cam2'][1,2] + bl = ints['b20']-ints['b30'] + d1 = disparity_loader(self.disp0[index]) + d2 = disparity_loader(self.disp1[index]) + #flowl0[:,:,2] = (flowl0[:,:,2]==1).astype(float) + flowl0[:,:,2] = np.logical_and(np.logical_and(flowl0[:,:,2]==1, d1!=0), d2!=0).astype(float) + + shape = d1.shape + mesh = np.meshgrid(range(shape[1]),range(shape[0])) + xcoord = mesh[0].astype(float) + ycoord = mesh[1].astype(float) + + # triangulation in two frames + P0 = triangulation(d1, xcoord, ycoord, bl=bl, fl = fl, cx = cx, cy = cy) + P1 = triangulation(d2, xcoord + flowl0[:,:,0], ycoord + flowl0[:,:,1], bl=bl, fl = fl, cx = cx, cy = cy) + dis0 = P0[2] + dis1 = P1[2] + + change_size = dis0.reshape(shape).astype(np.float32) + flow3d = (P1-P0)[:3].reshape((3,)+shape).transpose((1,2,0)) + + gt_normal = np.concatenate((d1[:,:,np.newaxis],d2[:,:,np.newaxis],d2[:,:,np.newaxis]),-1) + change_size = np.concatenate((change_size[:,:,np.newaxis],gt_normal,flow3d),2) + else: + shape = iml0.size + shape=[shape[1],shape[0]] + flowl0 = np.zeros((shape[0],shape[1],3)) + change_size = np.zeros((shape[0],shape[1],7)) + depth = disparity_loader(self.iml1[index].replace('camera','groundtruth')) + change_size[:,:,0] = depth + + seqid = self.iml0[index].split('/')[-5].rsplit('_',3)[0] + ints = load_calib_cam_to_cam('/data/gengshay/KITTI/%s/calib_cam_to_cam.txt'%seqid) + fl = ints['K_cam2'][0,0] + cx = ints['K_cam2'][0,2] + cy = ints['K_cam2'][1,2] + bl = ints['b20']-ints['b30'] + + + iml1 = np.asarray(iml1)/255. + iml0 = np.asarray(iml0)/255. + iml0 = iml0[:,:,::-1].copy() + iml1 = iml1[:,:,::-1].copy() + + ## following data augmentation procedure in PWCNet + ## https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu + import __main__ # a workaround for "discount_coeff" + try: + with open('/scratch/gengshay/iter_counts-%d.txt'%int(__main__.args.logname.split('-')[-1]), 'r') as f: + iter_counts = int(f.readline()) + except: + iter_counts = 0 + schedule = [0.5, 1., 50000.] # initial coeff, final_coeff, half life + schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * \ + (2/(1+np.exp(-1.0986*iter_counts/schedule[2])) - 1) + + if self.pca_augmentor: + pca_augmentor = flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff) + else: + pca_augmentor = flow_transforms.Scale(1., order=0) + + if np.random.binomial(1,self.prob): + co_transform1 = flow_transforms.Compose([ + flow_transforms.SpatialAug([th,tw], + scale=[0.2,0.,0.1], + rot=[0.4,0.], + trans=[0.4,0.], + squeeze=[0.3,0.], schedule_coeff=schedule_coeff, order=self.order), + ]) + else: + co_transform1 = flow_transforms.Compose([ + flow_transforms.RandomCrop([th,tw]), + ]) + + co_transform2 = flow_transforms.Compose([ + flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff), + #flow_transforms.PCAAug(schedule_coeff=schedule_coeff), + flow_transforms.ChromaticAug( schedule_coeff=schedule_coeff, noise=self.noise), + ]) + + flowl0 = np.concatenate([flowl0,change_size],-1) + augmented,flowl0,intr = co_transform1([iml0, iml1], flowl0, [fl,cx,cy,bl]) + imol0 = augmented[0] + imol1 = augmented[1] + augmented,flowl0,intr = co_transform2(augmented, flowl0, intr) + + iml0 = augmented[0] + iml1 = augmented[1] + flowl0 = flowl0.astype(np.float32) + change_size = flowl0[:,:,3:] + flowl0 = flowl0[:,:,:3] + + # randomly cover a region + sx=0;sy=0;cx=0;cy=0 + if np.random.binomial(1,0.5): + sx = int(np.random.uniform(25,100)) + sy = int(np.random.uniform(25,100)) + #sx = int(np.random.uniform(50,150)) + #sy = int(np.random.uniform(50,150)) + cx = int(np.random.uniform(sx,iml1.shape[0]-sx)) + cy = int(np.random.uniform(sy,iml1.shape[1]-sy)) + iml1[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(iml1,0),0)[np.newaxis,np.newaxis] + + iml0 = torch.Tensor(np.transpose(iml0,(2,0,1))) + iml1 = torch.Tensor(np.transpose(iml1,(2,0,1))) + + return iml0, iml1, flowl0, change_size, intr, imol0, imol1, np.asarray([cx-sx,cx+sx,cy-sy,cy+sy]) + + def __len__(self): + return len(self.iml0) diff --git a/expansion/dataloader/flow_transforms.py b/expansion/dataloader/flow_transforms.py new file mode 100755 index 0000000000000000000000000000000000000000..3bff1ac448a408012a5aa3a1c1035cbf2695e240 --- /dev/null +++ b/expansion/dataloader/flow_transforms.py @@ -0,0 +1,440 @@ +from __future__ import division +import torch +import random +import numpy as np +import numbers +import types +import scipy.ndimage as ndimage +import pdb +import torchvision +import PIL.Image as Image +import cv2 +from torch.nn import functional as F + + +class Compose(object): + """ Composes several co_transforms together. + For example: + >>> co_transforms.Compose([ + >>> co_transforms.CenterCrop(10), + >>> co_transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, co_transforms): + self.co_transforms = co_transforms + + def __call__(self, input, target): + for t in self.co_transforms: + input,target = t(input,target) + return input,target + + +class Scale(object): + """ Rescales the inputs and target arrays to the given 'size'. + 'size' will be the size of the smaller edge. + For example, if height > width, then image will be + rescaled to (size * height / width, size) + size: size of the smaller edge + interpolation order: Default: 2 (bilinear) + """ + + def __init__(self, size, order=1): + self.ratio = size + self.order = order + if order==0: + self.code=cv2.INTER_NEAREST + elif order==1: + self.code=cv2.INTER_LINEAR + elif order==2: + self.code=cv2.INTER_CUBIC + + def __call__(self, inputs, target): + if self.ratio==1: + return inputs, target + h, w, _ = inputs[0].shape + ratio = self.ratio + + inputs[0] = cv2.resize(inputs[0], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR) + inputs[1] = cv2.resize(inputs[1], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR) + # keep the mask same + tmp = cv2.resize(target[:,:,2], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_NEAREST) + target = cv2.resize(target, None, fx=ratio,fy=ratio,interpolation=self.code) * ratio + target[:,:,2] = tmp + + + return inputs, target + + + + +class SpatialAug(object): + def __init__(self, crop, scale=None, rot=None, trans=None, squeeze=None, schedule_coeff=1, order=1, black=False): + self.crop = crop + self.scale = scale + self.rot = rot + self.trans = trans + self.squeeze = squeeze + self.t = np.zeros(6) + self.schedule_coeff = schedule_coeff + self.order = order + self.black = black + + def to_identity(self): + self.t[0] = 1; self.t[2] = 0; self.t[4] = 0; self.t[1] = 0; self.t[3] = 1; self.t[5] = 0; + + def left_multiply(self, u0, u1, u2, u3, u4, u5): + result = np.zeros(6) + result[0] = self.t[0]*u0 + self.t[1]*u2; + result[1] = self.t[0]*u1 + self.t[1]*u3; + + result[2] = self.t[2]*u0 + self.t[3]*u2; + result[3] = self.t[2]*u1 + self.t[3]*u3; + + result[4] = self.t[4]*u0 + self.t[5]*u2 + u4; + result[5] = self.t[4]*u1 + self.t[5]*u3 + u5; + self.t = result + + def inverse(self): + result = np.zeros(6) + a = self.t[0]; c = self.t[2]; e = self.t[4]; + b = self.t[1]; d = self.t[3]; f = self.t[5]; + + denom = a*d - b*c; + + result[0] = d / denom; + result[1] = -b / denom; + result[2] = -c / denom; + result[3] = a / denom; + result[4] = (c*f-d*e) / denom; + result[5] = (b*e-a*f) / denom; + + return result + + def grid_transform(self, meshgrid, t, normalize=True, gridsize=None): + if gridsize is None: + h, w = meshgrid[0].shape + else: + h, w = gridsize + vgrid = torch.cat([(meshgrid[0] * t[0] + meshgrid[1] * t[2] + t[4])[:,:,np.newaxis], + (meshgrid[0] * t[1] + meshgrid[1] * t[3] + t[5])[:,:,np.newaxis]],-1) + if normalize: + vgrid[:,:,0] = 2.0*vgrid[:,:,0]/max(w-1,1)-1.0 + vgrid[:,:,1] = 2.0*vgrid[:,:,1]/max(h-1,1)-1.0 + return vgrid + + + def __call__(self, inputs, target): + h, w, _ = inputs[0].shape + th, tw = self.crop + meshgrid = torch.meshgrid([torch.Tensor(range(th)), torch.Tensor(range(tw))])[::-1] + cornergrid = torch.meshgrid([torch.Tensor([0,th-1]), torch.Tensor([0,tw-1])])[::-1] + + for i in range(50): + # im0 + self.to_identity() + #TODO add mirror + if np.random.binomial(1,0.5): + mirror = True + else: + mirror = False + ##TODO + #mirror = False + if mirror: + self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th); + else: + self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th); + scale0 = 1; scale1 = 1; squeeze0 = 1; squeeze1 = 1; + if not self.rot is None: + rot0 = np.random.uniform(-self.rot[0],+self.rot[0]) + rot1 = np.random.uniform(-self.rot[1]*self.schedule_coeff, self.rot[1]*self.schedule_coeff) + rot0 + self.left_multiply(np.cos(rot0), np.sin(rot0), -np.sin(rot0), np.cos(rot0), 0, 0) + if not self.trans is None: + trans0 = np.random.uniform(-self.trans[0],+self.trans[0], 2) + trans1 = np.random.uniform(-self.trans[1]*self.schedule_coeff,+self.trans[1]*self.schedule_coeff, 2) + trans0 + self.left_multiply(1, 0, 0, 1, trans0[0] * tw, trans0[1] * th) + if not self.squeeze is None: + squeeze0 = np.exp(np.random.uniform(-self.squeeze[0], self.squeeze[0])) + squeeze1 = np.exp(np.random.uniform(-self.squeeze[1]*self.schedule_coeff, self.squeeze[1]*self.schedule_coeff)) * squeeze0 + if not self.scale is None: + scale0 = np.exp(np.random.uniform(self.scale[2]-self.scale[0], self.scale[2]+self.scale[0])) + scale1 = np.exp(np.random.uniform(-self.scale[1]*self.schedule_coeff, self.scale[1]*self.schedule_coeff)) * scale0 + self.left_multiply(1.0/(scale0*squeeze0), 0, 0, 1.0/(scale0/squeeze0), 0, 0) + + self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h); + transmat0 = self.t.copy() + + # im1 + self.to_identity() + if mirror: + self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th); + else: + self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th); + if not self.rot is None: + self.left_multiply(np.cos(rot1), np.sin(rot1), -np.sin(rot1), np.cos(rot1), 0, 0) + if not self.trans is None: + self.left_multiply(1, 0, 0, 1, trans1[0] * tw, trans1[1] * th) + self.left_multiply(1.0/(scale1*squeeze1), 0, 0, 1.0/(scale1/squeeze1), 0, 0) + self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h); + transmat1 = self.t.copy() + transmat1_inv = self.inverse() + + if self.black: + # black augmentation, allowing 0 values in the input images + # https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/black_augmentation_layer.cu + break + else: + if ((self.grid_transform(cornergrid, transmat0, gridsize=[float(h),float(w)]).abs()>1).sum() +\ + (self.grid_transform(cornergrid, transmat1, gridsize=[float(h),float(w)]).abs()>1).sum()) == 0: + break + if i==49: + print('max_iter in augmentation') + self.to_identity() + self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th); + self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h); + transmat0 = self.t.copy() + transmat1 = self.t.copy() + + # do the real work + vgrid = self.grid_transform(meshgrid, transmat0,gridsize=[float(h),float(w)]) + inputs_0 = F.grid_sample(torch.Tensor(inputs[0]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0) + if self.order == 0: + target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0) + else: + target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0) + + mask_0 = target[:,:,2:3].copy(); mask_0[mask_0==0]=np.nan + if self.order == 0: + mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0) + else: + mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0) + mask_0[torch.isnan(mask_0)] = 0 + + + vgrid = self.grid_transform(meshgrid, transmat1,gridsize=[float(h),float(w)]) + inputs_1 = F.grid_sample(torch.Tensor(inputs[1]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0) + + # flow + pos = target_0[:,:,:2] + self.grid_transform(meshgrid, transmat0,normalize=False) + pos = self.grid_transform(pos.permute(2,0,1),transmat1_inv,normalize=False) + if target_0.shape[2]>=4: + # scale + exp = target_0[:,:,3:] * scale1 / scale0 + target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1), + (pos[:,:,1] - meshgrid[1]).unsqueeze(-1), + mask_0, + exp], -1) + else: + target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1), + (pos[:,:,1] - meshgrid[1]).unsqueeze(-1), + mask_0], -1) +# target_0[:,:,2].unsqueeze(-1) ], -1) + inputs = [np.asarray(inputs_0), np.asarray(inputs_1)] + target = np.asarray(target) + + return inputs,target + + +class pseudoPCAAug(object): + """ + Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu + This version is faster. + """ + def __init__(self, schedule_coeff=1): + self.augcolor = torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.5, hue=0.5/3.14) + + def __call__(self, inputs, target): + inputs[0] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[0]*255))))/255. + inputs[1] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[1]*255))))/255. + return inputs,target + + +class PCAAug(object): + """ + Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu + """ + def __init__(self, lmult_pow =[0.4, 0,-0.2], + lmult_mult =[0.4, 0,0, ], + lmult_add =[0.03,0,0, ], + sat_pow =[0.4, 0,0, ], + sat_mult =[0.5, 0,-0.3], + sat_add =[0.03,0,0, ], + col_pow =[0.4, 0,0, ], + col_mult =[0.2, 0,0, ], + col_add =[0.02,0,0, ], + ladd_pow =[0.4, 0,0, ], + ladd_mult =[0.4, 0,0, ], + ladd_add =[0.04,0,0, ], + col_rotate =[1., 0,0, ], + schedule_coeff=1): + # no mean + self.pow_nomean = [1,1,1] + self.add_nomean = [0,0,0] + self.mult_nomean = [1,1,1] + self.pow_withmean = [1,1,1] + self.add_withmean = [0,0,0] + self.mult_withmean = [1,1,1] + self.lmult_pow = 1 + self.lmult_mult = 1 + self.lmult_add = 0 + self.col_angle = 0 + if not ladd_pow is None: + self.pow_nomean[0] =np.exp(np.random.normal(ladd_pow[2], ladd_pow[0])) + if not col_pow is None: + self.pow_nomean[1] =np.exp(np.random.normal(col_pow[2], col_pow[0])) + self.pow_nomean[2] =np.exp(np.random.normal(col_pow[2], col_pow[0])) + + if not ladd_add is None: + self.add_nomean[0] =np.random.normal(ladd_add[2], ladd_add[0]) + if not col_add is None: + self.add_nomean[1] =np.random.normal(col_add[2], col_add[0]) + self.add_nomean[2] =np.random.normal(col_add[2], col_add[0]) + + if not ladd_mult is None: + self.mult_nomean[0] =np.exp(np.random.normal(ladd_mult[2], ladd_mult[0])) + if not col_mult is None: + self.mult_nomean[1] =np.exp(np.random.normal(col_mult[2], col_mult[0])) + self.mult_nomean[2] =np.exp(np.random.normal(col_mult[2], col_mult[0])) + + # with mean + if not sat_pow is None: + self.pow_withmean[1] =np.exp(np.random.uniform(sat_pow[2]-sat_pow[0], sat_pow[2]+sat_pow[0])) + self.pow_withmean[2] =self.pow_withmean[1] + if not sat_add is None: + self.add_withmean[1] =np.random.uniform(sat_add[2]-sat_add[0], sat_add[2]+sat_add[0]) + self.add_withmean[2] =self.add_withmean[1] + if not sat_mult is None: + self.mult_withmean[1] = np.exp(np.random.uniform(sat_mult[2]-sat_mult[0], sat_mult[2]+sat_mult[0])) + self.mult_withmean[2] = self.mult_withmean[1] + + if not lmult_pow is None: + self.lmult_pow = np.exp(np.random.uniform(lmult_pow[2]-lmult_pow[0], lmult_pow[2]+lmult_pow[0])) + if not lmult_mult is None: + self.lmult_mult= np.exp(np.random.uniform(lmult_mult[2]-lmult_mult[0], lmult_mult[2]+lmult_mult[0])) + if not lmult_add is None: + self.lmult_add = np.random.uniform(lmult_add[2]-lmult_add[0], lmult_add[2]+lmult_add[0]) + if not col_rotate is None: + self.col_angle= np.random.uniform(col_rotate[2]-col_rotate[0], col_rotate[2]+col_rotate[0]) + + # eigen vectors + self.eigvec = np.reshape([0.51,0.56,0.65,0.79,0.01,-0.62,0.35,-0.83,0.44],[3,3]).transpose() + + + def __call__(self, inputs, target): + inputs[0] = self.pca_image(inputs[0]) + inputs[1] = self.pca_image(inputs[1]) + return inputs,target + + def pca_image(self, rgb): + eig = np.dot(rgb, self.eigvec) + max_rgb = np.clip(rgb,0,np.inf).max((0,1)) + min_rgb = rgb.min((0,1)) + mean_rgb = rgb.mean((0,1)) + max_abs_eig = np.abs(eig).max((0,1)) + max_l = np.sqrt(np.sum(max_abs_eig*max_abs_eig)) + mean_eig = np.dot(mean_rgb, self.eigvec) + + # no-mean stuff + eig -= mean_eig[np.newaxis, np.newaxis] + + for c in range(3): + if max_abs_eig[c] > 1e-2: + mean_eig[c] /= max_abs_eig[c] + eig[:,:,c] = eig[:,:,c] / max_abs_eig[c]; + eig[:,:,c] = np.power(np.abs(eig[:,:,c]),self.pow_nomean[c]) *\ + ((eig[:,:,c] > 0) -0.5)*2 + eig[:,:,c] = eig[:,:,c] + self.add_nomean[c] + eig[:,:,c] = eig[:,:,c] * self.mult_nomean[c] + eig += mean_eig[np.newaxis,np.newaxis] + + # withmean stuff + if max_abs_eig[0] > 1e-2: + eig[:,:,0] = np.power(np.abs(eig[:,:,0]),self.pow_withmean[0]) * \ + ((eig[:,:,0]>0)-0.5)*2; + eig[:,:,0] = eig[:,:,0] + self.add_withmean[0]; + eig[:,:,0] = eig[:,:,0] * self.mult_withmean[0]; + + s = np.sqrt(eig[:,:,1]*eig[:,:,1] + eig[:,:,2] * eig[:,:,2]) + smask = s > 1e-2 + s1 = np.power(s, self.pow_withmean[1]); + s1 = np.clip(s1 + self.add_withmean[1], 0,np.inf) + s1 = s1 * self.mult_withmean[1] + s1 = s1 * smask + s*(1-smask) + + # color angle + if self.col_angle!=0: + temp1 = np.cos(self.col_angle) * eig[:,:,1] - np.sin(self.col_angle) * eig[:,:,2] + temp2 = np.sin(self.col_angle) * eig[:,:,1] + np.cos(self.col_angle) * eig[:,:,2] + eig[:,:,1] = temp1 + eig[:,:,2] = temp2 + + # to origin magnitude + for c in range(3): + if max_abs_eig[c] > 1e-2: + eig[:,:,c] = eig[:,:,c] * max_abs_eig[c] + + if max_l > 1e-2: + l1 = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2]) + l1 = l1 / max_l + + eig[:,:,1][smask] = (eig[:,:,1] / s * s1)[smask] + eig[:,:,2][smask] = (eig[:,:,2] / s * s1)[smask] + #eig[:,:,1] = (eig[:,:,1] / s * s1) * smask + eig[:,:,1] * (1-smask) + #eig[:,:,2] = (eig[:,:,2] / s * s1) * smask + eig[:,:,2] * (1-smask) + + if max_l > 1e-2: + l = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2]) + l1 = np.power(l1, self.lmult_pow) + l1 = np.clip(l1 + self.lmult_add, 0, np.inf) + l1 = l1 * self.lmult_mult + l1 = l1 * max_l + lmask = l > 1e-2 + eig[lmask] = (eig / l[:,:,np.newaxis] * l1[:,:,np.newaxis])[lmask] + for c in range(3): + eig[:,:,c][lmask] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c]))[lmask] + # for c in range(3): +# # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] * lmask + eig[:,:,c] * (1-lmask) + # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] + # eig[:,:,c] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c])) * lmask + eig[:,:,c] * (1-lmask) + + return np.clip(np.dot(eig, self.eigvec.transpose()), 0, 1) + + +class ChromaticAug(object): + """ + Chromatic augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu + """ + def __init__(self, noise = 0.06, + gamma = 0.02, + brightness = 0.02, + contrast = 0.02, + color = 0.02, + schedule_coeff=1): + + self.noise = np.random.uniform(0,noise) + self.gamma = np.exp(np.random.normal(0, gamma*schedule_coeff)) + self.brightness = np.random.normal(0, brightness*schedule_coeff) + self.contrast = np.exp(np.random.normal(0, contrast*schedule_coeff)) + self.color = np.exp(np.random.normal(0, color*schedule_coeff,3)) + + def __call__(self, inputs, target): + inputs[1] = self.chrom_aug(inputs[1]) + # noise + inputs[0]+=np.random.normal(0, self.noise, inputs[0].shape) + inputs[1]+=np.random.normal(0, self.noise, inputs[0].shape) + return inputs,target + + def chrom_aug(self, rgb): + # color change + mean_in = rgb.sum(-1) + rgb = rgb*self.color[np.newaxis,np.newaxis] + brightness_coeff = mean_in / (rgb.sum(-1)+0.01) + rgb = np.clip(rgb*brightness_coeff[:,:,np.newaxis],0,1) + # gamma + rgb = np.power(rgb,self.gamma) + # brightness + rgb += self.brightness + # contrast + rgb = 0.5 + ( rgb-0.5)*self.contrast + rgb = np.clip(rgb, 0, 1) + return rgb diff --git a/expansion/dataloader/hd1klist.py b/expansion/dataloader/hd1klist.py new file mode 100755 index 0000000000000000000000000000000000000000..9b4754986a850007f43345068d42648c2d6ecc24 --- /dev/null +++ b/expansion/dataloader/hd1klist.py @@ -0,0 +1,29 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np +import pdb + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + + left_fold = 'image_2/' + train = [img for img in os.listdir(filepath+left_fold) if img.find('HD1K2018') > -1] + train = sorted(train) + + l0_train = [filepath+left_fold+img for img in train] + l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] + l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] + flow_train = [img.replace('image_2','flow_occ') for img in l0_train] + + return l0_train, l1_train, flow_train diff --git a/expansion/dataloader/kitti12list.py b/expansion/dataloader/kitti12list.py new file mode 100755 index 0000000000000000000000000000000000000000..af1e7f2e9438f79dd7219d01274786c5c4f0e794 --- /dev/null +++ b/expansion/dataloader/kitti12list.py @@ -0,0 +1,29 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + + left_fold = 'colored_0/' + flow_noc = 'flow_occ/' + + train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] + + l0_train = [filepath+left_fold+img for img in train] + l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] + flow_train = [filepath+flow_noc+img for img in train] + + + return l0_train, l1_train, flow_train diff --git a/expansion/dataloader/kitti15list.py b/expansion/dataloader/kitti15list.py new file mode 100755 index 0000000000000000000000000000000000000000..ff44e5cbb43d64e99bda79aaac4bdb8f7cdd9d2b --- /dev/null +++ b/expansion/dataloader/kitti15list.py @@ -0,0 +1,29 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + + left_fold = 'image_2/' + flow_noc = 'flow_occ/' + + train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] + + l0_train = [filepath+left_fold+img for img in train] + l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] + flow_train = [filepath+flow_noc+img for img in train] + + + return sorted(l0_train), sorted(l1_train), sorted(flow_train) diff --git a/expansion/dataloader/kitti15list_train.py b/expansion/dataloader/kitti15list_train.py new file mode 100755 index 0000000000000000000000000000000000000000..e1eca1af37426274490e916b884271281024cc47 --- /dev/null +++ b/expansion/dataloader/kitti15list_train.py @@ -0,0 +1,31 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + + left_fold = 'image_2/' + flow_noc = 'flow_occ/' + + train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] + + train = [i for i in train if int(i.split('_')[0])%5!=0] + + l0_train = [filepath+left_fold+img for img in train] + l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] + flow_train = [filepath+flow_noc+img for img in train] + + + return sorted(l0_train), sorted(l1_train), sorted(flow_train) diff --git a/expansion/dataloader/kitti15list_train_lidar.py b/expansion/dataloader/kitti15list_train_lidar.py new file mode 100755 index 0000000000000000000000000000000000000000..aa77c139455c195d0532aa890c5c8b8f137923cb --- /dev/null +++ b/expansion/dataloader/kitti15list_train_lidar.py @@ -0,0 +1,34 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + + left_fold = 'image_2/' + flow_noc = 'flow_occ/' + + train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] + +# train = [i for i in train if int(i.split('_')[0])%5!=0] + with open('/data/gengshay/kitti_scene/devkit/mapping/train_mapping.txt','r') as f: + flags = [True if len(i)>1 else False for i in f.readlines()] + train = [fn for (it,fn) in enumerate(sorted(train)) if flags[it] ][:100] + + l0_train = [filepath+left_fold+img for img in train] + l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] + flow_train = [filepath+flow_noc+img for img in train] + + + return sorted(l0_train), sorted(l1_train), sorted(flow_train) diff --git a/expansion/dataloader/kitti15list_val.py b/expansion/dataloader/kitti15list_val.py new file mode 100755 index 0000000000000000000000000000000000000000..3d5e39e245dd8878f1f6270dcc2b28ff2c85ed5d --- /dev/null +++ b/expansion/dataloader/kitti15list_val.py @@ -0,0 +1,31 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + + left_fold = 'image_2/' + flow_noc = 'flow_occ/' + + train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] + + train = [i for i in train if int(i.split('_')[0])%5==0] + + l0_train = [filepath+left_fold+img for img in train] + l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] + flow_train = [filepath+flow_noc+img for img in train] + + + return sorted(l0_train), sorted(l1_train), sorted(flow_train) diff --git a/expansion/dataloader/kitti15list_val_lidar.py b/expansion/dataloader/kitti15list_val_lidar.py new file mode 100755 index 0000000000000000000000000000000000000000..12420446c328bfa45c283a80a9d406b54e7cf346 --- /dev/null +++ b/expansion/dataloader/kitti15list_val_lidar.py @@ -0,0 +1,34 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + + left_fold = 'image_2/' + flow_noc = 'flow_occ/' + + train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] + +# train = [i for i in train if int(i.split('_')[0])%5!=0] + with open('/data/gengshay/kitti_scene/devkit/mapping/train_mapping.txt','r') as f: + flags = [True if len(i)>1 else False for i in f.readlines()] + train = [fn for (it,fn) in enumerate(sorted(train)) if flags[it] ][100:] + + l0_train = [filepath+left_fold+img for img in train] + l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] + flow_train = [filepath+flow_noc+img for img in train] + + + return sorted(l0_train), sorted(l1_train), sorted(flow_train) diff --git a/expansion/dataloader/kitti15list_val_mr.py b/expansion/dataloader/kitti15list_val_mr.py new file mode 100755 index 0000000000000000000000000000000000000000..56c209f65a734b9874738456a1bebc843f6fecb1 --- /dev/null +++ b/expansion/dataloader/kitti15list_val_mr.py @@ -0,0 +1,41 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + + left_fold = 'image_2/' + flow_noc = 'flow_occ/' + + train = [img for img in os.listdir(filepath+left_fold) if 'Kitti' in img and img.find('_10') > -1] + +# train = [i for i in train if int(i.split('_')[1])%5==0] + import pdb; pdb.set_trace() + train = sorted([i for i in train if int(i.split('_')[1])%5==0])[0:1] + + l0_train = [filepath+left_fold+img for img in train] + l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] + flow_train = [filepath+flow_noc+img for img in train] + + l0_train += [filepath+left_fold+img.replace('_10','_09') for img in train] + l1_train += [filepath+left_fold+img for img in train] + flow_train += flow_train + + tmp = l0_train + l0_train = l0_train+ [i.replace('rob_flow', 'kitti_scene').replace('Kitti2015_','') for i in l1_train] + l1_train = l1_train+tmp + flow_train += flow_train + + return l0_train, l1_train, flow_train diff --git a/expansion/dataloader/robloader.py b/expansion/dataloader/robloader.py new file mode 100755 index 0000000000000000000000000000000000000000..471f2262c77d7c71b48b0fb187159e9d2dbc517a --- /dev/null +++ b/expansion/dataloader/robloader.py @@ -0,0 +1,133 @@ +import os +import numbers +import torch +import torch.utils.data as data +import torch +import torchvision.transforms as transforms +import random +from PIL import Image, ImageOps +import numpy as np +import torchvision +from . import flow_transforms +import pdb +import cv2 +from utils.flowlib import read_flow +from utils.util_flow import readPFM + + +def default_loader(path): + return Image.open(path).convert('RGB') + +def flow_loader(path): + if '.pfm' in path: + data = readPFM(path)[0] + data[:,:,2] = 1 + return data + else: + return read_flow(path) + + +def disparity_loader(path): + if '.png' in path: + data = Image.open(path) + data = np.ascontiguousarray(data,dtype=np.float32)/256 + return data + else: + return readPFM(path)[0] + +class myImageFloder(data.Dataset): + def __init__(self, iml0, iml1, flowl0, loader=default_loader, dploader= flow_loader, scale=1.,shape=[320,448], order=1, noise=0.06, pca_augmentor=True, prob = 1., cover=False, black=False, scale_aug=[0.4,0.2]): + self.iml0 = iml0 + self.iml1 = iml1 + self.flowl0 = flowl0 + self.loader = loader + self.dploader = dploader + self.scale=scale + self.shape=shape + self.order=order + self.noise = noise + self.pca_augmentor = pca_augmentor + self.prob = prob + self.cover = cover + self.black = black + self.scale_aug = scale_aug + + def __getitem__(self, index): + iml0 = self.iml0[index] + iml1 = self.iml1[index] + flowl0= self.flowl0[index] + th, tw = self.shape + + iml0 = self.loader(iml0) + iml1 = self.loader(iml1) + iml1 = np.asarray(iml1)/255. + iml0 = np.asarray(iml0)/255. + iml0 = iml0[:,:,::-1].copy() + iml1 = iml1[:,:,::-1].copy() + flowl0 = self.dploader(flowl0) + #flowl0[:,:,-1][flowl0[:,:,0]==np.inf]=0 # for gtav window pfm files + #flowl0[:,:,0][~flowl0[:,:,2].astype(bool)]=0 + #flowl0[:,:,1][~flowl0[:,:,2].astype(bool)]=0 # avoid nan in grad + flowl0 = np.ascontiguousarray(flowl0,dtype=np.float32) + flowl0[np.isnan(flowl0)] = 1e6 # set to max + + ## following data augmentation procedure in PWCNet + ## https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu + import __main__ # a workaround for "discount_coeff" + try: + with open('iter_counts-%d.txt'%int(__main__.args.logname.split('-')[-1]), 'r') as f: + iter_counts = int(f.readline()) + except: + iter_counts = 0 + schedule = [0.5, 1., 50000.] # initial coeff, final_coeff, half life + schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * \ + (2/(1+np.exp(-1.0986*iter_counts/schedule[2])) - 1) + + if self.pca_augmentor: + pca_augmentor = flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff) + else: + pca_augmentor = flow_transforms.Scale(1., order=0) + + if np.random.binomial(1,self.prob): + co_transform = flow_transforms.Compose([ + flow_transforms.Scale(self.scale, order=self.order), + #flow_transforms.SpatialAug([th,tw], trans=[0.2,0.03], order=self.order, black=self.black), + flow_transforms.SpatialAug([th,tw],scale=[self.scale_aug[0],0.03,self.scale_aug[1]], + rot=[0.4,0.03], + trans=[0.4,0.03], + squeeze=[0.3,0.], schedule_coeff=schedule_coeff, order=self.order, black=self.black), + #flow_transforms.pseudoPCAAug(schedule_coeff=schedule_coeff), + flow_transforms.PCAAug(schedule_coeff=schedule_coeff), + flow_transforms.ChromaticAug( schedule_coeff=schedule_coeff, noise=self.noise), + ]) + else: + co_transform = flow_transforms.Compose([ + flow_transforms.Scale(self.scale, order=self.order), + flow_transforms.SpatialAug([th,tw], trans=[0.4,0.03], order=self.order, black=self.black) + ]) + + augmented,flowl0 = co_transform([iml0, iml1], flowl0) + iml0 = augmented[0] + iml1 = augmented[1] + + if self.cover: + ## randomly cover a region + # following sec. 3.2 of http://openaccess.thecvf.com/content_CVPR_2019/html/Yang_Hierarchical_Deep_Stereo_Matching_on_High-Resolution_Images_CVPR_2019_paper.html + if np.random.binomial(1,0.5): + #sx = int(np.random.uniform(25,100)) + #sy = int(np.random.uniform(25,100)) + sx = int(np.random.uniform(50,125)) + sy = int(np.random.uniform(50,125)) + #sx = int(np.random.uniform(50,150)) + #sy = int(np.random.uniform(50,150)) + cx = int(np.random.uniform(sx,iml1.shape[0]-sx)) + cy = int(np.random.uniform(sy,iml1.shape[1]-sy)) + iml1[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(iml1,0),0)[np.newaxis,np.newaxis] + + iml0 = torch.Tensor(np.transpose(iml0,(2,0,1))) + iml1 = torch.Tensor(np.transpose(iml1,(2,0,1))) + + return iml0, iml1, flowl0 + + def __len__(self): + return len(self.iml0) diff --git a/expansion/dataloader/sceneflowlist.py b/expansion/dataloader/sceneflowlist.py new file mode 100755 index 0000000000000000000000000000000000000000..0ba93e6caa33214f0a2bf7a46c74566c86414d0b --- /dev/null +++ b/expansion/dataloader/sceneflowlist.py @@ -0,0 +1,51 @@ +import os +import os.path +import glob + +def dataloader(filepath, level=6): + iml0 = [] + iml1 = [] + flowl0 = [] + disp0 = [] + dispc = [] + calib = [] + level_stars = '/*'*level + candidate_pool = glob.glob('%s/optical_flow%s'%(filepath,level_stars)) + for flow_path in sorted(candidate_pool): + if 'TEST' in flow_path: continue + if 'flower_storm_x2/into_future/right/OpticalFlowIntoFuture_0023_R.pfm' in flow_path: + continue + if 'flower_storm_x2/into_future/left/OpticalFlowIntoFuture_0023_L.pfm' in flow_path: + continue + if 'flower_storm_augmented0_x2/into_future/right/OpticalFlowIntoFuture_0023_R.pfm' in flow_path: + continue + if 'flower_storm_augmented0_x2/into_future/left/OpticalFlowIntoFuture_0023_L.pfm' in flow_path: + continue + if 'FlyingThings' in flow_path and '_0014_' in flow_path: + continue + if 'FlyingThings' in flow_path and '_0015_' in flow_path: + continue + idd = flow_path.split('/')[-1].split('_')[-2] + if 'into_future' in flow_path: + idd_p1 = '%04d'%(int(idd)+1) + else: + idd_p1 = '%04d'%(int(idd)-1) + if os.path.exists(flow_path.replace(idd,idd_p1)): + d0_path = flow_path.replace('/into_future/','/').replace('/into_past/','/').replace('optical_flow','disparity') + d0_path = '%s/%s.pfm'%(d0_path.rsplit('/',1)[0],idd) + dc_path = flow_path.replace('optical_flow','disparity_change') + dc_path = '%s/%s.pfm'%(dc_path.rsplit('/',1)[0],idd) + im_path = flow_path.replace('/into_future/','/').replace('/into_past/','/').replace('optical_flow','frames_cleanpass') + im0_path = '%s/%s.png'%(im_path.rsplit('/',1)[0],idd) + im1_path = '%s/%s.png'%(im_path.rsplit('/',1)[0],idd_p1) + #with open('%s/camera_data.txt'%(im0_path.replace('frames_cleanpass','camera_data').rsplit('/',2)[0]),'r') as f: + # if 'FlyingThings' in flow_path and len(f.readlines())!=40: + # print(flow_path) + # continue + iml0.append(im0_path) + iml1.append(im1_path) + flowl0.append(flow_path) + disp0.append(d0_path) + dispc.append(dc_path) + calib.append('%s/camera_data.txt'%(im0_path.replace('frames_cleanpass','camera_data').rsplit('/',2)[0])) + return iml0, iml1, flowl0, disp0, dispc, calib diff --git a/expansion/dataloader/seqlist.py b/expansion/dataloader/seqlist.py new file mode 100755 index 0000000000000000000000000000000000000000..2e0e8bda391a7ddfd09b3268065dc02712bc7575 --- /dev/null +++ b/expansion/dataloader/seqlist.py @@ -0,0 +1,26 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np +import glob + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + + train = [img for img in sorted(glob.glob('%s/*'%filepath))] + + l0_train = train[:-1] + l1_train = train[1:] + + + return sorted(l0_train), sorted(l1_train), sorted(l0_train) diff --git a/expansion/dataloader/sintellist.py b/expansion/dataloader/sintellist.py new file mode 100755 index 0000000000000000000000000000000000000000..44bc1ab5d466d605b7bc695adb41f2431aa0f790 --- /dev/null +++ b/expansion/dataloader/sintellist.py @@ -0,0 +1,32 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np +import pdb + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + + left_fold = 'image_2/' + train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1] + + l0_train = [filepath+left_fold+img for img in train] + l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] + + #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val + + l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] + flow_train = [img.replace('image_2','flow_occ') for img in l0_train] + + + return l0_train, l1_train, flow_train diff --git a/expansion/dataloader/sintellist_clean.py b/expansion/dataloader/sintellist_clean.py new file mode 100755 index 0000000000000000000000000000000000000000..c008399e94e150856e921bef12199af9910400f6 --- /dev/null +++ b/expansion/dataloader/sintellist_clean.py @@ -0,0 +1,31 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np +import pdb + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + + left_fold = 'image_2/' + train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_clean') > -1] + + l0_train = [filepath+left_fold+img for img in train] + l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] + + #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val + + l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] + flow_train = [img.replace('image_2','flow_occ') for img in l0_train] + + return l0_train, l1_train, flow_train diff --git a/expansion/dataloader/sintellist_final.py b/expansion/dataloader/sintellist_final.py new file mode 100755 index 0000000000000000000000000000000000000000..b8585d594b02984bbe13005f680aaf1e859864d3 --- /dev/null +++ b/expansion/dataloader/sintellist_final.py @@ -0,0 +1,32 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np +import pdb + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + + left_fold = 'image_2/' + train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_final') > -1] + + l0_train = [filepath+left_fold+img for img in train] + l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] + + #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val + + l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] + flow_train = [img.replace('image_2','flow_occ') for img in l0_train] + + pdb.set_trace() + return l0_train, l1_train, flow_train diff --git a/expansion/dataloader/sintellist_train.py b/expansion/dataloader/sintellist_train.py new file mode 100755 index 0000000000000000000000000000000000000000..81ff9393fb5896983648885186f5fb7c2e6907de --- /dev/null +++ b/expansion/dataloader/sintellist_train.py @@ -0,0 +1,32 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np +import pdb + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + + left_fold = 'image_2/' + train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1] + + l0_train = [filepath+left_fold+img for img in train] + l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] + + l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val + + l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] + flow_train = [img.replace('image_2','flow_occ') for img in l0_train] + + + return l0_train, l1_train, flow_train diff --git a/expansion/dataloader/sintellist_val.py b/expansion/dataloader/sintellist_val.py new file mode 100755 index 0000000000000000000000000000000000000000..452446556bc18aee9ff2b47ab5d2ff81c7b5a2ec --- /dev/null +++ b/expansion/dataloader/sintellist_val.py @@ -0,0 +1,34 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np +import pdb + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + + left_fold = 'image_2/' + train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1] + + l0_train = [filepath+left_fold+img for img in train] + l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] + + l0_train = [i for i in l0_train if ('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i)] # remove 10 as val + #l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val + + l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] + flow_train = [img.replace('image_2','flow_occ') for img in l0_train] + + + return sorted(l0_train)[::3], sorted(l1_train)[::3], sorted(flow_train)[::3] +# return sorted(l0_train)[::10], sorted(l1_train)[::10], sorted(flow_train)[::10] diff --git a/expansion/dataloader/thingslist.py b/expansion/dataloader/thingslist.py new file mode 100755 index 0000000000000000000000000000000000000000..cfe3976bce5738dc2000b2dfc736b41a28d8624d --- /dev/null +++ b/expansion/dataloader/thingslist.py @@ -0,0 +1,122 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def dataloader(filepath): + exc_list = [ +'0004117.flo', +'0003149.flo', +'0001203.flo', +'0003147.flo', +'0003666.flo', +'0006337.flo', +'0006336.flo', +'0007126.flo', +'0004118.flo', +] + + left_fold = 'image_clean/left/' + flow_noc = 'flow/left/into_future/' + train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0] + + l0_trainlf = [filepath+left_fold+img.replace('flo','png') for img in train] + l1_trainlf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlf] + flow_trainlf = [filepath+flow_noc+img for img in train] + + + exc_list = [ +'0003148.flo', +'0004117.flo', +'0002890.flo', +'0003149.flo', +'0001203.flo', +'0003666.flo', +'0006337.flo', +'0006336.flo', +'0004118.flo', +] + + left_fold = 'image_clean/right/' + flow_noc = 'flow/right/into_future/' + train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0] + + l0_trainrf = [filepath+left_fold+img.replace('flo','png') for img in train] + l1_trainrf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrf] + flow_trainrf = [filepath+flow_noc+img for img in train] + + + exc_list = [ +'0004237.flo', +'0004705.flo', +'0004045.flo', +'0004346.flo', +'0000161.flo', +'0000931.flo', +'0000121.flo', +'0010822.flo', +'0004117.flo', +'0006023.flo', +'0005034.flo', +'0005054.flo', +'0000162.flo', +'0000053.flo', +'0005055.flo', +'0003147.flo', +'0004876.flo', +'0000163.flo', +'0006878.flo', +] + + left_fold = 'image_clean/left/' + flow_noc = 'flow/left/into_past/' + train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0] + + l0_trainlp = [filepath+left_fold+img.replace('flo','png') for img in train] + l1_trainlp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlp] + flow_trainlp = [filepath+flow_noc+img for img in train] + + exc_list = [ +'0003148.flo', +'0004705.flo', +'0000161.flo', +'0000121.flo', +'0004117.flo', +'0000160.flo', +'0005034.flo', +'0005054.flo', +'0000162.flo', +'0000053.flo', +'0005055.flo', +'0003147.flo', +'0001549.flo', +'0000163.flo', +'0006336.flo', +'0001648.flo', +'0006878.flo', +] + + left_fold = 'image_clean/right/' + flow_noc = 'flow/right/into_past/' + train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0] + + l0_trainrp = [filepath+left_fold+img.replace('flo','png') for img in train] + l1_trainrp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrp] + flow_trainrp = [filepath+flow_noc+img for img in train] + + + l0_train = l0_trainlf + l0_trainrf + l0_trainlp + l0_trainrp + l1_train = l1_trainlf + l1_trainrf + l1_trainlp + l1_trainrp + flow_train = flow_trainlf + flow_trainrf + flow_trainlp + flow_trainrp + return l0_train, l1_train, flow_train diff --git a/expansion/models/VCN_exp.py b/expansion/models/VCN_exp.py new file mode 100755 index 0000000000000000000000000000000000000000..6115ef9a68c63a12fae83bec92846de173c829b5 --- /dev/null +++ b/expansion/models/VCN_exp.py @@ -0,0 +1,561 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import os +os.environ['PYTHON_EGG_CACHE'] = 'tmp/' # a writable directory +import numpy as np +import math +import pdb +import time + +from .submodule import pspnet, bfmodule, conv +from .conv4d import sepConv4d, sepConv4dBlock, butterfly4D + +class flow_reg(nn.Module): + """ + Soft winner-take-all that selects the most likely diplacement. + Set ent=True to enable entropy output. + Set maxdisp to adjust maximum allowed displacement towards one side. + maxdisp=4 searches for a 9x9 region. + Set fac to squeeze search window. + maxdisp=4 and fac=2 gives search window of 9x5 + """ + def __init__(self, size, ent=False, maxdisp = int(4), fac=1): + B,W,H = size + super(flow_reg, self).__init__() + self.ent = ent + self.md = maxdisp + self.fac = fac + self.truncated = True + self.wsize = 3 # by default using truncation 7x7 + + flowrangey = range(-maxdisp,maxdisp+1) + flowrangex = range(-int(maxdisp//self.fac),int(maxdisp//self.fac)+1) + meshgrid = np.meshgrid(flowrangex,flowrangey) + flowy = np.tile( np.reshape(meshgrid[0],[1,2*maxdisp+1,2*int(maxdisp//self.fac)+1,1,1]), (B,1,1,H,W) ) + flowx = np.tile( np.reshape(meshgrid[1],[1,2*maxdisp+1,2*int(maxdisp//self.fac)+1,1,1]), (B,1,1,H,W) ) + self.register_buffer('flowx',torch.Tensor(flowx)) + self.register_buffer('flowy',torch.Tensor(flowy)) + + self.pool3d = nn.MaxPool3d((self.wsize*2+1,self.wsize*2+1,1),stride=1,padding=(self.wsize,self.wsize,0)) + + def forward(self, x): + b,u,v,h,w = x.shape + oldx = x + + if self.truncated: + # truncated softmax + x = x.view(b,u*v,h,w) + + idx = x.argmax(1)[:,np.newaxis] + if x.is_cuda: + mask = Variable(torch.cuda.HalfTensor(b,u*v,h,w)).fill_(0) + else: + mask = Variable(torch.FloatTensor(b,u*v,h,w)).fill_(0) + mask.scatter_(1,idx,1) + mask = mask.view(b,1,u,v,-1) + mask = self.pool3d(mask)[:,0].view(b,u,v,h,w) + + ninf = x.clone().fill_(-np.inf).view(b,u,v,h,w) + x = torch.where(mask.byte(),oldx,ninf) + else: + self.wsize = (np.sqrt(u*v)-1)/2 + + b,u,v,h,w = x.shape + x = F.softmax(x.view(b,-1,h,w),1).view(b,u,v,h,w) + outx = torch.sum(torch.sum(x*self.flowx,1),1,keepdim=True) + outy = torch.sum(torch.sum(x*self.flowy,1),1,keepdim=True) + + if self.ent: + # local + local_entropy = (-x*torch.clamp(x,1e-9,1-1e-9).log()).sum(1).sum(1)[:,np.newaxis] + if self.wsize == 0: + local_entropy[:] = 1. + else: + local_entropy /= np.log((self.wsize*2+1)**2) + + # global + x = F.softmax(oldx.view(b,-1,h,w),1).view(b,u,v,h,w) + global_entropy = (-x*torch.clamp(x,1e-9,1-1e-9).log()).sum(1).sum(1)[:,np.newaxis] + global_entropy /= np.log(x.shape[1]*x.shape[2]) + return torch.cat([outx,outy],1),torch.cat([local_entropy, global_entropy],1) + else: + return torch.cat([outx,outy],1),None + + +class WarpModule(nn.Module): + """ + taken from https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py + """ + def __init__(self, size): + super(WarpModule, self).__init__() + B,W,H = size + # mesh grid + xx = torch.arange(0, W).view(1,-1).repeat(H,1) + yy = torch.arange(0, H).view(-1,1).repeat(1,W) + xx = xx.view(1,1,H,W).repeat(B,1,1,1) + yy = yy.view(1,1,H,W).repeat(B,1,1,1) + self.register_buffer('grid',torch.cat((xx,yy),1).float()) + + def forward(self, x, flo): + """ + warp an image/tensor (im2) back to im1, according to the optical flow + + x: [B, C, H, W] (im2) + flo: [B, 2, H, W] flow + + """ + B, C, H, W = x.size() + vgrid = self.grid + flo + + # scale grid to [-1,1] + vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:]/max(W-1,1)-1.0 + vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:]/max(H-1,1)-1.0 + + vgrid = vgrid.permute(0,2,3,1) + output = nn.functional.grid_sample(x, vgrid,align_corners=True) + mask = ((vgrid[:,:,:,0].abs()<1) * (vgrid[:,:,:,1].abs()<1)) >0 + return output*mask.unsqueeze(1).float(), mask + + +def get_grid(B,H,W): + meshgrid_base = np.meshgrid(range(0,W), range(0,H))[::-1] + basey = np.reshape(meshgrid_base[0],[1,1,1,H,W]) + basex = np.reshape(meshgrid_base[1],[1,1,1,H,W]) + grid = torch.tensor(np.concatenate((basex.reshape((-1,H,W,1)),basey.reshape((-1,H,W,1))),-1)).cuda().float() + return grid.view(1,1,H,W,2) + + +class VCN(nn.Module): + """ + VCN. + md defines maximum displacement for each level, following a coarse-to-fine-warping scheme + fac defines squeeze parameter for the coarsest level + """ + def __init__(self, size, md=[4,4,4,4,4], fac=1.,exp_unc=False): # exp_uncertainty + super(VCN,self).__init__() + self.md = md + self.fac = fac + use_entropy = True + withbn = True + + ## pspnet + self.pspnet = pspnet(is_proj=False) + + ### Volumetric-UNet + fdima1 = 128 # 6/5/4 + fdima2 = 64 # 3/2 + fdimb1 = 16 # 6/5/4/3 + fdimb2 = 12 # 2 + + full=False + self.f6 = butterfly4D(fdima1, fdimb1,withbn=withbn,full=full) + self.p6 = sepConv4d(fdimb1,fdimb1, with_bn=False, full=full) + + self.f5 = butterfly4D(fdima1, fdimb1,withbn=withbn, full=full) + self.p5 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full) + + self.f4 = butterfly4D(fdima1, fdimb1,withbn=withbn,full=full) + self.p4 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full) + + self.f3 = butterfly4D(fdima2, fdimb1,withbn=withbn,full=full) + self.p3 = sepConv4d(fdimb1,fdimb1, with_bn=False,full=full) + + full=True + self.f2 = butterfly4D(fdima2, fdimb2,withbn=withbn,full=full) + self.p2 = sepConv4d(fdimb2,fdimb2, with_bn=False,full=full) + + self.flow_reg64 = flow_reg([fdimb1*size[0],size[1]//64,size[2]//64], ent=use_entropy, maxdisp=self.md[0], fac=self.fac) + self.flow_reg32 = flow_reg([fdimb1*size[0],size[1]//32,size[2]//32], ent=use_entropy, maxdisp=self.md[1]) + self.flow_reg16 = flow_reg([fdimb1*size[0],size[1]//16,size[2]//16], ent=use_entropy, maxdisp=self.md[2]) + self.flow_reg8 = flow_reg([fdimb1*size[0],size[1]//8,size[2]//8] , ent=use_entropy, maxdisp=self.md[3]) + self.flow_reg4 = flow_reg([fdimb2*size[0],size[1]//4,size[2]//4] , ent=use_entropy, maxdisp=self.md[4]) + + self.warp5 = WarpModule([size[0],size[1]//32,size[2]//32]) + self.warp4 = WarpModule([size[0],size[1]//16,size[2]//16]) + self.warp3 = WarpModule([size[0],size[1]//8,size[2]//8]) + self.warp2 = WarpModule([size[0],size[1]//4,size[2]//4]) + + ## hypotheses fusion modules, adopted from the refinement module of PWCNet + # https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py + # c6 + self.dc6_conv1 = conv(128+4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc6_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) + self.dc6_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) + self.dc6_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) + self.dc6_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) + self.dc6_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc6_conv7 = nn.Conv2d(32,2*fdimb1,kernel_size=3,stride=1,padding=1,bias=True) + + # c5 + self.dc5_conv1 = conv(128+4*fdimb1*2, 128, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc5_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) + self.dc5_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) + self.dc5_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) + self.dc5_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) + self.dc5_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc5_conv7 = nn.Conv2d(32,2*fdimb1*2,kernel_size=3,stride=1,padding=1,bias=True) + + # c4 + self.dc4_conv1 = conv(128+4*fdimb1*3, 128, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc4_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) + self.dc4_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) + self.dc4_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) + self.dc4_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) + self.dc4_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc4_conv7 = nn.Conv2d(32,2*fdimb1*3,kernel_size=3,stride=1,padding=1,bias=True) + + # c3 + self.dc3_conv1 = conv(64+16*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc3_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) + self.dc3_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) + self.dc3_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) + self.dc3_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) + self.dc3_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc3_conv7 = nn.Conv2d(32,8*fdimb1,kernel_size=3,stride=1,padding=1,bias=True) + + # c2 + self.dc2_conv1 = conv(64+16*fdimb1+4*fdimb2, 128, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc2_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) + self.dc2_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) + self.dc2_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) + self.dc2_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) + self.dc2_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc2_conv7 = nn.Conv2d(32,4*2*fdimb1 + 2*fdimb2,kernel_size=3,stride=1,padding=1,bias=True) + + self.dc6_conv = nn.Sequential( self.dc6_conv1, + self.dc6_conv2, + self.dc6_conv3, + self.dc6_conv4, + self.dc6_conv5, + self.dc6_conv6, + self.dc6_conv7) + self.dc5_conv = nn.Sequential( self.dc5_conv1, + self.dc5_conv2, + self.dc5_conv3, + self.dc5_conv4, + self.dc5_conv5, + self.dc5_conv6, + self.dc5_conv7) + self.dc4_conv = nn.Sequential( self.dc4_conv1, + self.dc4_conv2, + self.dc4_conv3, + self.dc4_conv4, + self.dc4_conv5, + self.dc4_conv6, + self.dc4_conv7) + self.dc3_conv = nn.Sequential( self.dc3_conv1, + self.dc3_conv2, + self.dc3_conv3, + self.dc3_conv4, + self.dc3_conv5, + self.dc3_conv6, + self.dc3_conv7) + self.dc2_conv = nn.Sequential( self.dc2_conv1, + self.dc2_conv2, + self.dc2_conv3, + self.dc2_conv4, + self.dc2_conv5, + self.dc2_conv6, + self.dc2_conv7) + + ## Out-of-range detection + self.dc6_convo = nn.Sequential(conv(128+4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1), + conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), + conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), + conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), + conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), + conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), + nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) + + self.dc5_convo = nn.Sequential(conv(128+2*4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1), + conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), + conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), + conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), + conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), + conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), + nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) + + self.dc4_convo = nn.Sequential(conv(128+3*4*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1), + conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), + conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), + conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), + conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), + conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), + nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) + + self.dc3_convo = nn.Sequential(conv(64+16*fdimb1, 128, kernel_size=3, stride=1, padding=1, dilation=1), + conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), + conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), + conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), + conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), + conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), + nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) + + self.dc2_convo = nn.Sequential(conv(64+16*fdimb1+4*fdimb2, 128, kernel_size=3, stride=1, padding=1, dilation=1), + conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2), + conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4), + conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8), + conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16), + conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1), + nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1,bias=True)) + + # affine-exp + self.f3d2v1 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) # + self.f3d2v2 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # + self.f3d2v3 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # + self.f3d2v4 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # + self.f3d2v5 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) # + self.f3d2v6 = conv(12*81, 32, kernel_size=3, stride=1, padding=1,dilation=1) # + self.f3d2 = bfmodule(128-64,1) + + # depth change net + self.dcnetv1 = conv(64, 32, kernel_size=3, stride=1, padding=1,dilation=1) # + self.dcnetv2 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # + self.dcnetv3 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # + self.dcnetv4 = conv(1, 32, kernel_size=3, stride=1, padding=1,dilation=1) # + self.dcnetv5 = conv(12*81, 32, kernel_size=3, stride=1, padding=1,dilation=1) # + self.dcnetv6 = conv(4, 32, kernel_size=3, stride=1, padding=1,dilation=1) # + if exp_unc: + self.dcnet = bfmodule(128,2) + else: + self.dcnet = bfmodule(128,1) + + for m in self.modules(): + if isinstance(m, nn.Conv3d): + n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if hasattr(m.bias,'data'): + m.bias.data.zero_() + + self.facs = [self.fac,1,1,1,1] + self.warp_modules = nn.ModuleList([None, self.warp5, self.warp4, self.warp3, self.warp2]) + self.f_modules = nn.ModuleList([self.f6, self.f5, self.f4, self.f3, self.f2]) + self.p_modules = nn.ModuleList([self.p6, self.p5, self.p4, self.p3, self.p2]) + self.reg_modules = nn.ModuleList([self.flow_reg64, self.flow_reg32, self.flow_reg16, self.flow_reg8, self.flow_reg4]) + self.oor_modules = nn.ModuleList([self.dc6_convo, self.dc5_convo, self.dc4_convo, self.dc3_convo, self.dc2_convo]) + self.fuse_modules = nn.ModuleList([self.dc6_conv, self.dc5_conv, self.dc4_conv, self.dc3_conv, self.dc2_conv]) + + def corrf(self, refimg_fea, targetimg_fea,maxdisp, fac=1): + """ + slow correlation function + """ + b,c,height,width = refimg_fea.shape + if refimg_fea.is_cuda: + cost = Variable(torch.cuda.FloatTensor(b,c,2*maxdisp+1,2*int(maxdisp//fac)+1,height,width)).fill_(0.) # b,c,u,v,h,w + else: + cost = Variable(torch.FloatTensor(b,c,2*maxdisp+1,2*int(maxdisp//fac)+1,height,width)).fill_(0.) # b,c,u,v,h,w + for i in range(2*maxdisp+1): + ind = i-maxdisp + for j in range(2*int(maxdisp//fac)+1): + indd = j-int(maxdisp//fac) + feata = refimg_fea[:,:,max(0,-indd):height-indd,max(0,-ind):width-ind] + featb = targetimg_fea[:,:,max(0,+indd):height+indd,max(0,ind):width+ind] + diff = (feata*featb) + cost[:, :, i,j,max(0,-indd):height-indd,max(0,-ind):width-ind] = diff # standard + cost = F.leaky_relu(cost, 0.1,inplace=True) + return cost + + def cost_matching(self,up_flow, c1, c2, flowh, enth, level): + """ + up_flow: upsample coarse flow + c1: normalized feature of image 1 + c2: normalized feature of image 2 + flowh: flow hypotheses + enth: entropy + """ + + # normalize + c1n = c1 / (c1.norm(dim=1, keepdim=True)+1e-9) + c2n = c2 / (c2.norm(dim=1, keepdim=True)+1e-9) + + # cost volume + if level == 0: + warp = c2n + else: + warp,_ = self.warp_modules[level](c2n, up_flow) + + feat = self.corrf(c1n,warp,self.md[level],fac=self.facs[level]) + feat = self.f_modules[level](feat) + cost = self.p_modules[level](feat) # b, 16, u,v,h,w + + # soft WTA + b,c,u,v,h,w = cost.shape + cost = cost.view(-1,u,v,h,w) # bx16, 9,9,h,w, also predict uncertainty from here + flowhh,enthh = self.reg_modules[level](cost) # bx16, 2, h, w + flowhh = flowhh.view(b,c,2,h,w) + if level > 0: + flowhh = flowhh + up_flow[:,np.newaxis] + flowhh = flowhh.view(b,-1,h,w) # b, 16*2, h, w + enthh = enthh.view(b,-1,h,w) # b, 16*1, h, w + + # append coarse hypotheses + if level == 0: + flowh = flowhh + enth = enthh + else: + flowh = torch.cat((flowhh, F.upsample(flowh.detach()*2, [flowhh.shape[2],flowhh.shape[3]], mode='bilinear')),1) # b, k2--k2, h, w + enth = torch.cat((enthh, F.upsample(enth, [flowhh.shape[2],flowhh.shape[3]], mode='bilinear')),1) + + if self.training or level==4: + x = torch.cat((enth.detach(), flowh.detach(), c1),1) + oor = self.oor_modules[level](x)[:,0] + else: oor = None + + # hypotheses fusion + x = torch.cat((enth.detach(), flowh.detach(), c1),1) + va = self.fuse_modules[level](x) + va = va.view(b,-1,2,h,w) + flow = ( flowh.view(b,-1,2,h,w) * F.softmax(va,1) ).sum(1) # b, 2k, 2, h, w + + return flow, flowh, enth, oor + + def affine(self,pref,flow, pw=1): + b,_,lh,lw=flow.shape + ptar = pref + flow + pw = 1 + pref = F.unfold(pref, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-pref[:,:,np.newaxis] + ptar = F.unfold(ptar, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-ptar[:,:,np.newaxis] # b, 2,9,h,w + pref = pref.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2) + ptar = ptar.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2) + + prefprefT = pref.matmul(pref.permute(0,2,1)) + ppdet = prefprefT[:,0,0]*prefprefT[:,1,1]-prefprefT[:,1,0]*prefprefT[:,0,1] + ppinv = torch.cat((prefprefT[:,1,1:],-prefprefT[:,0,1:], -prefprefT[:,1:,0], prefprefT[:,0:1,0]),1).view(-1,2,2)/ppdet.clamp(1e-10,np.inf)[:,np.newaxis,np.newaxis] + + Affine = ptar.matmul(pref.permute(0,2,1)).matmul(ppinv) + Error = (Affine.matmul(pref)-ptar).norm(2,1).mean(1).view(b,1,lh,lw) + + Avol = (Affine[:,0,0]*Affine[:,1,1]-Affine[:,1,0]*Affine[:,0,1]).view(b,1,lh,lw).abs().clamp(1e-10,np.inf) + exp = Avol.sqrt() + mask = (exp>0.5) & (exp<2) & (Error<0.1) + mask = mask[:,0] + + exp = exp.clamp(0.5,2) + exp[Error>0.1]=1 + return exp, Error, mask + + def affine_mask(self,pref,flow, pw=3): + """ + pref: reference coordinates + pw: patch width + """ + flmask = flow[:,2:] + flow = flow[:,:2] + b,_,lh,lw=flow.shape + ptar = pref + flow + pref = F.unfold(pref, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-pref[:,:,np.newaxis] + ptar = F.unfold(ptar, (pw*2+1,pw*2+1), padding=(pw)).view(b,2,(pw*2+1)**2,lh,lw)-ptar[:,:,np.newaxis] # b, 2,9,h,w + + conf_flow = flmask + conf_flow = F.unfold(conf_flow,(pw*2+1,pw*2+1), padding=(pw)).view(b,1,(pw*2+1)**2,lh,lw) + count = conf_flow.sum(2,keepdims=True) + conf_flow = ((pw*2+1)**2)*conf_flow / count + pref = pref * conf_flow + ptar = ptar * conf_flow + + pref = pref.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2) + ptar = ptar.permute(0,3,4,1,2).reshape(b*lh*lw,2,(pw*2+1)**2) + + prefprefT = pref.matmul(pref.permute(0,2,1)) + ppdet = prefprefT[:,0,0]*prefprefT[:,1,1]-prefprefT[:,1,0]*prefprefT[:,0,1] + ppinv = torch.cat((prefprefT[:,1,1:],-prefprefT[:,0,1:], -prefprefT[:,1:,0], prefprefT[:,0:1,0]),1).view(-1,2,2)/ppdet.clamp(1e-10,np.inf)[:,np.newaxis,np.newaxis] + + Affine = ptar.matmul(pref.permute(0,2,1)).matmul(ppinv) + Error = (Affine.matmul(pref)-ptar).norm(2,1).mean(1).view(b,1,lh,lw) + + Avol = (Affine[:,0,0]*Affine[:,1,1]-Affine[:,1,0]*Affine[:,0,1]).view(b,1,lh,lw).abs().clamp(1e-10,np.inf) + exp = Avol.sqrt() + mask = (exp>0.5) & (exp<2) & (Error<0.2) & (flmask.bool()) & (count[:,0]>4) + mask = mask[:,0] + + exp = exp.clamp(0.5,2) + exp[Error>0.2]=1 + return exp, Error, mask + + def weight_parameters(self): + return [param for name, param in self.named_parameters() if 'weight' in name] + + def bias_parameters(self): + return [param for name, param in self.named_parameters() if 'bias' in name] + + def forward(self,im,disc_aux=None): + bs = im.shape[0]//2 + + if self.training and disc_aux[-1]: # if only fine-tuning expansion + reset=True + self.eval() + torch.set_grad_enabled(False) + else: reset=False + + c06,c05,c04,c03,c02 = self.pspnet(im) + c16 = c06[:bs]; c26 = c06[bs:] + c15 = c05[:bs]; c25 = c05[bs:] + c14 = c04[:bs]; c24 = c04[bs:] + c13 = c03[:bs]; c23 = c03[bs:] + c12 = c02[:bs]; c22 = c02[bs:] + + ## matching 6 + flow6, flow6h, ent6h, oor6 = self.cost_matching(None, c16, c26, None, None,level=0) + + ## matching 5 + up_flow6 = F.upsample(flow6, [im.size()[2]//32,im.size()[3]//32], mode='bilinear')*2 + flow5, flow5h, ent5h, oor5 = self.cost_matching(up_flow6, c15, c25, flow6h, ent6h,level=1) + + ## matching 4 + up_flow5 = F.upsample(flow5, [im.size()[2]//16,im.size()[3]//16], mode='bilinear')*2 + flow4, flow4h, ent4h, oor4 = self.cost_matching(up_flow5, c14, c24, flow5h, ent5h,level=2) + + ## matching 3 + up_flow4 = F.upsample(flow4, [im.size()[2]//8,im.size()[3]//8], mode='bilinear')*2 + flow3, flow3h, ent3h, oor3 = self.cost_matching(up_flow4, c13, c23, flow4h, ent4h,level=3) + + ## matching 2 + up_flow3 = F.upsample(flow3, [im.size()[2]//4,im.size()[3]//4], mode='bilinear')*2 + flow2, flow2h, ent2h, oor2 = self.cost_matching(up_flow3, c12, c22, flow3h, ent3h,level=4) + + if reset: + torch.set_grad_enabled(True) + self.train() + + # expansion + b,_,h,w = flow2.shape + exp2,err2,_ = self.affine(get_grid(b,h,w)[:,0].permute(0,3,1,2).repeat(b,1,1,1).clone(), flow2.detach(),pw=1) + x = torch.cat(( + self.f3d2v2(-exp2.log()), + self.f3d2v3(err2), + ),1) + dchange2 = -exp2.log()+1./200*self.f3d2(x)[0] + + # depth change net + iexp2 = F.upsample(dchange2.clone(), [im.size()[2],im.size()[3]], mode='bilinear') + + x = torch.cat((self.dcnetv1(c12.detach()), + self.dcnetv2(dchange2.detach()), + self.dcnetv3(-exp2.log()), + self.dcnetv4(err2), + ),1) + dcneto = 1./200*self.dcnet(x)[0] + dchange2 = dchange2.detach() + dcneto[:,:1] + + flow2 = F.upsample(flow2.detach(), [im.size()[2],im.size()[3]], mode='bilinear')*4 + dchange2 = F.upsample(dchange2, [im.size()[2],im.size()[3]], mode='bilinear') + + if self.training: + flowl0 = disc_aux[0].permute(0,3,1,2).clone() + gt_depth = disc_aux[2][:,:,:,0] + gt_f3d = disc_aux[2][:,:,:,4:7].permute(0,3,1,2).clone() + gt_dchange = (1+gt_f3d[:,2]/gt_depth) + maskdc = (gt_dchange < 2) & (gt_dchange > 0.5) & disc_aux[1] + + gt_expi,gt_expi_err,maskoe = self.affine_mask(get_grid(b,4*h,4*w)[:,0].permute(0,3,1,2).repeat(b,1,1,1), flowl0,pw=3) + gt_exp = 1./gt_expi[:,0] + + loss = 0.1* (dchange2[:,0]-gt_dchange.log()).abs()[maskdc].mean() + loss += 0.1* (iexp2[:,0]-gt_exp.log()).abs()[maskoe].mean() + return flow2*4, flow3*8,flow4*16,flow5*32,flow6*64,loss, dchange2[:,0], iexp2[:,0] + + else: + return flow2, oor2, dchange2, iexp2 + diff --git a/expansion/models/__init__.py b/expansion/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/expansion/models/__pycache__/VCN_exp.cpython-38.pyc b/expansion/models/__pycache__/VCN_exp.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..adab22249d7acf5f2de9c51538b0d2f2779cae49 Binary files /dev/null and b/expansion/models/__pycache__/VCN_exp.cpython-38.pyc differ diff --git a/expansion/models/__pycache__/__init__.cpython-38.pyc b/expansion/models/__pycache__/__init__.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..0be5cf74212f653439ad3462dae141dbe4016ac1 Binary files /dev/null and b/expansion/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/expansion/models/__pycache__/conv4d.cpython-38.pyc b/expansion/models/__pycache__/conv4d.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..254f5581fc8389dd765f0d8008b9bd9f9574d870 Binary files /dev/null and b/expansion/models/__pycache__/conv4d.cpython-38.pyc differ diff --git a/expansion/models/__pycache__/submodule.cpython-38.pyc b/expansion/models/__pycache__/submodule.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..4a42df05b4e4a0431ad436bb8dc9942e6f3224d6 Binary files /dev/null and b/expansion/models/__pycache__/submodule.cpython-38.pyc differ diff --git a/expansion/models/conv4d.py b/expansion/models/conv4d.py new file mode 100755 index 0000000000000000000000000000000000000000..2747f2cf1709cc3b0adb1f0a16583eb01b2e4a1d --- /dev/null +++ b/expansion/models/conv4d.py @@ -0,0 +1,296 @@ +import pdb +import torch.nn as nn +import math +import torch +from torch.nn.parameter import Parameter +import torch.nn.functional as F +from torch.nn import Module +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.utils import _quadruple +from torch.autograd import Variable +from torch.nn import Conv2d + +def conv4d(data,filters,bias=None,permute_filters=True,use_half=False): + """ + This is done by stacking results of multiple 3D convolutions, and is very slow. + Taken from https://github.com/ignacio-rocco/ncnet + """ + b,c,h,w,d,t=data.size() + + data=data.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop + + # Same permutation is done with filters, unless already provided with permutation + if permute_filters: + filters=filters.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop + + c_out=filters.size(1) + if use_half: + output = Variable(torch.HalfTensor(h,b,c_out,w,d,t),requires_grad=data.requires_grad) + else: + output = Variable(torch.zeros(h,b,c_out,w,d,t),requires_grad=data.requires_grad) + + padding=filters.size(0)//2 + if use_half: + Z=Variable(torch.zeros(padding,b,c,w,d,t).half()) + else: + Z=Variable(torch.zeros(padding,b,c,w,d,t)) + + if data.is_cuda: + Z=Z.cuda(data.get_device()) + output=output.cuda(data.get_device()) + + data_padded = torch.cat((Z,data,Z),0) + + + for i in range(output.size(0)): # loop on first feature dimension + # convolve with center channel of filter (at position=padding) + output[i,:,:,:,:,:]=F.conv3d(data_padded[i+padding,:,:,:,:,:], + filters[padding,:,:,:,:,:], bias=bias, stride=1, padding=padding) + # convolve with upper/lower channels of filter (at postions [:padding] [padding+1:]) + for p in range(1,padding+1): + output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding-p,:,:,:,:,:], + filters[padding-p,:,:,:,:,:], bias=None, stride=1, padding=padding) + output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding+p,:,:,:,:,:], + filters[padding+p,:,:,:,:,:], bias=None, stride=1, padding=padding) + + output=output.permute(1,2,0,3,4,5).contiguous() + return output + +class Conv4d(_ConvNd): + """Applies a 4D convolution over an input signal composed of several input + planes. + """ + + def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True): + # stride, dilation and groups !=1 functionality not tested + stride=1 + dilation=1 + groups=1 + # zero padding is added automatically in conv4d function to preserve tensor size + padding = 0 + kernel_size = _quadruple(kernel_size) + stride = _quadruple(stride) + padding = _quadruple(padding) + dilation = _quadruple(dilation) + super(Conv4d, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + False, _quadruple(0), groups, bias) + # weights will be sliced along one dimension during convolution loop + # make the looping dimension to be the first one in the tensor, + # so that we don't need to call contiguous() inside the loop + self.pre_permuted_filters=pre_permuted_filters + if self.pre_permuted_filters: + self.weight.data=self.weight.data.permute(2,0,1,3,4,5).contiguous() + self.use_half=False + # self.isbias = bias + # if not self.isbias: + # self.bn = torch.nn.BatchNorm1d(out_channels) + + + def forward(self, input): + out = conv4d(input, self.weight, bias=self.bias,permute_filters=not self.pre_permuted_filters,use_half=self.use_half) # filters pre-permuted in constructor + # if not self.isbias: + # b,c,u,v,h,w = out.shape + # out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w) + return out + +class fullConv4d(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True): + super(fullConv4d, self).__init__() + self.conv = Conv4d(in_channels, out_channels, kernel_size, bias=bias, pre_permuted_filters=pre_permuted_filters) + self.isbias = bias + if not self.isbias: + self.bn = torch.nn.BatchNorm1d(out_channels) + + def forward(self, input): + out = self.conv(input) + if not self.isbias: + b,c,u,v,h,w = out.shape + out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w) + return out + +class butterfly4D(torch.nn.Module): + ''' + butterfly 4d + ''' + def __init__(self, fdima, fdimb, withbn=True, full=True,groups=1): + super(butterfly4D, self).__init__() + self.proj = nn.Sequential(projfeat4d(fdima, fdimb, 1, with_bn=withbn,groups=groups), + nn.ReLU(inplace=True),) + self.conva1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups) + self.conva2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups) + self.convb3 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups) + self.convb2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups) + self.convb1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups) + + #@profile + def forward(self,x): + out = self.proj(x) + b,c,u,v,h,w = out.shape # 9x9 + + out1 = self.conva1(out) # 5x5, 3 + _,c1,u1,v1,h1,w1 = out1.shape + + out2 = self.conva2(out1) # 3x3, 9 + _,c2,u2,v2,h2,w2 = out2.shape + + out2 = self.convb3(out2) # 3x3, 9 + + tout1 = F.upsample(out2.view(b,c,u2,v2,-1),(u1,v1,h2*w2),mode='trilinear').view(b,c,u1,v1,h2,w2) # 5x5 + tout1 = F.upsample(tout1.view(b,c,-1,h2,w2),(u1*v1,h1,w1),mode='trilinear').view(b,c,u1,v1,h1,w1) # 5x5 + out1 = tout1 + out1 + out1 = self.convb2(out1) + + tout = F.upsample(out1.view(b,c,u1,v1,-1),(u,v,h1*w1),mode='trilinear').view(b,c,u,v,h1,w1) + tout = F.upsample(tout.view(b,c,-1,h1,w1),(u*v,h,w),mode='trilinear').view(b,c,u,v,h,w) + out = tout + out + out = self.convb1(out) + + return out + + + +class projfeat4d(torch.nn.Module): + ''' + Turn 3d projection into 2d projection + ''' + def __init__(self, in_planes, out_planes, stride, with_bn=True,groups=1): + super(projfeat4d, self).__init__() + self.with_bn = with_bn + self.stride = stride + self.conv1 = nn.Conv3d(in_planes, out_planes, 1, (stride,stride,1), padding=0,bias=not with_bn,groups=groups) + self.bn = nn.BatchNorm3d(out_planes) + + def forward(self,x): + b,c,u,v,h,w = x.size() + x = self.conv1(x.view(b,c,u,v,h*w)) + if self.with_bn: + x = self.bn(x) + _,c,u,v,_ = x.shape + x = x.view(b,c,u,v,h,w) + return x + +class sepConv4d(torch.nn.Module): + ''' + Separable 4d convolution block as 2 3D convolutions + ''' + def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, ksize=3, full=True,groups=1): + super(sepConv4d, self).__init__() + bias = not with_bn + self.isproj = False + self.stride = stride[0] + expand = 1 + + if with_bn: + if in_planes != out_planes: + self.isproj = True + self.proj = nn.Sequential(nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups), + nn.BatchNorm2d(out_planes)) + if full: + self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups), + nn.BatchNorm3d(in_planes)) + else: + self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups), + nn.BatchNorm3d(in_planes)) + self.conv2 = nn.Sequential(nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups), + nn.BatchNorm3d(in_planes*expand)) + else: + if in_planes != out_planes: + self.isproj = True + self.proj = nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups) + if full: + self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups) + else: + self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups) + self.conv2 = nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups) + self.relu = nn.ReLU(inplace=True) + + #@profile + def forward(self,x): + b,c,u,v,h,w = x.shape + x = self.conv2(x.view(b,c,u,v,-1)) + b,c,u,v,_ = x.shape + x = self.relu(x) + x = self.conv1(x.view(b,c,-1,h,w)) + b,c,_,h,w = x.shape + + if self.isproj: + x = self.proj(x.view(b,c,-1,w)) + x = x.view(b,-1,u,v,h,w) + return x + + +class sepConv4dBlock(torch.nn.Module): + ''' + Separable 4d convolution block as 2 2D convolutions and a projection + layer + ''' + def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, full=True,groups=1): + super(sepConv4dBlock, self).__init__() + if in_planes == out_planes and stride==(1,1,1): + self.downsample = None + else: + if full: + self.downsample = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn,ksize=1, full=full,groups=groups) + else: + self.downsample = projfeat4d(in_planes, out_planes,stride[0], with_bn=with_bn,groups=groups) + self.conv1 = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn, full=full ,groups=groups) + self.conv2 = sepConv4d(out_planes, out_planes,(1,1,1), with_bn=with_bn, full=full,groups=groups) + self.relu1 = nn.ReLU(inplace=True) + self.relu2 = nn.ReLU(inplace=True) + + #@profile + def forward(self,x): + out = self.relu1(self.conv1(x)) + if self.downsample: + x = self.downsample(x) + out = self.relu2(x + self.conv2(out)) + return out + + +##import torch.backends.cudnn as cudnn +##cudnn.benchmark = True +#import time +##im = torch.randn(9,64,9,160,224).cuda() +##net = torch.nn.Conv3d(64, 64, 3).cuda() +##net = Conv4d(1,1,3,bias=True,pre_permuted_filters=True).cuda() +##net = sepConv4dBlock(2,2,stride=(1,1,1)).cuda() +# +##im = torch.randn(1,16,9,9,96,320).cuda() +##net = sepConv4d(16,16,with_bn=False).cuda() +# +##im = torch.randn(1,16,81,96,320).cuda() +##net = torch.nn.Conv3d(16,16,(1,3,3),padding=(0,1,1)).cuda() +# +##im = torch.randn(1,16,9,9,96*320).cuda() +##net = torch.nn.Conv3d(16,16,(3,3,1),padding=(1,1,0)).cuda() +# +##im = torch.randn(10000,10,9,9).cuda() +##net = torch.nn.Conv2d(10,10,3,padding=1).cuda() +# +##im = torch.randn(81,16,96,320).cuda() +##net = torch.nn.Conv2d(16,16,3,padding=1).cuda() +#c= int(16 *1) +#cp = int(16 *1) +#h=int(96 *4) +#w=int(320 *4) +#k=3 +#im = torch.randn(1,c,h,w).cuda() +#net = torch.nn.Conv2d(c,cp,k,padding=k//2).cuda() +# +#im2 = torch.randn(cp,k*k*c).cuda() +#im1 = F.unfold(im, (k,k), padding=k//2)[0] +# +# +#net(im) +#net(im) +#torch.mm(im2,im1) +#torch.mm(im2,im1) +#torch.cuda.synchronize() +#beg = time.time() +#for i in range(100): +# net(im) +# #im1 = F.unfold(im, (k,k), padding=k//2)[0] +# torch.mm(im2,im1) +#torch.cuda.synchronize() +#print('%f'%((time.time()-beg)*10.)) diff --git a/expansion/models/submodule.py b/expansion/models/submodule.py new file mode 100755 index 0000000000000000000000000000000000000000..0e9a032c776f320dc35aaa5a9219022232811709 --- /dev/null +++ b/expansion/models/submodule.py @@ -0,0 +1,450 @@ +from __future__ import print_function +import torch +import torch.nn as nn +import torch.utils.data +from torch.autograd import Variable +import torch.nn.functional as F +import math +import numpy as np +import pdb + +class residualBlock(nn.Module): + expansion = 1 + + def __init__(self, in_channels, n_filters, stride=1, downsample=None,dilation=1,with_bn=True): + super(residualBlock, self).__init__() + if dilation > 1: + padding = dilation + else: + padding = 1 + + if with_bn: + self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, padding, dilation=dilation) + self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1) + else: + self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, padding, dilation=dilation,with_bn=False) + self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, with_bn=False) + self.downsample = downsample + self.relu = nn.LeakyReLU(0.1, inplace=True) + + def forward(self, x): + residual = x + + out = self.convbnrelu1(x) + out = self.convbn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + return self.relu(out) + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.1,inplace=True)) + + +class conv2DBatchNorm(nn.Module): + def __init__(self, in_channels, n_filters, k_size, stride, padding, dilation=1, with_bn=True): + super(conv2DBatchNorm, self).__init__() + bias = not with_bn + + if dilation > 1: + conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, + padding=padding, stride=stride, bias=bias, dilation=dilation) + + else: + conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, + padding=padding, stride=stride, bias=bias, dilation=1) + + + if with_bn: + self.cb_unit = nn.Sequential(conv_mod, + nn.BatchNorm2d(int(n_filters)),) + else: + self.cb_unit = nn.Sequential(conv_mod,) + + def forward(self, inputs): + outputs = self.cb_unit(inputs) + return outputs + +class conv2DBatchNormRelu(nn.Module): + def __init__(self, in_channels, n_filters, k_size, stride, padding, dilation=1, with_bn=True): + super(conv2DBatchNormRelu, self).__init__() + bias = not with_bn + if dilation > 1: + conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, + padding=padding, stride=stride, bias=bias, dilation=dilation) + + else: + conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, + padding=padding, stride=stride, bias=bias, dilation=1) + + if with_bn: + self.cbr_unit = nn.Sequential(conv_mod, + nn.BatchNorm2d(int(n_filters)), + nn.LeakyReLU(0.1, inplace=True),) + else: + self.cbr_unit = nn.Sequential(conv_mod, + nn.LeakyReLU(0.1, inplace=True),) + + def forward(self, inputs): + outputs = self.cbr_unit(inputs) + return outputs + +class pyramidPooling(nn.Module): + + def __init__(self, in_channels, with_bn=True, levels=4): + super(pyramidPooling, self).__init__() + self.levels = levels + + self.paths = [] + for i in range(levels): + self.paths.append(conv2DBatchNormRelu(in_channels, in_channels, 1, 1, 0, with_bn=with_bn)) + self.path_module_list = nn.ModuleList(self.paths) + self.relu = nn.LeakyReLU(0.1, inplace=True) + + def forward(self, x): + h, w = x.shape[2:] + + k_sizes = [] + strides = [] + for pool_size in np.linspace(1,min(h,w)//2,self.levels,dtype=int): + k_sizes.append((int(h/pool_size), int(w/pool_size))) + strides.append((int(h/pool_size), int(w/pool_size))) + k_sizes = k_sizes[::-1] + strides = strides[::-1] + + pp_sum = x + + for i, module in enumerate(self.path_module_list): + out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0) + out = module(out) + out = F.upsample(out, size=(h,w), mode='bilinear') + pp_sum = pp_sum + 1./self.levels*out + pp_sum = self.relu(pp_sum/2.) + + return pp_sum + +class pspnet(nn.Module): + """ + Modified PSPNet. https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/pspnet.py + """ + def __init__(self, is_proj=True,groups=1): + super(pspnet, self).__init__() + self.inplanes = 32 + self.is_proj = is_proj + + # Encoder + self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16, + padding=1, stride=2) + self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16, + padding=1, stride=1) + self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32, + padding=1, stride=1) + # Vanilla Residual Blocks + self.res_block3 = self._make_layer(residualBlock,64,1,stride=2) + self.res_block5 = self._make_layer(residualBlock,128,1,stride=2) + self.res_block6 = self._make_layer(residualBlock,128,1,stride=2) + self.res_block7 = self._make_layer(residualBlock,128,1,stride=2) + self.pyramid_pooling = pyramidPooling(128, levels=3) + + # Iconvs + self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2), + conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, + padding=1, stride=1)) + self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128, + padding=1, stride=1) + self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2), + conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, + padding=1, stride=1)) + self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128, + padding=1, stride=1) + self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2), + conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, + padding=1, stride=1)) + self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, + padding=1, stride=1) + self.upconv3 = nn.Sequential(nn.Upsample(scale_factor=2), + conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32, + padding=1, stride=1)) + self.iconv2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64, + padding=1, stride=1) + + if self.is_proj: + self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1) + self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1) + self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1) + self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1) + self.proj2 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if hasattr(m.bias,'data'): + m.bias.data.zero_() + + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion),) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + return nn.Sequential(*layers) + + def forward(self, x): + # H, W -> H/2, W/2 + conv1 = self.convbnrelu1_1(x) + conv1 = self.convbnrelu1_2(conv1) + conv1 = self.convbnrelu1_3(conv1) + + ## H/2, W/2 -> H/4, W/4 + pool1 = F.max_pool2d(conv1, 3, 2, 1) + + # H/4, W/4 -> H/16, W/16 + rconv3 = self.res_block3(pool1) + conv4 = self.res_block5(rconv3) + conv5 = self.res_block6(conv4) + conv6 = self.res_block7(conv5) + conv6 = self.pyramid_pooling(conv6) + + conv6x = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]],mode='bilinear') + concat5 = torch.cat((conv5,self.upconv6[1](conv6x)),dim=1) + conv5 = self.iconv5(concat5) + + conv5x = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]],mode='bilinear') + concat4 = torch.cat((conv4,self.upconv5[1](conv5x)),dim=1) + conv4 = self.iconv4(concat4) + + conv4x = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]],mode='bilinear') + concat3 = torch.cat((rconv3,self.upconv4[1](conv4x)),dim=1) + conv3 = self.iconv3(concat3) + + conv3x = F.upsample(conv3, [pool1.size()[2],pool1.size()[3]],mode='bilinear') + concat2 = torch.cat((pool1,self.upconv3[1](conv3x)),dim=1) + conv2 = self.iconv2(concat2) + + if self.is_proj: + proj6 = self.proj6(conv6) + proj5 = self.proj5(conv5) + proj4 = self.proj4(conv4) + proj3 = self.proj3(conv3) + proj2 = self.proj2(conv2) + return proj6,proj5,proj4,proj3,proj2 + else: + return conv6, conv5, conv4, conv3, conv2 + + +class pspnet_s(nn.Module): + """ + Modified PSPNet. https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/pspnet.py + """ + def __init__(self, is_proj=True,groups=1): + super(pspnet_s, self).__init__() + self.inplanes = 32 + self.is_proj = is_proj + + # Encoder + self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16, + padding=1, stride=2) + self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16, + padding=1, stride=1) + self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32, + padding=1, stride=1) + # Vanilla Residual Blocks + self.res_block3 = self._make_layer(residualBlock,64,1,stride=2) + self.res_block5 = self._make_layer(residualBlock,128,1,stride=2) + self.res_block6 = self._make_layer(residualBlock,128,1,stride=2) + self.res_block7 = self._make_layer(residualBlock,128,1,stride=2) + self.pyramid_pooling = pyramidPooling(128, levels=3) + + # Iconvs + self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2), + conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, + padding=1, stride=1)) + self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128, + padding=1, stride=1) + self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2), + conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, + padding=1, stride=1)) + self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128, + padding=1, stride=1) + self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2), + conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, + padding=1, stride=1)) + self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, + padding=1, stride=1) + #self.upconv3 = nn.Sequential(nn.Upsample(scale_factor=2), + # conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32, + # padding=1, stride=1)) + #self.iconv2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64, + # padding=1, stride=1) + + if self.is_proj: + self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1) + self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1) + self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1) + self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1) + #self.proj2 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if hasattr(m.bias,'data'): + m.bias.data.zero_() + + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion),) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + return nn.Sequential(*layers) + + def forward(self, x): + # H, W -> H/2, W/2 + conv1 = self.convbnrelu1_1(x) + conv1 = self.convbnrelu1_2(conv1) + conv1 = self.convbnrelu1_3(conv1) + + ## H/2, W/2 -> H/4, W/4 + pool1 = F.max_pool2d(conv1, 3, 2, 1) + + # H/4, W/4 -> H/16, W/16 + rconv3 = self.res_block3(pool1) + conv4 = self.res_block5(rconv3) + conv5 = self.res_block6(conv4) + conv6 = self.res_block7(conv5) + conv6 = self.pyramid_pooling(conv6) + + conv6x = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]],mode='bilinear') + concat5 = torch.cat((conv5,self.upconv6[1](conv6x)),dim=1) + conv5 = self.iconv5(concat5) + + conv5x = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]],mode='bilinear') + concat4 = torch.cat((conv4,self.upconv5[1](conv5x)),dim=1) + conv4 = self.iconv4(concat4) + + conv4x = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]],mode='bilinear') + concat3 = torch.cat((rconv3,self.upconv4[1](conv4x)),dim=1) + conv3 = self.iconv3(concat3) + + #conv3x = F.upsample(conv3, [pool1.size()[2],pool1.size()[3]],mode='bilinear') + #concat2 = torch.cat((pool1,self.upconv3[1](conv3x)),dim=1) + #conv2 = self.iconv2(concat2) + + if self.is_proj: + proj6 = self.proj6(conv6) + proj5 = self.proj5(conv5) + proj4 = self.proj4(conv4) + proj3 = self.proj3(conv3) + # proj2 = self.proj2(conv2) + # return proj6,proj5,proj4,proj3,proj2 + return proj6,proj5,proj4,proj3 + else: + # return conv6, conv5, conv4, conv3, conv2 + return conv6, conv5, conv4, conv3 + +class bfmodule(nn.Module): + def __init__(self, inplanes, outplanes): + super(bfmodule, self).__init__() + self.proj = conv2DBatchNormRelu(in_channels=inplanes,k_size=1,n_filters=64,padding=0,stride=1) + self.inplanes = 64 + # Vanilla Residual Blocks + self.res_block3 = self._make_layer(residualBlock,64,1,stride=2) + self.res_block5 = self._make_layer(residualBlock,64,1,stride=2) + self.res_block6 = self._make_layer(residualBlock,64,1,stride=2) + self.res_block7 = self._make_layer(residualBlock,128,1,stride=2) + self.pyramid_pooling = pyramidPooling(128, levels=3) + # Iconvs + self.upconv6 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, + padding=1, stride=1) + self.upconv5 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32, + padding=1, stride=1) + self.upconv4 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32, + padding=1, stride=1) + self.upconv3 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32, + padding=1, stride=1) + self.iconv5 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, + padding=1, stride=1) + self.iconv4 = conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64, + padding=1, stride=1) + self.iconv3 = conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64, + padding=1, stride=1) + self.iconv2 = nn.Sequential(conv2DBatchNormRelu(in_channels=96, k_size=3, n_filters=64, + padding=1, stride=1), + nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True)) + + self.proj6 = nn.Conv2d(128, outplanes,kernel_size=3, stride=1, padding=1, bias=True) + self.proj5 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True) + self.proj4 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True) + self.proj3 = nn.Conv2d(64, outplanes,kernel_size=3, stride=1, padding=1, bias=True) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if hasattr(m.bias,'data'): + m.bias.data.zero_() + + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion),) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + return nn.Sequential(*layers) + + def forward(self, x): + proj = self.proj(x) # 4x + rconv3 = self.res_block3(proj) #8x + conv4 = self.res_block5(rconv3) #16x + conv5 = self.res_block6(conv4) #32x + conv6 = self.res_block7(conv5) #64x + conv6 = self.pyramid_pooling(conv6) #64x + pred6 = self.proj6(conv6) + + conv6u = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]], mode='bilinear') + concat5 = torch.cat((conv5,self.upconv6(conv6u)),dim=1) + conv5 = self.iconv5(concat5) #32x + pred5 = self.proj5(conv5) + + conv5u = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]], mode='bilinear') + concat4 = torch.cat((conv4,self.upconv5(conv5u)),dim=1) + conv4 = self.iconv4(concat4) #16x + pred4 = self.proj4(conv4) + + conv4u = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]], mode='bilinear') + concat3 = torch.cat((rconv3,self.upconv4(conv4u)),dim=1) + conv3 = self.iconv3(concat3) # 8x + pred3 = self.proj3(conv3) + + conv3u = F.upsample(conv3, [x.size()[2],x.size()[3]], mode='bilinear') + concat2 = torch.cat((proj,self.upconv3(conv3u)),dim=1) + pred2 = self.iconv2(concat2) # 4x + + return pred2, pred3, pred4, pred5, pred6 + diff --git a/expansion/submission.py b/expansion/submission.py new file mode 100755 index 0000000000000000000000000000000000000000..21ebb64061b13b9b00e3c52f606d42612de3812c --- /dev/null +++ b/expansion/submission.py @@ -0,0 +1,95 @@ +from __future__ import print_function +import sys +import cv2 +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +import torch.optim as optim +import torch.nn.functional as F +cudnn.benchmark = False + +class Expansion(): + + def __init__(self, loadmodel = 'pretrained_models/optical_expansion/robust.pth', testres = 1, maxdisp = 256, fac = 1): + + maxw,maxh = [int(testres*1280), int(testres*384)] + + max_h = int(maxh // 64 * 64) + max_w = int(maxw // 64 * 64) + if max_h < maxh: max_h += 64 + if max_w < maxw: max_w += 64 + maxh = max_h + maxw = max_w + + mean_L = [[0.33,0.33,0.33]] + mean_R = [[0.33,0.33,0.33]] + + # construct model, VCN-expansion + from expansion.models.VCN_exp import VCN + model = VCN([1, maxw, maxh], md=[int(4*(maxdisp/256)),4,4,4,4], fac=fac, + exp_unc=('robust' in loadmodel)) # expansion uncertainty only in the new model + model = nn.DataParallel(model, device_ids=[0]) + model.cuda() + + if loadmodel is not None: + pretrained_dict = torch.load(loadmodel) + mean_L=pretrained_dict['mean_L'] + mean_R=pretrained_dict['mean_R'] + pretrained_dict['state_dict'] = {k:v for k,v in pretrained_dict['state_dict'].items()} + model.load_state_dict(pretrained_dict['state_dict'],strict=False) + else: + print('dry run') + + model.eval() + # resize + maxh = 256 + maxw = 256 + max_h = int(maxh // 64 * 64) + max_w = int(maxw // 64 * 64) + if max_h < maxh: max_h += 64 + if max_w < maxw: max_w += 64 + + # modify module according to inputs + from expansion.models.VCN_exp import WarpModule, flow_reg + for i in range(len(model.module.reg_modules)): + model.module.reg_modules[i] = flow_reg([1,max_w//(2**(6-i)), max_h//(2**(6-i))], + ent=getattr(model.module, 'flow_reg%d'%2**(6-i)).ent,\ + maxdisp=getattr(model.module, 'flow_reg%d'%2**(6-i)).md,\ + fac=getattr(model.module, 'flow_reg%d'%2**(6-i)).fac).cuda() + for i in range(len(model.module.warp_modules)): + model.module.warp_modules[i] = WarpModule([1,max_w//(2**(6-i)), max_h//(2**(6-i))]).cuda() + + mean_L = torch.from_numpy(np.asarray(mean_L).astype(np.float32).mean(0)[np.newaxis,:,np.newaxis,np.newaxis]).cuda() + mean_R = torch.from_numpy(np.asarray(mean_R).astype(np.float32).mean(0)[np.newaxis,:,np.newaxis,np.newaxis]).cuda() + + self.max_h = max_h + self.max_w = max_w + self.model = model + self.mean_L = mean_L + self.mean_R = mean_R + + def run(self, imgL_o, imgR_o): + model = self.model + mean_L = self.mean_L + mean_R = self.mean_R + + imgL_o[imgL_o<-1] = -1 + imgL_o[imgL_o>1] = 1 + imgR_o[imgR_o<-1] = -1 + imgR_o[imgR_o>1] = 1 + imgL = (imgL_o+1.)*0.5-mean_L + imgR = (imgR_o*1.)*0.5-mean_R + + with torch.no_grad(): + imgLR = torch.cat([imgL,imgR],0) + model.eval() + torch.cuda.synchronize() + rts = model(imgLR) + torch.cuda.synchronize() + flow, occ, logmid, logexp = rts + + torch.cuda.empty_cache() + + return flow, logexp diff --git a/expansion/utils/__init__.py b/expansion/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/expansion/utils/__pycache__/__init__.cpython-38.pyc b/expansion/utils/__pycache__/__init__.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..2d86925f5b75d1d7208c8e21a5d8820c67d1396a Binary files /dev/null and b/expansion/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/expansion/utils/__pycache__/flowlib.cpython-38.pyc b/expansion/utils/__pycache__/flowlib.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..098a490b5ccbdea2a9fa168b66c74dba9c9eef95 Binary files /dev/null and b/expansion/utils/__pycache__/flowlib.cpython-38.pyc differ diff --git a/expansion/utils/__pycache__/io.cpython-38.pyc b/expansion/utils/__pycache__/io.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..e065c426452bde7ddd36ae2952aa4b392728eea6 Binary files /dev/null and b/expansion/utils/__pycache__/io.cpython-38.pyc differ diff --git a/expansion/utils/__pycache__/pfm.cpython-38.pyc b/expansion/utils/__pycache__/pfm.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..ed1807440f8c97584409a5335134a826ae6d83ea Binary files /dev/null and b/expansion/utils/__pycache__/pfm.cpython-38.pyc differ diff --git a/expansion/utils/__pycache__/util_flow.cpython-38.pyc b/expansion/utils/__pycache__/util_flow.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..510d6d73c7122b777e80935b6d9ac0b567da2cb0 Binary files /dev/null and b/expansion/utils/__pycache__/util_flow.cpython-38.pyc differ diff --git a/expansion/utils/flowlib.py b/expansion/utils/flowlib.py new file mode 100755 index 0000000000000000000000000000000000000000..59096c15da5e529ecb2d85be4881ee467f0c838f --- /dev/null +++ b/expansion/utils/flowlib.py @@ -0,0 +1,656 @@ +""" +# ============================== +# flowlib.py +# library for optical flow processing +# Author: Ruoteng Li +# Date: 6th Aug 2016 +# ============================== +""" +import png +from . import pfm +import numpy as np +import matplotlib.colors as cl +import matplotlib.pyplot as plt +from PIL import Image +import cv2 +import pdb + + +UNKNOWN_FLOW_THRESH = 1e7 +SMALLFLOW = 0.0 +LARGEFLOW = 1e8 + +""" +============= +Flow Section +============= +""" + +def point_vec(img,flow,skip=16): + #img[:] = 255 + maxsize=256 + extendfac=2. + resize_factor = max(1,int(max(maxsize/img.shape[0], maxsize/img.shape[1]))) + meshgrid = np.meshgrid(range(img.shape[1]),range(img.shape[0])) + dispimg = cv2.resize(img[:,:,::-1].copy(), None,fx=resize_factor,fy=resize_factor) + colorflow = flow_to_image(flow).astype(int) + + for i in range(img.shape[1]): # x + for j in range(img.shape[0]): # y + if flow[j,i,2] != 1: continue + if j%skip!=0 or i%skip!=0: continue + xend = int((meshgrid[0][j,i]+extendfac*flow[j,i,0])*resize_factor) + yend = int((meshgrid[1][j,i]+extendfac*flow[j,i,1])*resize_factor) + leng = np.linalg.norm(flow[j,i,:2]*extendfac) + if leng<3:continue + dispimg = cv2.arrowedLine(dispimg, (meshgrid[0][j,i]*resize_factor,meshgrid[1][j,i]*resize_factor),\ + (xend,yend), + (int(colorflow[j,i,2]),int(colorflow[j,i,1]),int(colorflow[j,i,0])),4,tipLength=2/leng,line_type=cv2.LINE_AA) + return dispimg + + +def show_flow(filename): + """ + visualize optical flow map using matplotlib + :param filename: optical flow file + :return: None + """ + flow = read_flow(filename) + img = flow_to_image(flow) + plt.imshow(img) + plt.show() + + +def visualize_flow(flow, mode='Y'): + """ + this function visualize the input flow + :param flow: input flow in array + :param mode: choose which color mode to visualize the flow (Y: Ccbcr, RGB: RGB color) + :return: None + """ + if mode == 'Y': + # Ccbcr color wheel + img = flow_to_image(flow) + plt.imshow(img) + plt.show() + elif mode == 'RGB': + (h, w) = flow.shape[0:2] + du = flow[:, :, 0] + dv = flow[:, :, 1] + valid = flow[:, :, 2] + max_flow = max(np.max(du), np.max(dv)) + img = np.zeros((h, w, 3), dtype=np.float64) + # angle layer + img[:, :, 0] = np.arctan2(dv, du) / (2 * np.pi) + # magnitude layer, normalized to 1 + img[:, :, 1] = np.sqrt(du * du + dv * dv) * 8 / max_flow + # phase layer + img[:, :, 2] = 8 - img[:, :, 1] + # clip to [0,1] + small_idx = img[:, :, 0:3] < 0 + large_idx = img[:, :, 0:3] > 1 + img[small_idx] = 0 + img[large_idx] = 1 + # convert to rgb + img = cl.hsv_to_rgb(img) + # remove invalid point + import pdb; pdb.set_trace() + img[:, :, 0] = img[:, :, 0] * valid + img[:, :, 1] = img[:, :, 1] * valid + img[:, :, 2] = img[:, :, 2] * valid + # show + plt.imshow(img) + plt.show() + + return None + + +def read_flow(filename): + """ + read optical flow data from flow file + :param filename: name of the flow file + :return: optical flow data in numpy array + """ + if filename.endswith('.flo'): + flow = read_flo_file(filename) + elif filename.endswith('.png'): + flow = read_png_file(filename) + elif filename.endswith('.pfm'): + flow = read_pfm_file(filename) + else: + raise Exception('Invalid flow file format!') + + return flow + + +def write_flow(flow, filename): + """ + write optical flow in Middlebury .flo format + :param flow: optical flow map + :param filename: optical flow file path to be saved + :return: None + """ + f = open(filename, 'wb') + magic = np.array([202021.25], dtype=np.float32) + (height, width) = flow.shape[0:2] + w = np.array([width], dtype=np.int32) + h = np.array([height], dtype=np.int32) + magic.tofile(f) + w.tofile(f) + h.tofile(f) + flow.tofile(f) + f.close() + + +def save_flow_image(flow, image_file): + """ + save flow visualization into image file + :param flow: optical flow data + :param flow_fil + :return: None + """ + flow_img = flow_to_image(flow) + img_out = Image.fromarray(flow_img) + img_out.save(image_file) + + +def flowfile_to_imagefile(flow_file, image_file): + """ + convert flowfile into image file + :param flow: optical flow data + :param flow_fil + :return: None + """ + flow = read_flow(flow_file) + save_flow_image(flow, image_file) + + +def segment_flow(flow): + h = flow.shape[0] + w = flow.shape[1] + u = flow[:, :, 0] + v = flow[:, :, 1] + + idx = ((abs(u) > LARGEFLOW) | (abs(v) > LARGEFLOW)) + idx2 = (abs(u) == SMALLFLOW) + class0 = (v == 0) & (u == 0) + u[idx2] = 0.00001 + tan_value = v / u + + class1 = (tan_value < 1) & (tan_value >= 0) & (u > 0) & (v >= 0) + class2 = (tan_value >= 1) & (u >= 0) & (v >= 0) + class3 = (tan_value < -1) & (u <= 0) & (v >= 0) + class4 = (tan_value < 0) & (tan_value >= -1) & (u < 0) & (v >= 0) + class8 = (tan_value >= -1) & (tan_value < 0) & (u > 0) & (v <= 0) + class7 = (tan_value < -1) & (u >= 0) & (v <= 0) + class6 = (tan_value >= 1) & (u <= 0) & (v <= 0) + class5 = (tan_value >= 0) & (tan_value < 1) & (u < 0) & (v <= 0) + + seg = np.zeros((h, w)) + + seg[class1] = 1 + seg[class2] = 2 + seg[class3] = 3 + seg[class4] = 4 + seg[class5] = 5 + seg[class6] = 6 + seg[class7] = 7 + seg[class8] = 8 + seg[class0] = 0 + seg[idx] = 0 + + return seg + + +def flow_error(tu, tv, u, v): + """ + Calculate average end point error + :param tu: ground-truth horizontal flow map + :param tv: ground-truth vertical flow map + :param u: estimated horizontal flow map + :param v: estimated vertical flow map + :return: End point error of the estimated flow + """ + smallflow = 0.0 + ''' + stu = tu[bord+1:end-bord,bord+1:end-bord] + stv = tv[bord+1:end-bord,bord+1:end-bord] + su = u[bord+1:end-bord,bord+1:end-bord] + sv = v[bord+1:end-bord,bord+1:end-bord] + ''' + stu = tu[:] + stv = tv[:] + su = u[:] + sv = v[:] + + idxUnknow = (abs(stu) > UNKNOWN_FLOW_THRESH) | (abs(stv) > UNKNOWN_FLOW_THRESH) + stu[idxUnknow] = 0 + stv[idxUnknow] = 0 + su[idxUnknow] = 0 + sv[idxUnknow] = 0 + + ind2 = [(np.absolute(stu) > smallflow) | (np.absolute(stv) > smallflow)] + index_su = su[ind2] + index_sv = sv[ind2] + an = 1.0 / np.sqrt(index_su ** 2 + index_sv ** 2 + 1) + un = index_su * an + vn = index_sv * an + + index_stu = stu[ind2] + index_stv = stv[ind2] + tn = 1.0 / np.sqrt(index_stu ** 2 + index_stv ** 2 + 1) + tun = index_stu * tn + tvn = index_stv * tn + + ''' + angle = un * tun + vn * tvn + (an * tn) + index = [angle == 1.0] + angle[index] = 0.999 + ang = np.arccos(angle) + mang = np.mean(ang) + mang = mang * 180 / np.pi + ''' + + epe = np.sqrt((stu - su) ** 2 + (stv - sv) ** 2) + epe = epe[ind2] + mepe = np.mean(epe) + return mepe + + +def flow_to_image(flow): + """ + Convert flow into middlebury color code image + :param flow: optical flow map + :return: optical flow image in middlebury color + """ + u = flow[:, :, 0] + v = flow[:, :, 1] + + maxu = -999. + maxv = -999. + minu = 999. + minv = 999. + + idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) + u[idxUnknow] = 0 + v[idxUnknow] = 0 + + maxu = max(maxu, np.max(u)) + minu = min(minu, np.min(u)) + + maxv = max(maxv, np.max(v)) + minv = min(minv, np.min(v)) + + rad = np.sqrt(u ** 2 + v ** 2) + maxrad = max(-1, np.max(rad)) + + u = u/(maxrad + np.finfo(float).eps) + v = v/(maxrad + np.finfo(float).eps) + + img = compute_color(u, v) + + idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) + img[idx] = 0 + + return np.uint8(img) + + +def evaluate_flow_file(gt_file, pred_file): + """ + evaluate the estimated optical flow end point error according to ground truth provided + :param gt_file: ground truth file path + :param pred_file: estimated optical flow file path + :return: end point error, float32 + """ + # Read flow files and calculate the errors + gt_flow = read_flow(gt_file) # ground truth flow + eva_flow = read_flow(pred_file) # predicted flow + # Calculate errors + average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], eva_flow[:, :, 0], eva_flow[:, :, 1]) + return average_pe + + +def evaluate_flow(gt_flow, pred_flow): + """ + gt: ground-truth flow + pred: estimated flow + """ + average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], pred_flow[:, :, 0], pred_flow[:, :, 1]) + return average_pe + + +""" +============== +Disparity Section +============== +""" + + +def read_disp_png(file_name): + """ + Read optical flow from KITTI .png file + :param file_name: name of the flow file + :return: optical flow data in matrix + """ + image_object = png.Reader(filename=file_name) + image_direct = image_object.asDirect() + image_data = list(image_direct[2]) + (w, h) = image_direct[3]['size'] + channel = len(image_data[0]) / w + flow = np.zeros((h, w, channel), dtype=np.uint16) + for i in range(len(image_data)): + for j in range(channel): + flow[i, :, j] = image_data[i][j::channel] + return flow[:, :, 0] / 256 + + +def disp_to_flowfile(disp, filename): + """ + Read KITTI disparity file in png format + :param disp: disparity matrix + :param filename: the flow file name to save + :return: None + """ + f = open(filename, 'wb') + magic = np.array([202021.25], dtype=np.float32) + (height, width) = disp.shape[0:2] + w = np.array([width], dtype=np.int32) + h = np.array([height], dtype=np.int32) + empty_map = np.zeros((height, width), dtype=np.float32) + data = np.dstack((disp, empty_map)) + magic.tofile(f) + w.tofile(f) + h.tofile(f) + data.tofile(f) + f.close() + + +""" +============== +Image Section +============== +""" + + +def read_image(filename): + """ + Read normal image of any format + :param filename: name of the image file + :return: image data in matrix uint8 type + """ + img = Image.open(filename) + im = np.array(img) + return im + +def warp_flow(img, flow): + h, w = flow.shape[:2] + flow = flow.copy().astype(np.float32) + flow[:,:,0] += np.arange(w) + flow[:,:,1] += np.arange(h)[:,np.newaxis] + res = cv2.remap(img, flow, None, cv2.INTER_LINEAR) + return res + +def warp_image(im, flow): + """ + Use optical flow to warp image to the next + :param im: image to warp + :param flow: optical flow + :return: warped image + """ + from scipy import interpolate + image_height = im.shape[0] + image_width = im.shape[1] + flow_height = flow.shape[0] + flow_width = flow.shape[1] + n = image_height * image_width + (iy, ix) = np.mgrid[0:image_height, 0:image_width] + (fy, fx) = np.mgrid[0:flow_height, 0:flow_width] + fx = fx.astype(np.float64) + fy = fy.astype(np.float64) + fx += flow[:,:,0] + fy += flow[:,:,1] + mask = np.logical_or(fx <0 , fx > flow_width) + mask = np.logical_or(mask, fy < 0) + mask = np.logical_or(mask, fy > flow_height) + fx = np.minimum(np.maximum(fx, 0), flow_width) + fy = np.minimum(np.maximum(fy, 0), flow_height) + points = np.concatenate((ix.reshape(n,1), iy.reshape(n,1)), axis=1) + xi = np.concatenate((fx.reshape(n, 1), fy.reshape(n,1)), axis=1) + warp = np.zeros((image_height, image_width, im.shape[2])) + for i in range(im.shape[2]): + channel = im[:, :, i] + plt.imshow(channel, cmap='gray') + values = channel.reshape(n, 1) + new_channel = interpolate.griddata(points, values, xi, method='cubic') + new_channel = np.reshape(new_channel, [flow_height, flow_width]) + new_channel[mask] = 1 + warp[:, :, i] = new_channel.astype(np.uint8) + + return warp.astype(np.uint8) + + +""" +============== +Others +============== +""" + +def pfm_to_flo(pfm_file): + flow_filename = pfm_file[0:pfm_file.find('.pfm')] + '.flo' + (data, scale) = pfm.readPFM(pfm_file) + flow = data[:, :, 0:2] + write_flow(flow, flow_filename) + + +def scale_image(image, new_range): + """ + Linearly scale the image into desired range + :param image: input image + :param new_range: the new range to be aligned + :return: image normalized in new range + """ + min_val = np.min(image).astype(np.float32) + max_val = np.max(image).astype(np.float32) + min_val_new = np.array(min(new_range), dtype=np.float32) + max_val_new = np.array(max(new_range), dtype=np.float32) + scaled_image = (image - min_val) / (max_val - min_val) * (max_val_new - min_val_new) + min_val_new + return scaled_image.astype(np.uint8) + + +def compute_color(u, v): + """ + compute optical flow color map + :param u: optical flow horizontal map + :param v: optical flow vertical map + :return: optical flow in color code + """ + [h, w] = u.shape + img = np.zeros([h, w, 3]) + nanIdx = np.isnan(u) | np.isnan(v) + u[nanIdx] = 0 + v[nanIdx] = 0 + + colorwheel = make_color_wheel() + ncols = np.size(colorwheel, 0) + + rad = np.sqrt(u**2+v**2) + + a = np.arctan2(-v, -u) / np.pi + + fk = (a+1) / 2 * (ncols - 1) + 1 + + k0 = np.floor(fk).astype(int) + + k1 = k0 + 1 + k1[k1 == ncols+1] = 1 + f = fk - k0 + + for i in range(0, np.size(colorwheel,1)): + tmp = colorwheel[:, i] + col0 = tmp[k0-1] / 255 + col1 = tmp[k1-1] / 255 + col = (1-f) * col0 + f * col1 + + idx = rad <= 1 + col[idx] = 1-rad[idx]*(1-col[idx]) + notidx = np.logical_not(idx) + + col[notidx] *= 0.75 + img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) + + return img + + +def make_color_wheel(): + """ + Generate color wheel according Middlebury color code + :return: Color wheel + """ + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + + colorwheel = np.zeros([ncols, 3]) + + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) + col += RY + + # YG + colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) + colorwheel[col:col+YG, 1] = 255 + col += YG + + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) + col += GC + + # CB + colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) + colorwheel[col:col+CB, 2] = 255 + col += CB + + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) + col += + BM + + # MR + colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) + colorwheel[col:col+MR, 0] = 255 + + return colorwheel + + +def read_flo_file(filename): + """ + Read from Middlebury .flo file + :param flow_file: name of the flow file + :return: optical flow data in matrix + """ + f = open(filename, 'rb') + magic = np.fromfile(f, np.float32, count=1) + data2d = None + + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + #print("Reading %d x %d flow file in .flo format" % (h, w)) + flow = np.ones((h[0],w[0],3)) + data2d = np.fromfile(f, np.float32, count=2 * w[0] * h[0]) + # reshape data into 3D array (columns, rows, channels) + data2d = np.resize(data2d, (h[0], w[0], 2)) + flow[:,:,:2] = data2d + f.close() + return flow + + +def read_png_file(flow_file): + """ + Read from KITTI .png file + :param flow_file: name of the flow file + :return: optical flow data in matrix + """ + flow = cv2.imread(flow_file,-1)[:,:,::-1].astype(np.float64) + # flow_object = png.Reader(filename=flow_file) + # flow_direct = flow_object.asDirect() + # flow_data = list(flow_direct[2]) + # (w, h) = flow_direct[3]['size'] + # #print("Reading %d x %d flow file in .png format" % (h, w)) + # flow = np.zeros((h, w, 3), dtype=np.float64) + # for i in range(len(flow_data)): + # flow[i, :, 0] = flow_data[i][0::3] + # flow[i, :, 1] = flow_data[i][1::3] + # flow[i, :, 2] = flow_data[i][2::3] + + invalid_idx = (flow[:, :, 2] == 0) + flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0 + flow[invalid_idx, 0] = 0 + flow[invalid_idx, 1] = 0 + return flow + + +def read_pfm_file(flow_file): + """ + Read from .pfm file + :param flow_file: name of the flow file + :return: optical flow data in matrix + """ + (data, scale) = pfm.readPFM(flow_file) + return data + + +# fast resample layer +def resample(img, sz): + """ + img: flow map to be resampled + sz: new flow map size. Must be [height,weight] + """ + original_image_size = img.shape + in_height = img.shape[0] + in_width = img.shape[1] + out_height = sz[0] + out_width = sz[1] + out_flow = np.zeros((out_height, out_width, 2)) + # find scale + height_scale = float(in_height) / float(out_height) + width_scale = float(in_width) / float(out_width) + + [x,y] = np.meshgrid(range(out_width), range(out_height)) + xx = x * width_scale + yy = y * height_scale + x0 = np.floor(xx).astype(np.int32) + x1 = x0 + 1 + y0 = np.floor(yy).astype(np.int32) + y1 = y0 + 1 + + x0 = np.clip(x0,0,in_width-1) + x1 = np.clip(x1,0,in_width-1) + y0 = np.clip(y0,0,in_height-1) + y1 = np.clip(y1,0,in_height-1) + + Ia = img[y0,x0,:] + Ib = img[y1,x0,:] + Ic = img[y0,x1,:] + Id = img[y1,x1,:] + + wa = (y1-yy) * (x1-xx) + wb = (yy-y0) * (x1-xx) + wc = (y1-yy) * (xx-x0) + wd = (yy-y0) * (xx-x0) + out_flow[:,:,0] = (Ia[:,:,0]*wa + Ib[:,:,0]*wb + Ic[:,:,0]*wc + Id[:,:,0]*wd) * out_width / in_width + out_flow[:,:,1] = (Ia[:,:,1]*wa + Ib[:,:,1]*wb + Ic[:,:,1]*wc + Id[:,:,1]*wd) * out_height / in_height + + return out_flow + diff --git a/expansion/utils/io.py b/expansion/utils/io.py new file mode 100755 index 0000000000000000000000000000000000000000..4677d32651b9f017142f4047f938052b1901ee4f --- /dev/null +++ b/expansion/utils/io.py @@ -0,0 +1,164 @@ +import errno +import os +import shutil +import sys +import traceback +import zipfile + +if sys.version_info[0] == 2: + import urllib2 +else: + import urllib.request + + +# Converts a string to bytes (for writing the string into a file). Provided for +# compatibility with Python 2 and 3. +def StrToBytes(text): + if sys.version_info[0] == 2: + return text + else: + return bytes(text, 'UTF-8') + + +# Outputs the given text and lets the user input a response (submitted by +# pressing the return key). Provided for compatibility with Python 2 and 3. +def GetUserInput(text): + if sys.version_info[0] == 2: + return raw_input(text) + else: + return input(text) + + +# Creates the given directory (hierarchy), which may already exist. Provided for +# compatibility with Python 2 and 3. +def MakeDirsExistOk(directory_path): + try: + os.makedirs(directory_path) + except OSError as exception: + if exception.errno != errno.EEXIST: + raise + + +# Deletes all files and folders within the given folder. +def DeleteFolderContents(folder_path): + for file_name in os.listdir(folder_path): + file_path = os.path.join(folder_path, file_name) + try: + if os.path.isfile(file_path): + os.unlink(file_path) + else: #if os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print('Exception in DeleteFolderContents():') + print(e) + print('Stack trace:') + print(traceback.format_exc()) + + +# Creates the given directory, respectively deletes all content of the directory +# in case it already exists. +def MakeCleanDirectory(folder_path): + if os.path.isdir(folder_path): + DeleteFolderContents(folder_path) + else: + MakeDirsExistOk(folder_path) + + +# Downloads the given URL to a file in the given directory. Returns the +# path to the downloaded file. +# In part adapted from: https://stackoverflow.com/questions/22676 +def DownloadFile(url, dest_dir_path): + file_name = url.split('/')[-1] + dest_file_path = os.path.join(dest_dir_path, file_name) + + if os.path.isfile(dest_file_path): + print('The following file already exists:') + print(dest_file_path) + print('Please choose whether to re-download and overwrite the file [o] or to skip downloading this file [s] by entering o or s.') + while True: + response = GetUserInput("> ") + if response == 's': + return dest_file_path + elif response == 'o': + break + else: + print('Please enter o or s.') + + url_object = None + if sys.version_info[0] == 2: + url_object = urllib2.urlopen(url) + else: + url_object = urllib.request.urlopen(url) + + with open(dest_file_path, 'wb') as outfile: + meta = url_object.info() + file_size = 0 + if sys.version_info[0] == 2: + file_size = int(meta.getheaders("Content-Length")[0]) + else: + file_size = int(meta["Content-Length"]) + print("Downloading: %s (size [bytes]: %s)" % (url, file_size)) + + file_size_downloaded = 0 + block_size = 8192 + while True: + buffer = url_object.read(block_size) + if not buffer: + break + + file_size_downloaded += len(buffer) + outfile.write(buffer) + + sys.stdout.write("%d / %d (%3f%%)\r" % (file_size_downloaded, file_size, file_size_downloaded * 100. / file_size)) + sys.stdout.flush() + + return dest_file_path + + +# Unzips the given zip file into the given directory. +def UnzipFile(file_path, unzip_dir_path, overwrite=True): + zip_ref = zipfile.ZipFile(open(file_path, 'rb')) + + if not overwrite: + for f in zip_ref.namelist(): + if not os.path.isfile(os.path.join(unzip_dir_path, f)): + zip_ref.extract(f, path=unzip_dir_path) + else: + print('Not overwriting {}'.format(f)) + else: + zip_ref.extractall(unzip_dir_path) + zip_ref.close() + + +# Creates a zip file with the contents of the given directory. +# The archive_base_path must not include the extension .zip. The full, final +# path of the archive is returned by the function. +def ZipDirectory(archive_base_path, root_dir_path): + # return shutil.make_archive(archive_base_path, 'zip', root_dir_path) # THIS WILL ALWAYS HAVE ./ FOLDER INCLUDED + with zipfile.ZipFile(archive_base_path+'.zip', "w", compression=zipfile.ZIP_DEFLATED) as zf: + base_path = os.path.normpath(root_dir_path) + for dirpath, dirnames, filenames in os.walk(root_dir_path): + for name in sorted(dirnames): + path = os.path.normpath(os.path.join(dirpath, name)) + zf.write(path, os.path.relpath(path, base_path)) + for name in filenames: + path = os.path.normpath(os.path.join(dirpath, name)) + if os.path.isfile(path): + zf.write(path, os.path.relpath(path, base_path)) + + return archive_base_path+'.zip' + + +# Downloads a zip file and directly unzips it. +def DownloadAndUnzipFile(url, archive_dir_path, unzip_dir_path, overwrite=True): + archive_path = DownloadFile(url, archive_dir_path) + UnzipFile(archive_path, unzip_dir_path, overwrite=overwrite) + +def mkdir_p(path): + try: + os.makedirs(path) + except OSError as exc: # Python >2.5 + if exc.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise diff --git a/expansion/utils/logger.py b/expansion/utils/logger.py new file mode 100755 index 0000000000000000000000000000000000000000..9bd31e79407f8e7e94d236c9b0e620403d1e3d85 --- /dev/null +++ b/expansion/utils/logger.py @@ -0,0 +1,113 @@ +""" +File: logger.py +Modified by: Senthil Purushwalkam +Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 +Email: spurushwandrewcmuedu +Github: https://github.com/senthilps8 +Description: +""" +import pdb +import tensorflow as tf +from torch.autograd import Variable +import numpy as np +import scipy.misc +import os +try: + from StringIO import StringIO # Python 2.7 +except ImportError: + from io import BytesIO # Python 3.x + + +class Logger(object): + + def __init__(self, log_dir, name=None): + """Create a summary writer logging to log_dir.""" + if name is None: + name = 'temp' + self.name = name + if name is not None: + try: + os.makedirs(os.path.join(log_dir, name)) + except: + pass + self.writer = tf.summary.FileWriter(os.path.join(log_dir, name), + filename_suffix=name) + else: + self.writer = tf.summary.FileWriter(log_dir, filename_suffix=name) + + def scalar_summary(self, tag, value, step): + """Log a scalar variable.""" + summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) + self.writer.add_summary(summary, step) + + def image_summary(self, tag, images, step): + """Log a list of images.""" + + img_summaries = [] + for i, img in enumerate(images): + # Write the image to a string + try: + s = StringIO() + except: + s = BytesIO() + scipy.misc.toimage(img).save(s, format="png") + + # Create an Image object + img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), + height=img.shape[0], + width=img.shape[1]) + # Create a Summary value + img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) + + # Create and write Summary + summary = tf.Summary(value=img_summaries) + self.writer.add_summary(summary, step) + + def histo_summary(self, tag, values, step, bins=1000): + """Log a histogram of the tensor of values.""" + + # Create a histogram using numpy + counts, bin_edges = np.histogram(values, bins=bins) + + # Fill the fields of the histogram proto + hist = tf.HistogramProto() + hist.min = float(np.min(values)) + hist.max = float(np.max(values)) + hist.num = int(np.prod(values.shape)) + hist.sum = float(np.sum(values)) + hist.sum_squares = float(np.sum(values**2)) + + # Drop the start of the first bin + bin_edges = bin_edges[1:] + + # Add bin edges and counts + for edge in bin_edges: + hist.bucket_limit.append(edge) + for c in counts: + hist.bucket.append(c) + + # Create and write Summary + summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) + self.writer.add_summary(summary, step) + self.writer.flush() + + def to_np(self, x): + return x.data.cpu().numpy() + + def to_var(self, x): + if torch.cuda.is_available(): + x = x.cuda() + return Variable(x) + + def model_param_histo_summary(self, model, step): + """log histogram summary of model's parameters + and parameter gradients + """ + for tag, value in model.named_parameters(): + if value.grad is None: + continue + tag = tag.replace('.', '/') + tag = self.name+'/'+tag + self.histo_summary(tag, self.to_np(value), step) + self.histo_summary(tag+'/grad', self.to_np(value.grad), step) + diff --git a/expansion/utils/multiscaleloss.py b/expansion/utils/multiscaleloss.py new file mode 100755 index 0000000000000000000000000000000000000000..b3240c7d94090c9c1447b0b844d0c16c36204796 --- /dev/null +++ b/expansion/utils/multiscaleloss.py @@ -0,0 +1,86 @@ +""" +Taken from https://github.com/ClementPinard/FlowNetPytorch +""" +import pdb +import torch +import torch.nn.functional as F + + +def EPE(input_flow, target_flow, mask, sparse=False, mean=True): + #mask = target_flow[:,2]>0 + target_flow = target_flow[:,:2] + EPE_map = torch.norm(target_flow-input_flow,2,1) + batch_size = EPE_map.size(0) + if sparse: + # invalid flow is defined with both flow coordinates to be exactly 0 + mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0) + + EPE_map = EPE_map[~mask] + if mean: + return EPE_map[mask].mean() + else: + return EPE_map[mask].sum()/batch_size + +def rob_EPE(input_flow, target_flow, mask, sparse=False, mean=True): + #mask = target_flow[:,2]>0 + target_flow = target_flow[:,:2] + #TODO +# EPE_map = torch.norm(target_flow-input_flow,2,1) + EPE_map = (torch.norm(target_flow-input_flow,1,1)+0.01).pow(0.4) + batch_size = EPE_map.size(0) + if sparse: + # invalid flow is defined with both flow coordinates to be exactly 0 + mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0) + + EPE_map = EPE_map[~mask] + if mean: + return EPE_map[mask].mean() + else: + return EPE_map[mask].sum()/batch_size + +def sparse_max_pool(input, size): + '''Downsample the input by considering 0 values as invalid. + + Unfortunately, no generic interpolation mode can resize a sparse map correctly, + the strategy here is to use max pooling for positive values and "min pooling" + for negative values, the two results are then summed. + This technique allows sparsity to be minized, contrary to nearest interpolation, + which could potentially lose information for isolated data points.''' + + positive = (input > 0).float() + negative = (input < 0).float() + output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d(-input * negative, size) + return output + + +def multiscaleEPE(network_output, target_flow, mask, weights=None, sparse=False, rob_loss = False): + def one_scale(output, target, mask, sparse): + + b, _, h, w = output.size() + + if sparse: + target_scaled = sparse_max_pool(target, (h, w)) + else: + target_scaled = F.interpolate(target, (h, w), mode='area') + mask = F.interpolate(mask.float().unsqueeze(1), (h, w), mode='bilinear').squeeze(1)==1 + if rob_loss: + return rob_EPE(output, target_scaled, mask, sparse, mean=False) + else: + return EPE(output, target_scaled, mask, sparse, mean=False) + + if type(network_output) not in [tuple, list]: + network_output = [network_output] + if weights is None: + weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article + assert(len(weights) == len(network_output)) + + loss = 0 + for output, weight in zip(network_output, weights): + loss += weight * one_scale(output, target_flow, mask, sparse) + return loss + + +def realEPE(output, target, mask, sparse=False): + b, _, h, w = target.size() + upsampled_output = F.interpolate(output, (h,w), mode='bilinear', align_corners=False) + return EPE(upsampled_output, target,mask, sparse, mean=True) diff --git a/expansion/utils/pfm.py b/expansion/utils/pfm.py new file mode 100755 index 0000000000000000000000000000000000000000..65a3dbb10ae96a72110294feb7cdc8df9d9c8b0d --- /dev/null +++ b/expansion/utils/pfm.py @@ -0,0 +1,79 @@ +import re +import numpy as np +import sys + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if (sys.version[0]) == '3': + header = header.decode('utf-8') + if header == 'PF': + color = True + elif header == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + if (sys.version[0]) == '3': + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) + else: + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + if (sys.version[0]) == '3': + scale = float(file.readline().rstrip().decode('utf-8')) + else: + scale = float(file.readline().rstrip()) + + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data, scale + + +def writePFM(file, image, scale=1): + file = open(file, 'wb') + + color = None + + if image.dtype.name != 'float32': + raise Exception('Image dtype must be float32.') + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale + color = False + else: + raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') + + file.write('PF\n' if color else 'Pf\n') + file.write('%d %d\n' % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == '<' or endian == '=' and sys.byteorder == 'little': + scale = -scale + + file.write('%f\n' % scale) + + image.tofile(file) diff --git a/expansion/utils/readpfm.py b/expansion/utils/readpfm.py new file mode 100755 index 0000000000000000000000000000000000000000..a79837854408d80bb033f5b794e3294b2da50897 --- /dev/null +++ b/expansion/utils/readpfm.py @@ -0,0 +1,51 @@ +import re +import numpy as np +import sys + + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if (sys.version[0]) == '3': + header = header.decode('utf-8') + if header == 'PF': + color = True + elif header == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + if (sys.version[0]) == '3': + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) + else: + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + if (sys.version[0]) == '3': + scale = float(file.readline().rstrip().decode('utf-8')) + else: + scale = float(file.readline().rstrip()) + + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data, scale + diff --git a/expansion/utils/sintel_io.py b/expansion/utils/sintel_io.py new file mode 100755 index 0000000000000000000000000000000000000000..c3633c9d08a8026e8cd652a2df263c6e893c0976 --- /dev/null +++ b/expansion/utils/sintel_io.py @@ -0,0 +1,214 @@ +#! /usr/bin/env python2 + +""" +I/O script to save and load the data coming with the MPI-Sintel low-level +computer vision benchmark. + +For more details about the benchmark, please visit www.mpi-sintel.de + +CHANGELOG: +v1.0 (2015/02/03): First release + +Copyright (c) 2015 Jonas Wulff +Max Planck Institute for Intelligent Systems, Tuebingen, Germany + +""" + +# Requirements: Numpy as PIL/Pillow +import numpy as np +from PIL import Image + +# Check for endianness, based on Daniel Scharstein's optical flow code. +# Using little-endian architecture, these two should be equal. +TAG_FLOAT = 202021.25 +TAG_CHAR = 'PIEH' + +def flow_read(filename): + """ Read optical flow from file, return (U,V) tuple. + + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + f = open(filename,'rb') + check = np.fromfile(f,dtype=np.float32,count=1)[0] + assert check == TAG_FLOAT, ' flow_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) + width = np.fromfile(f,dtype=np.int32,count=1)[0] + height = np.fromfile(f,dtype=np.int32,count=1)[0] + size = width*height + assert width > 0 and height > 0 and size > 1 and size < 100000000, ' flow_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) + tmp = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width*2)) + u = tmp[:,np.arange(width)*2] + v = tmp[:,np.arange(width)*2 + 1] + return u,v + +def flow_write(filename,uv,v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert(uv.ndim == 3) + assert(uv.shape[2] == 2) + u = uv[:,:,0] + v = uv[:,:,1] + else: + u = uv + + assert(u.shape == v.shape) + height,width = u.shape + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width*nBands)) + tmp[:,np.arange(width)*2] = u + tmp[:,np.arange(width)*2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def depth_read(filename): + """ Read depth data from file, return as numpy array. """ + f = open(filename,'rb') + check = np.fromfile(f,dtype=np.float32,count=1)[0] + assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) + width = np.fromfile(f,dtype=np.int32,count=1)[0] + height = np.fromfile(f,dtype=np.int32,count=1)[0] + size = width*height + assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) + depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width)) + return depth + +def depth_write(filename, depth): + """ Write depth to file. """ + height,width = depth.shape[:2] + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + + depth.astype(np.float32).tofile(f) + f.close() + + +def disparity_write(filename,disparity,bitdepth=16): + """ Write disparity to file. + + bitdepth can be either 16 (default) or 32. + + The maximum disparity is 1024, since the image width in Sintel + is 1024. + """ + d = disparity.copy() + + # Clip disparity. + d[d>1024] = 1024 + d[d<0] = 0 + + d_r = (d / 4.0).astype('uint8') + d_g = ((d * (2.0**6)) % 256).astype('uint8') + + out = np.zeros((d.shape[0],d.shape[1],3),dtype='uint8') + out[:,:,0] = d_r + out[:,:,1] = d_g + + if bitdepth > 16: + d_b = (d * (2**14) % 256).astype('uint8') + out[:,:,2] = d_b + + Image.fromarray(out,'RGB').save(filename,'PNG') + + +def disparity_read(filename): + """ Return disparity read from filename. """ + f_in = np.array(Image.open(filename)) + d_r = f_in[:,:,0].astype('float64') + d_g = f_in[:,:,1].astype('float64') + d_b = f_in[:,:,2].astype('float64') + + depth = d_r * 4 + d_g / (2**6) + d_b / (2**14) + return depth + + +#def cam_read(filename): +# """ Read camera data, return (M,N) tuple. +# +# M is the intrinsic matrix, N is the extrinsic matrix, so that +# +# x = M*N*X, +# with x being a point in homogeneous image pixel coordinates, X being a +# point in homogeneous world coordinates. +# """ +# txtdata = np.loadtxt(filename) +# intrinsic = txtdata[0,:9].reshape((3,3)) +# extrinsic = textdata[1,:12].reshape((3,4)) +# return intrinsic,extrinsic +# +# +#def cam_write(filename,M,N): +# """ Write intrinsic matrix M and extrinsic matrix N to file. """ +# Z = np.zeros((2,12)) +# Z[0,:9] = M.ravel() +# Z[1,:12] = N.ravel() +# np.savetxt(filename,Z) + +def cam_read(filename): + """ Read camera data, return (M,N) tuple. + + M is the intrinsic matrix, N is the extrinsic matrix, so that + + x = M*N*X, + with x being a point in homogeneous image pixel coordinates, X being a + point in homogeneous world coordinates. + """ + f = open(filename,'rb') + check = np.fromfile(f,dtype=np.float32,count=1)[0] + assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) + M = np.fromfile(f,dtype='float64',count=9).reshape((3,3)) + N = np.fromfile(f,dtype='float64',count=12).reshape((3,4)) + return M,N + +def cam_write(filename, M, N): + """ Write intrinsic matrix M and extrinsic matrix N to file. """ + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + M.astype('float64').tofile(f) + N.astype('float64').tofile(f) + f.close() + + +def segmentation_write(filename,segmentation): + """ Write segmentation to file. """ + + segmentation_ = segmentation.astype('int32') + seg_r = np.floor(segmentation_ / (256**2)).astype('uint8') + seg_g = np.floor((segmentation_ % (256**2)) / 256).astype('uint8') + seg_b = np.floor(segmentation_ % 256).astype('uint8') + + out = np.zeros((segmentation.shape[0],segmentation.shape[1],3),dtype='uint8') + out[:,:,0] = seg_r + out[:,:,1] = seg_g + out[:,:,2] = seg_b + + Image.fromarray(out,'RGB').save(filename,'PNG') + + +def segmentation_read(filename): + """ Return disparity read from filename. """ + f_in = np.array(Image.open(filename)) + seg_r = f_in[:,:,0].astype('int32') + seg_g = f_in[:,:,1].astype('int32') + seg_b = f_in[:,:,2].astype('int32') + + segmentation = (seg_r * 256 + seg_g) * 256 + seg_b + return segmentation + + diff --git a/expansion/utils/util_flow.py b/expansion/utils/util_flow.py new file mode 100755 index 0000000000000000000000000000000000000000..13c683370f8f2b4b6ac6b077d05b0964753821bb --- /dev/null +++ b/expansion/utils/util_flow.py @@ -0,0 +1,272 @@ +import math +import png +import struct +import array +import numpy as np +import cv2 +import pdb + +from io import * + +UNKNOWN_FLOW_THRESH = 1e9; +UNKNOWN_FLOW = 1e10; + +# Middlebury checks +TAG_STRING = 'PIEH' # use this when WRITING the file +TAG_FLOAT = 202021.25 # check for this when READING the file + +def readPFM(file): + import re + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(b'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data, scale + + +def save_pfm(file, image, scale = 1): + import sys + color = None + + if image.dtype.name != 'float32': + raise Exception('Image dtype must be float32.') + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale + color = False + else: + raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') + + file.write('PF\n' if color else 'Pf\n') + file.write('%d %d\n' % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == '<' or endian == '=' and sys.byteorder == 'little': + scale = -scale + + file.write('%f\n' % scale) + + image.tofile(file) + + +def ReadMiddleburyFloFile(path): + """ Read .FLO file as specified by Middlebury. + + Returns tuple (width, height, u, v, mask), where u, v, mask are flat + arrays of values. + """ + + with open(path, 'rb') as fil: + tag = struct.unpack('f', fil.read(4))[0] + width = struct.unpack('i', fil.read(4))[0] + height = struct.unpack('i', fil.read(4))[0] + + assert tag == TAG_FLOAT + + #data = np.fromfile(path, dtype=np.float, count=-1) + #data = data[3:] + + fmt = 'f' * width*height*2 + data = struct.unpack(fmt, fil.read(4*width*height*2)) + + u = data[::2] + v = data[1::2] + + mask = map(lambda x,y: abs(x) 0: + # print(u[ind], v[ind], mask[ind], row[3*x], row[3*x+1], row[3*x+2]) + + #png_reader.close() + + return (width, height, u, v, mask) + + +def WriteMiddleburyFloFile(path, width, height, u, v, mask=None): + """ Write .FLO file as specified by Middlebury. + """ + + if mask is not None: + u_masked = map(lambda x,y: x if y else UNKNOWN_FLOW, u, mask) + v_masked = map(lambda x,y: x if y else UNKNOWN_FLOW, v, mask) + else: + u_masked = u + v_masked = v + + fmt = 'f' * width*height*2 + # Interleave lists + data = [x for t in zip(u_masked,v_masked) for x in t] + + with open(path, 'wb') as fil: + fil.write(str.encode(TAG_STRING)) + fil.write(struct.pack('i', width)) + fil.write(struct.pack('i', height)) + fil.write(struct.pack(fmt, *data)) + + +def write_flow(path,flow): + + invalid_idx = (flow[:, :, 2] == 0) + flow[:, :, 0:2] = flow[:, :, 0:2]*64.+ 2 ** 15 + flow[invalid_idx, 0] = 0 + flow[invalid_idx, 1] = 0 + + flow = flow.astype(np.uint16) + flow = cv2.imwrite(path, flow[:,:,::-1]) + + #WriteKittiPngFile(path, + # flow.shape[1], flow.shape[0], flow[:,:,0].flatten(), + # flow[:,:,1].flatten(), flow[:,:,2].flatten()) + + + +def WriteKittiPngFile(path, width, height, u, v, mask=None): + """ Write 16-bit .PNG file as specified by KITTI-2015 (flow). + + u, v are lists of float values + mask is a list of floats, denoting the *valid* pixels. + """ + + data = array.array('H',[0])*width*height*3 + + for i,(u_,v_,mask_) in enumerate(zip(u,v,mask)): + data[3*i] = int(u_*64.0+2**15) + data[3*i+1] = int(v_*64.0+2**15) + data[3*i+2] = int(mask_) + + # if mask_ > 0: + # print(data[3*i], data[3*i+1],data[3*i+2]) + + with open(path, 'wb') as png_file: + png_writer = png.Writer(width=width, height=height, bitdepth=16, compression=3, greyscale=False) + png_writer.write_array(png_file, data) + + +def ConvertMiddleburyFloToKittiPng(src_path, dest_path): + width, height, u, v, mask = ReadMiddleburyFloFile(src_path) + WriteKittiPngFile(dest_path, width, height, u, v, mask=mask) + +def ConvertKittiPngToMiddleburyFlo(src_path, dest_path): + width, height, u, v, mask = ReadKittiPngFile(src_path) + WriteMiddleburyFloFile(dest_path, width, height, u, v, mask=mask) + + +def ParseFilenameKitti(filename): + # Parse kitti filename (seq_frameno.xx), + # return seq, frameno, ext. + # Be aware that seq might contain the dataset name (if contained as prefix) + ext = filename[filename.rfind('.'):] + frameno = filename[filename.rfind('_')+1:filename.rfind('.')] + frameno = int(frameno) + seq = filename[:filename.rfind('_')] + return seq, frameno, ext + + +def read_calib_file(filepath): + """Read in a calibration file and parse into a dictionary.""" + data = {} + + with open(filepath, 'r') as f: + for line in f.readlines(): + key, value = line.split(':', 1) + # The only non-float values in these files are dates, which + # we don't care about anyway + try: + data[key] = np.array([float(x) for x in value.split()]) + except ValueError: + pass + + return data + +def load_calib_cam_to_cam(cam_to_cam_file): + # We'll return the camera calibration as a dictionary + data = {} + + # Load and parse the cam-to-cam calibration data + filedata = read_calib_file(cam_to_cam_file) + + # Create 3x4 projection matrices + P_rect_00 = np.reshape(filedata['P_rect_00'], (3, 4)) + P_rect_10 = np.reshape(filedata['P_rect_01'], (3, 4)) + P_rect_20 = np.reshape(filedata['P_rect_02'], (3, 4)) + P_rect_30 = np.reshape(filedata['P_rect_03'], (3, 4)) + + # Compute the camera intrinsics + data['K_cam0'] = P_rect_00[0:3, 0:3] + data['K_cam1'] = P_rect_10[0:3, 0:3] + data['K_cam2'] = P_rect_20[0:3, 0:3] + data['K_cam3'] = P_rect_30[0:3, 0:3] + + data['b00'] = P_rect_00[0, 3] / P_rect_00[0, 0] + data['b10'] = P_rect_10[0, 3] / P_rect_10[0, 0] + data['b20'] = P_rect_20[0, 3] / P_rect_20[0, 0] + data['b30'] = P_rect_30[0, 3] / P_rect_30[0, 0] + + return data + diff --git a/interface/flask_app.py b/interface/flask_app.py new file mode 100755 index 0000000000000000000000000000000000000000..f4b28432fcf8439dad33abe2d5cb6b439025ea2b --- /dev/null +++ b/interface/flask_app.py @@ -0,0 +1,107 @@ +from flask import Flask, render_template, request, redirect, url_for, abort +import json + +app = Flask(__name__) + +import sys +sys.path.append(".") +sys.path.append("..") + +import argparse +from PIL import Image, ImageOps +import numpy as np +import base64 +import cv2 +from inference import demo + +def Base64ToNdarry(img_base64): + img_data = base64.b64decode(img_base64) + img_np = np.fromstring(img_data, np.uint8) + src = cv2.imdecode(img_np, cv2.IMREAD_ANYCOLOR) + + return src + +def NdarrayToBase64(dst): + result, dst_data = cv2.imencode('.png', dst) + dst_base64 = base64.b64encode(dst_data) + + return dst_base64 + +parser = argparse.ArgumentParser(description='User controllable latent transformer') +parser.add_argument('--checkpoint_path', default='pretrained_models/latent_transformer/cat.pt') +args = parser.parse_args() + +demo = demo(args.checkpoint_path) + +@app.route("/", methods=["GET", "POST"]) +#@auth.login_required +def init(): + if request.method == "GET": + input_img = demo.run() + input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode() + return render_template("index.html", filepath1=input_base64, canvas_img=input_base64, result=True) + if request.method == "POST": + if 'zi' in request.form.keys(): + input_img = demo.move(z=-0.05) + elif 'zo' in request.form.keys(): + input_img = demo.move(z=0.05) + elif 'u' in request.form.keys(): + input_img = demo.move(y=-0.5, z=-0.0) + elif 'd' in request.form.keys(): + input_img = demo.move(y=0.5, z=-0.0) + elif 'l' in request.form.keys(): + input_img = demo.move(x=-0.5, z=-0.0) + elif 'r' in request.form.keys(): + input_img = demo.move(x=0.5, z=-0.0) + else: + input_img = demo.run() + + input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode() + return render_template("index.html", filepath1=input_base64, canvas_img=input_base64, result=True) + +@app.route('/zoom', methods=["POST"]) +def zoom_func(): + + dz = json.loads(request.form['dz']) + sx = json.loads(request.form['sx']) + sy = json.loads(request.form['sy']) + stop_points = json.loads(request.form['stop_points']) + + input_img = demo.zoom(dz,sxsy=[sx,sy],stop_points=stop_points) + input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode() + res = {'img':input_base64} + return json.dumps(res) + +@app.route('/translate', methods=["POST"]) +def translate_func(): + + dx = json.loads(request.form['dx']) + dy = json.loads(request.form['dy']) + dz = json.loads(request.form['dz']) + sx = json.loads(request.form['sx']) + sy = json.loads(request.form['sy']) + stop_points = json.loads(request.form['stop_points']) + zi = json.loads(request.form['zi']) + zo = json.loads(request.form['zo']) + + input_img = demo.translate([dx,dy],sxsy=[sx,sy],stop_points=stop_points,zoom_in=zi,zoom_out=zo) + input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode() + res = {'img':input_base64} + return json.dumps(res) + +@app.route('/changestyle', methods=["POST"]) +def changestyle_func(): + input_img = demo.change_style() + input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode() + res = {'img':input_base64} + return json.dumps(res) + +@app.route('/reset', methods=["POST"]) +def reset_func(): + input_img = demo.reset() + input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode() + res = {'img':input_base64} + return json.dumps(res) + +if __name__ == "__main__": + app.run(debug=False, host='0.0.0.0', port=8000) \ No newline at end of file diff --git a/interface/inference.py b/interface/inference.py new file mode 100755 index 0000000000000000000000000000000000000000..568f8f8be5903b53505b30ca416bfbb57b3d7287 --- /dev/null +++ b/interface/inference.py @@ -0,0 +1,117 @@ +import os +from argparse import Namespace +import numpy as np +import torch +import sys + +sys.path.append(".") +sys.path.append("..") + +from models.StyleGANControler import StyleGANControler + +class demo(): + + def __init__(self, checkpoint_path, truncation = 0.5, use_average_code_as_input = False): + self.truncation = truncation + self.use_average_code_as_input = use_average_code_as_input + ckpt = torch.load(checkpoint_path, map_location='cpu') + opts = ckpt['opts'] + opts['checkpoint_path'] = checkpoint_path + self.opts = Namespace(**ckpt['opts']) + + self.net = StyleGANControler(self.opts) + self.net.eval() + self.net.cuda() + self.target_layers = [0,1,2,3,4,5] + + self.w1 = None + self.w1_after = None + self.f1 = None + + def run(self): + z1 = torch.randn(1,512).to("cuda") + x1, self.w1, self.f1 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_feature_map=True,return_latents=True,truncation=self.truncation, truncation_latent=self.net.latent_avg[0]) + self.w1_after = self.w1.clone() + x1 = self.net.face_pool(x1) + result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] + return result + + def translate(self, dxy, sxsy=[0,0], stop_points=[], zoom_in=False, zoom_out=False): + dz = -5. if zoom_in else 0. + dz = 5. if zoom_out else dz + + dxyz = np.array([dxy[0],dxy[1],dz], dtype=np.float32) + dxy_norm = np.linalg.norm(dxyz[:2], ord=2) + dxyz[:2] = dxyz[:2]/dxy_norm + vec_num = dxy_norm/10 + + x = torch.from_numpy(np.array([[dxyz]],dtype=np.float32)).cuda() + f1 = torch.nn.functional.interpolate(self.f1, (256,256)) + y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0) + + if len(stop_points)>0: + x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1) + tmp = [] + for sp in stop_points: + tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1)) + y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1) + + if not self.use_average_code_as_input: + w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) + w1 = self.w1.clone() + w1[:,self.target_layers] = w_hat + else: + w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) + w1 = self.w1.clone() + w1[:,self.target_layers] = self.w1.clone()[:,self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers] + + x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False) + + self.w1_after = w1.clone() + x1 = self.net.face_pool(x1) + result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] + return result + + def zoom(self, dz, sxsy=[0,0], stop_points=[]): + vec_num = abs(dz)/5 + dz = 100*np.sign(dz) + x = torch.from_numpy(np.array([[[1.,0,dz]]],dtype=np.float32)).cuda() + f1 = torch.nn.functional.interpolate(self.f1, (256,256)) + y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0) + + if len(stop_points)>0: + x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1) + tmp = [] + for sp in stop_points: + tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1)) + y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1) + + if not self.use_average_code_as_input: + w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) + w1 = self.w1.clone() + w1[:,self.target_layers] = w_hat + else: + w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) + w1 = self.w1.clone() + w1[:,self.target_layers] = self.w1.clone()[:,self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers] + + + x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False) + + x1 = self.net.face_pool(x1) + result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] + return result + + def change_style(self): + z1 = torch.randn(1,512).to("cuda") + x1, w2 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_latents=True, truncation=self.truncation, truncation_latent=self.net.latent_avg[0]) + self.w1_after[:,6:] = w2.detach()[:,0] + x1, _ = self.net.decoder([self.w1_after], input_is_latent=True, randomize_noise=False, return_latents=False) + result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] + return result + + def reset(self): + x1, _ = self.net.decoder([self.w1], input_is_latent=True, randomize_noise=False, return_latents=False) + result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] + return result + \ No newline at end of file diff --git a/interface/templates/index.html b/interface/templates/index.html new file mode 100755 index 0000000000000000000000000000000000000000..422afd27e84c1b28c030a4f480776d53de9fc5ee --- /dev/null +++ b/interface/templates/index.html @@ -0,0 +1,195 @@ + + + + + + +
+ + + +
+ + +
+ +
+ +
+ + + + + + +
Mouse drag:Translation
Middle mouse button:Set anchor point
Mouse wheel:Zoom in & out
'i' or 'o' key + mouse drag:Translation with zooming in & out
's' key:style mixing
+
+ + + + \ No newline at end of file diff --git a/licenses/LICENSE_ gengshan-y_expansion b/licenses/LICENSE_ gengshan-y_expansion new file mode 100755 index 0000000000000000000000000000000000000000..e685a844ef325e623ba048fee0b6de7ad1c57553 --- /dev/null +++ b/licenses/LICENSE_ gengshan-y_expansion @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Carnegie Mellon University + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/licenses/LICENSE_HuangYG123 b/licenses/LICENSE_HuangYG123 new file mode 100755 index 0000000000000000000000000000000000000000..c539fea307a624b0941fb808d6ae3ab4db552529 --- /dev/null +++ b/licenses/LICENSE_HuangYG123 @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 HuangYG123 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE_S-aiueo32 b/licenses/LICENSE_S-aiueo32 new file mode 100755 index 0000000000000000000000000000000000000000..81e7b18bd6fcfd5a81e08d0bcb192be28cd6723c --- /dev/null +++ b/licenses/LICENSE_S-aiueo32 @@ -0,0 +1,25 @@ +BSD 2-Clause License + +Copyright (c) 2020, Sou Uchida +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE_TreB1eN b/licenses/LICENSE_TreB1eN new file mode 100755 index 0000000000000000000000000000000000000000..1c7d3585c795c41d2334036b01a8d660a5235671 --- /dev/null +++ b/licenses/LICENSE_TreB1eN @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 TreB1eN + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE_lessw2020 b/licenses/LICENSE_lessw2020 new file mode 100755 index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac --- /dev/null +++ b/licenses/LICENSE_lessw2020 @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/licenses/LICENSE_pixel2style2pixel b/licenses/LICENSE_pixel2style2pixel new file mode 100755 index 0000000000000000000000000000000000000000..272e5a8959a88dce922045124af98cf74b9ee9a4 --- /dev/null +++ b/licenses/LICENSE_pixel2style2pixel @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Elad Richardson, Yuval Alaluf + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/licenses/LICENSE_rosinality b/licenses/LICENSE_rosinality new file mode 100755 index 0000000000000000000000000000000000000000..81da3fce025084b7005be5405d3842fbea29b5ba --- /dev/null +++ b/licenses/LICENSE_rosinality @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Kim Seonghyeon + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/models/StyleGANControler.py b/models/StyleGANControler.py new file mode 100755 index 0000000000000000000000000000000000000000..b0766fd8fe51056ee548ac690c435abd554b301a --- /dev/null +++ b/models/StyleGANControler.py @@ -0,0 +1,70 @@ +import torch +from torch import nn +from models.networks import latent_transformer +from models.stylegan2.model import Generator +import numpy as np + +def get_keys(d, name): + if 'state_dict' in d: + d = d['state_dict'] + d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} + return d_filt + + +class StyleGANControler(nn.Module): + + def __init__(self, opts): + super(StyleGANControler, self).__init__() + self.set_opts(opts) + # Define architecture + + if 'ffhq' in self.opts.stylegan_weights: + self.style_num = 18 + elif 'car' in self.opts.stylegan_weights: + self.style_num = 16 + elif 'cat' in self.opts.stylegan_weights: + self.style_num = 14 + elif 'church' in self.opts.stylegan_weights: + self.style_num = 14 + elif 'anime' in self.opts.stylegan_weights: + self.style_num = 16 + + self.encoder = self.set_encoder() + if self.style_num==18: + self.decoder = Generator(1024, 512, 8, channel_multiplier=2) + elif self.style_num==16: + self.decoder = Generator(512, 512, 8, channel_multiplier=2) + elif self.style_num==14: + self.decoder = Generator(256, 512, 8, channel_multiplier=2) + self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) + + # Load weights if needed + self.load_weights() + + def set_encoder(self): + encoder = latent_transformer.Network(self.opts) + return encoder + + def load_weights(self): + if self.opts.checkpoint_path is not None: + print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path)) + ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') + self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) + self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True) + self.__load_latent_avg(ckpt) + else: + print('Loading decoder weights from pretrained!') + ckpt = torch.load(self.opts.stylegan_weights) + self.decoder.load_state_dict(ckpt['g_ema'], strict=True) + self.__load_latent_avg(ckpt, repeat=self.opts.style_num) + + def set_opts(self, opts): + self.opts = opts + + def __load_latent_avg(self, ckpt, repeat=None): + if 'latent_avg' in ckpt: + self.latent_avg = ckpt['latent_avg'].to(self.opts.device) + if repeat is not None: + self.latent_avg = self.latent_avg.repeat(repeat, 1) + else: + self.latent_avg = None diff --git a/models/__init__.py b/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/networks/__init__.py b/models/networks/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/networks/latent_transformer.py b/models/networks/latent_transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..f465ce7285397ed3460e668f1c734e02483b2903 --- /dev/null +++ b/models/networks/latent_transformer.py @@ -0,0 +1,162 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange + +# classes +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class CrossAttention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.to_k = nn.Linear(dim, inner_dim , bias=False) + self.to_v = nn.Linear(dim, inner_dim , bias = False) + self.to_q = nn.Linear(dim, inner_dim, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x_qkv, query_length=1): + h = self.heads + + k = self.to_k(x_qkv)[:, query_length:] + k = rearrange(k, 'b n (h d) -> b h n d', h = h) + + v = self.to_v(x_qkv)[:, query_length:] + v = rearrange(v, 'b n (h d) -> b h n d', h = h) + + q = self.to_q(x_qkv)[:, :query_length] + q = rearrange(q, 'b n (h d) -> b h n d', h = h) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = dots.softmax(dim=-1) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + + return out + +class TransformerEncoder(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + +class TransformerDecoder(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.pos_embedding = nn.Parameter(torch.randn(1, 6, dim)) + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, CrossAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + def forward(self, x, y): + x = x + self.pos_embedding[:, :x.shape[1]] + for sattn, cattn, ff in self.layers: + x = sattn(x) + x + xy = torch.cat((x,y), dim=1) + x = cattn(xy, query_length=x.shape[1]) + x + x = ff(x) + x + return x + +class Network(nn.Module): + def __init__(self, opts): + super(Network, self).__init__() + + self.transformer_encoder = TransformerEncoder(dim=512, depth=6, heads=8, dim_head=64, mlp_dim=512, dropout=0) + self.transformer_decoder = TransformerDecoder(dim=512, depth=6, heads=8, dim_head=64, mlp_dim=512, dropout=0) + self.layer1 = nn.Linear(3, 256) + self.layer2 = nn.Linear(512, 256) + self.layer3 = nn.Linear(512, 512) + self.layer4 = nn.Linear(512, 512) + self.mlp_head = nn.Sequential( + nn.Linear(512, 512) + ) + + def forward(self, w, x, y, alpha=1.): + #w: latent vectors + #x: flow vectors + #y: StyleGAN features + xh = F.relu(self.layer1(x)) + yh = F.relu(self.layer2(y)) + xyh = torch.cat([xh,yh], dim=2) + xyh = F.relu(self.layer3(xyh)) + xyh = self.transformer_encoder(xyh) + + wh = F.relu(self.layer4(w)) + + h = self.transformer_decoder(wh, xyh) + h = self.mlp_head(h) + w_hat = w+alpha*h + return w_hat diff --git a/models/stylegan2/__init__.py b/models/stylegan2/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/stylegan2/model.py b/models/stylegan2/model.py new file mode 100755 index 0000000000000000000000000000000000000000..f6e4c3441f4d12294f3961be3e7ed932a93090e8 --- /dev/null +++ b/models/stylegan2/model.py @@ -0,0 +1,714 @@ +import math +import random +import torch +from torch import nn +from torch.nn import functional as F + +from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style, input_is_stylespace=False): + batch, in_channel, height, width = input.shape + + if not input_is_stylespace: + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out, style + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None, input_is_stylespace=False): + out, style = self.conv(input, style, input_is_stylespace=input_is_stylespace) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out, style + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None, input_is_stylespace=False): + out, style = self.conv(input, style, input_is_stylespace=input_is_stylespace) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out, style + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + return_features=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + input_is_stylespace=False, + noise=None, + randomize_noise=True, + return_feature_map=False, + return_s=False + ): + + if not input_is_latent and not input_is_stylespace: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) + ] + + if truncation < 1 and not input_is_stylespace: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if input_is_stylespace: + latent = styles[0] + elif len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + style_vector = [] + + if not input_is_stylespace: + out = self.input(latent) + out, out_style = self.conv1(out, latent[:, 0], noise=noise[0]) + style_vector.append(out_style) + + skip, out_style = self.to_rgb1(out, latent[:, 1]) + style_vector.append(out_style) + + i = 1 + else: + out = self.input(latent[0]) + out, out_style = self.conv1(out, latent[0], noise=noise[0], input_is_stylespace=input_is_stylespace) + style_vector.append(out_style) + + skip, out_style = self.to_rgb1(out, latent[1], input_is_stylespace=input_is_stylespace) + style_vector.append(out_style) + + i = 2 + + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + if not input_is_stylespace: + out, out_style1 = conv1(out, latent[:, i], noise=noise1) + out, out_style2 = conv2(out, latent[:, i + 1], noise=noise2) + skip, rgb_style = to_rgb(out, latent[:, i + 2], skip) + if i==7: + feature_map = out + style_vector.extend([out_style1, out_style2, rgb_style]) + i += 2 + else: + out, out_style1 = conv1(out, latent[i], noise=noise1, input_is_stylespace=input_is_stylespace) + out, out_style2 = conv2(out, latent[i + 1], noise=noise2, input_is_stylespace=input_is_stylespace) + skip, rgb_style = to_rgb(out, latent[i + 2], skip, input_is_stylespace=input_is_stylespace) + + style_vector.extend([out_style1, out_style2, rgb_style]) + + i += 3 + + image = skip + + if return_feature_map: + if return_latents: + return image, latent, feature_map + elif return_s: + return image, style_vector, feature_map + else: + return image, feature_map + + if return_latents: + return image, latent + elif return_s: + return image, style_vector + elif return_features: + return image, out + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view( + group, -1, self.stddev_feat, channel // self.stddev_feat, height, width + ) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out diff --git a/models/stylegan2/op/__init__.py b/models/stylegan2/op/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3 --- /dev/null +++ b/models/stylegan2/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/models/stylegan2/op/fused_act.py b/models/stylegan2/op/fused_act.py new file mode 100755 index 0000000000000000000000000000000000000000..76ae78e49570971ee3fe303844b2c3b3fee77fa0 --- /dev/null +++ b/models/stylegan2/op/fused_act.py @@ -0,0 +1,85 @@ +import os + +import torch +from torch import nn +from torch.autograd import Function +from torch.utils.cpp_extension import load + +module_path = os.path.dirname(__file__) +fused = load( + 'fused', + sources=[ + os.path.join(module_path, 'fused_bias_act.cpp'), + os.path.join(module_path, 'fused_bias_act_kernel.cu'), + ], +) + + +class FusedLeakyReLUFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused.fused_bias_act( + grad_output, empty, out, 3, 1, negative_slope, scale + ) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act( + gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale + ) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.negative_slope, ctx.scale + ) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/models/stylegan2/op/fused_bias_act.cpp b/models/stylegan2/op/fused_bias_act.cpp new file mode 100755 index 0000000000000000000000000000000000000000..02be898f970bcc8ea297867fcaa4e71b24b3d949 --- /dev/null +++ b/models/stylegan2/op/fused_bias_act.cpp @@ -0,0 +1,21 @@ +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} \ No newline at end of file diff --git a/models/stylegan2/op/fused_bias_act_kernel.cu b/models/stylegan2/op/fused_bias_act_kernel.cu new file mode 100755 index 0000000000000000000000000000000000000000..c9fa56fea7ede7072dc8925cfb0148f136eb85b8 --- /dev/null +++ b/models/stylegan2/op/fused_bias_act_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} \ No newline at end of file diff --git a/models/stylegan2/op/upfirdn2d.cpp b/models/stylegan2/op/upfirdn2d.cpp new file mode 100755 index 0000000000000000000000000000000000000000..d2e633dc896433c205e18bc3e455539192ff968e --- /dev/null +++ b/models/stylegan2/op/upfirdn2d.cpp @@ -0,0 +1,23 @@ +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} \ No newline at end of file diff --git a/models/stylegan2/op/upfirdn2d.py b/models/stylegan2/op/upfirdn2d.py new file mode 100755 index 0000000000000000000000000000000000000000..7bc5a1e331c2bbb1893ac748cfd0f144ff0651b4 --- /dev/null +++ b/models/stylegan2/op/upfirdn2d.py @@ -0,0 +1,184 @@ +import os + +import torch +from torch.autograd import Function +from torch.utils.cpp_extension import load + +module_path = os.path.dirname(__file__) +upfirdn2d_op = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'upfirdn2d.cpp'), + os.path.join(module_path, 'upfirdn2d_kernel.cu'), + ], +) + + +class UpFirDn2dBackward(Function): + @staticmethod + def forward( + ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size + ): + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view( + ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] + ) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 + ) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + out = UpFirDn2d.apply( + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + + return out[:, ::down_y, ::down_x, :] diff --git a/models/stylegan2/op/upfirdn2d_kernel.cu b/models/stylegan2/op/upfirdn2d_kernel.cu new file mode 100755 index 0000000000000000000000000000000000000000..2a710aa6adc3d43ac93136a1814e3c39970e1c7e --- /dev/null +++ b/models/stylegan2/op/upfirdn2d_kernel.cu @@ -0,0 +1,272 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + + +template +__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + + #pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) + #pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; + } + } + } + } +} + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; + + auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h; + int tile_out_w; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 2: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 3: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 4: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 5: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 6: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + } + }); + + return out; +} \ No newline at end of file diff --git a/options/__init__.py b/options/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/options/train_options.py b/options/train_options.py new file mode 100755 index 0000000000000000000000000000000000000000..c33c9217be048f5574bde91eb1c6ebd186a265bd --- /dev/null +++ b/options/train_options.py @@ -0,0 +1,33 @@ +from argparse import ArgumentParser + +class TrainOptions: + + def __init__(self): + self.parser = ArgumentParser() + self.initialize() + + def initialize(self): + self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') + + self.parser.add_argument('--batch_size', default=1, type=int, help='Batch size for training') + self.parser.add_argument('--learning_rate', default=0.001, type=float, help='Optimizer learning rate') + self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use') + self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model') + + self.parser.add_argument('--lpips_lambda', default=0., type=float, help='LPIPS loss multiplier factor') + self.parser.add_argument('--l2_lambda', default=0, type=float, help='L2 loss multiplier factor') + self.parser.add_argument('--l2latent_lambda', default=1.0, type=float, help='L2 loss multiplier factor') + + self.parser.add_argument('--stylegan_weights', default='pretrained_models/stylegan2-cat-config-f.pt', type=str, help='Path to StyleGAN model weights') + self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint') + + self.parser.add_argument('--max_steps', default=60100, type=int, help='Maximum number of training steps') + self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training') + self.parser.add_argument('--save_interval', default=10000, type=int, help='Model checkpoint interval') + + self.parser.add_argument('--style_num', default=14, type=int, help='The number of StyleGAN layers get latent codes ') + self.parser.add_argument('--channel_multiplier', default=2, type=int, help='StyleGAN parameter') + + def parse(self): + opts = self.parser.parse_args() + return opts \ No newline at end of file diff --git a/scripts/train.py b/scripts/train.py new file mode 100755 index 0000000000000000000000000000000000000000..7450aef4a544d3f36568d86a34870fe63ec521d3 --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,29 @@ +""" +This file runs the main training/val loop +""" +import os +import json +import sys +import pprint + +sys.path.append(".") +sys.path.append("..") + +from options.train_options import TrainOptions +from training.coach import Coach + + +def main(): + opts = TrainOptions().parse() + os.makedirs(opts.exp_dir, exist_ok=True) + + opts_dict = vars(opts) + pprint.pprint(opts_dict) + with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: + json.dump(opts_dict, f, indent=4, sort_keys=True) + + coach = Coach(opts) + coach.train() + +if __name__ == '__main__': + main() diff --git a/training/__init__.py b/training/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/training/coach.py b/training/coach.py new file mode 100755 index 0000000000000000000000000000000000000000..30e7009653b8ddecb2bc53bfd4e5aa6f1f23b6ef --- /dev/null +++ b/training/coach.py @@ -0,0 +1,211 @@ +import os +import math, random +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +matplotlib.use('Agg') + +import torch +from torch import nn +from torch.utils.tensorboard import SummaryWriter +import torch.nn.functional as F + +from utils import common +from criteria.lpips.lpips import LPIPS +from models.StyleGANControler import StyleGANControler +from training.ranger import Ranger + +from expansion.submission import Expansion +from expansion.utils.flowlib import point_vec + +class Coach: + def __init__(self, opts): + self.opts = opts + if self.opts.checkpoint_path is None: + self.global_step = 0 + else: + self.global_step = int(os.path.splitext(os.path.basename(self.opts.checkpoint_path))[0].split('_')[-1]) + + self.device = 'cuda:0' # TODO: Allow multiple GPU? currently using CUDA_VISIBLE_DEVICES + self.opts.device = self.device + + # Initialize network + self.net = StyleGANControler(self.opts).to(self.device) + + # Initialize loss + if self.opts.lpips_lambda > 0: + self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval() + self.mse_loss = nn.MSELoss().to(self.device).eval() + + # Initialize optimizer + self.optimizer = self.configure_optimizers() + + # Initialize logger + log_dir = os.path.join(opts.exp_dir, 'logs') + os.makedirs(log_dir, exist_ok=True) + self.logger = SummaryWriter(log_dir=log_dir) + + # Initialize checkpoint dir + self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints') + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.best_val_loss = None + if self.opts.save_interval is None: + self.opts.save_interval = self.opts.max_steps + + # Initialize optical flow estimator + self.ex = Expansion() + + # Set flow normalization values + if 'ffhq' in self.opts.stylegan_weights: + self.sigma_f = 4 + self.sigma_e = 0.02 + elif 'car' in self.opts.stylegan_weights: + self.sigma_f = 5 + self.sigma_e = 0.03 + elif 'cat' in self.opts.stylegan_weights: + self.sigma_f = 12 + self.sigma_e = 0.04 + elif 'church' in self.opts.stylegan_weights: + self.sigma_f = 8 + self.sigma_e = 0.02 + elif 'anime' in self.opts.stylegan_weights: + self.sigma_f = 7 + self.sigma_e = 0.025 + + def train(self, truncation = 0.3, sigma = 0.1, target_layers = [0,1,2,3,4,5]): + + x = np.array(range(0,256,16)).astype(np.float32)/127.5-1. + y = np.array(range(0,256,16)).astype(np.float32)/127.5-1. + xx, yy = np.meshgrid(x,y) + grid = np.concatenate([xx[:,:,None],yy[:,:,None]], axis=2) + grid = torch.from_numpy(grid[None,:]).cuda() + grid = grid.repeat(self.opts.batch_size,1,1,1) + + while self.global_step < self.opts.max_steps: + with torch.no_grad(): + z1 = torch.randn(self.opts.batch_size,512).to("cuda") + z2 = torch.randn(self.opts.batch_size,self.net.style_num, 512).to("cuda") + + x1, w1, f1 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_feature_map=True,return_latents=True,truncation=truncation, truncation_latent=self.net.latent_avg[0]) + x1 = self.net.face_pool(x1) + x2, w2 = self.net.decoder([z2],input_is_latent=False,randomize_noise=False,return_latents=True, truncation_latent=self.net.latent_avg[0]) + x2 = self.net.face_pool(x2) + w_mid = w1.clone() + w_mid[:,target_layers] = w_mid[:,target_layers]+sigma*(w2[:,target_layers]-w_mid[:,target_layers]) + x_mid, _ = self.net.decoder([w_mid], input_is_latent=True, randomize_noise=False, return_latents=False) + x_mid = self.net.face_pool(x_mid) + + flow, logexp = self.ex.run(x1.detach(),x_mid.detach()) + flow_feature = torch.cat([flow/self.sigma_f, logexp/self.sigma_e], dim=1) + f1 = F.interpolate(f1, (flow_feature.shape[2:])) + f1 = F.grid_sample(f1, grid, mode='nearest', align_corners=True) + flow_feature = F.grid_sample(flow_feature, grid, mode='nearest', align_corners=True) + flow_feature = flow_feature.view(flow_feature.shape[0], flow_feature.shape[1], -1).permute(0,2,1) + f1 = f1.view(f1.shape[0], f1.shape[1], -1).permute(0,2,1) + + self.net.train() + self.optimizer.zero_grad() + w_hat = self.net.encoder(w1[:,target_layers].detach(), flow_feature.detach(), f1.detach()) + loss, loss_dict, id_logs = self.calc_loss(w_hat, w_mid[:,target_layers].detach()) + loss.backward() + self.optimizer.step() + + w_mid[:,target_layers] = w_hat.detach() + x_hat, _ = self.net.decoder([w_mid], input_is_latent=True, randomize_noise=False) + x_hat = self.net.face_pool(x_hat) + if self.global_step % self.opts.image_interval == 0 or ( + self.global_step < 1000 and self.global_step % 100 == 0): + imgL_o = ((x1.detach()+1.)*127.5)[0].permute(1,2,0).cpu().numpy() + flow = torch.cat((flow,torch.ones_like(flow)[:,:1]), dim=1)[0].permute(1,2,0).cpu().numpy() + flowvis = point_vec(imgL_o, flow) + flowvis = torch.from_numpy(flowvis[:,:,::-1].copy()).permute(2,0,1).unsqueeze(0)/127.5-1. + self.parse_and_log_images(None, flowvis, x_mid, x_hat, title='trained_images') + print(loss_dict) + + if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps: + self.checkpoint_me(loss_dict, is_best=False) + + if self.global_step == self.opts.max_steps: + print('OMG, finished training!') + break + + self.global_step += 1 + + def checkpoint_me(self, loss_dict, is_best): + save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step) + save_dict = self.__get_save_dict() + checkpoint_path = os.path.join(self.checkpoint_dir, save_name) + torch.save(save_dict, checkpoint_path) + with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f: + if is_best: + f.write('**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict)) + else: + f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict)) + + def configure_optimizers(self): + params = list(self.net.encoder.parameters()) + if self.opts.train_decoder: + params += list(self.net.decoder.parameters()) + if self.opts.optim_name == 'adam': + optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate) + else: + optimizer = Ranger(params, lr=self.opts.learning_rate) + return optimizer + + def calc_loss(self, latent, w, y_hat=None, y=None): + loss_dict = {} + loss = 0.0 + id_logs = None + + if self.opts.l2_lambda > 0 and (y_hat is not None) and (y is not None): + loss_l2 = F.mse_loss(y_hat, y) + loss_dict['loss_l2'] = float(loss_l2) + loss += loss_l2 * self.opts.l2_lambda + if self.opts.lpips_lambda > 0 and (y_hat is not None) and (y is not None): + loss_lpips = self.lpips_loss(y_hat, y) + loss_dict['loss_lpips'] = float(loss_lpips) + loss += loss_lpips * self.opts.lpips_lambda + if self.opts.l2latent_lambda > 0: + loss_l2 = F.mse_loss(latent, w) + loss_dict['loss_l2latent'] = float(loss_l2) + loss += loss_l2 * self.opts.l2latent_lambda + + loss_dict['loss'] = float(loss) + return loss, loss_dict, id_logs + + def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=1): + im_data = [] + for i in range(display_count): + cur_im_data = { + 'input_face': common.tensor2im(x[i]), + 'target_face': common.tensor2im(y[i]), + 'output_face': common.tensor2im(y_hat[i]), + } + if id_logs is not None: + for key in id_logs[i]: + cur_im_data[key] = id_logs[i][key] + im_data.append(cur_im_data) + self.log_images(title, im_data=im_data, subscript=subscript) + + + def log_images(self, name, im_data, subscript=None, log_latest=False): + fig = common.vis_faces(im_data) + step = self.global_step + if log_latest: + step = 0 + if subscript: + path = os.path.join(self.logger.log_dir, name, '{}_{:04d}.jpg'.format(subscript, step)) + else: + path = os.path.join(self.logger.log_dir, name, '{:04d}.jpg'.format(step)) + os.makedirs(os.path.dirname(path), exist_ok=True) + fig.savefig(path) + plt.close(fig) + + def __get_save_dict(self): + save_dict = { + 'state_dict': self.net.state_dict(), + 'opts': vars(self.opts) + } + + save_dict['latent_avg'] = self.net.latent_avg + return save_dict \ No newline at end of file diff --git a/training/ranger.py b/training/ranger.py new file mode 100755 index 0000000000000000000000000000000000000000..441cded317542a82229ddfbd3073639f8d3e706b --- /dev/null +++ b/training/ranger.py @@ -0,0 +1,164 @@ +# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. + +# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer +# and/or +# https://github.com/lessw2020/Best-Deep-Learning-Optimizers + +# Ranger has now been used to capture 12 records on the FastAI leaderboard. + +# This version = 20.4.11 + +# Credits: +# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization +# RAdam --> https://github.com/LiyuanLucasLiu/RAdam +# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. +# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 + +# summary of changes: +# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. +# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), +# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. +# changes 8/31/19 - fix references to *self*.N_sma_threshold; +# changed eps to 1e-5 as better default than 1e-8. + +import math +import torch +from torch.optim.optimizer import Optimizer + + +class Ranger(Optimizer): + + def __init__(self, params, lr=1e-3, # lr + alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options + betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options + use_gc=True, gc_conv_only=False + # Gradient centralization on or off, applied to conv layers only or conv + fc layers + ): + + # parameter checks + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + if not lr > 0: + raise ValueError(f'Invalid Learning Rate: {lr}') + if not eps > 0: + raise ValueError(f'Invalid eps: {eps}') + + # parameter comments: + # beta1 (momentum) of .95 seems to work better than .90... + # N_sma_threshold of 5 seems better in testing than 4. + # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. + + # prep defaults and init torch.optim base + defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, + eps=eps, weight_decay=weight_decay) + super().__init__(params, defaults) + + # adjustable threshold + self.N_sma_threshhold = N_sma_threshhold + + # look ahead params + + self.alpha = alpha + self.k = k + + # radam buffer for state + self.radam_buffer = [[None, None, None] for ind in range(10)] + + # gc on or off + self.use_gc = use_gc + + # level of gradient centralization + self.gc_gradient_threshold = 3 if gc_conv_only else 1 + + def __setstate__(self, state): + super(Ranger, self).__setstate__(state) + + def step(self, closure=None): + loss = None + + # Evaluate averages and grad, update param tensors + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + + if grad.is_sparse: + raise RuntimeError('Ranger optimizer does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] # get state dict for this param + + if len(state) == 0: # if first time to run...init dictionary with our desired entries + # if self.first_run_check==0: + # self.first_run_check=1 + # print("Initializing slow buffer...should not see this at load from saved model!") + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + + # look ahead weight storage now in state dict + state['slow_buffer'] = torch.empty_like(p.data) + state['slow_buffer'].copy_(p.data) + + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + # begin computations + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + # GC operation for Conv layers and FC layers + if grad.dim() > self.gc_gradient_threshold: + grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) + + state['step'] += 1 + + # compute variance mov avg + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + # compute mean moving avg + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + buffered = self.radam_buffer[int(state['step'] % 10)] + + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + if N_sma > self.N_sma_threshhold: + step_size = math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = 1.0 / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # apply lr + if N_sma > self.N_sma_threshhold: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) + else: + p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr']) + + p.data.copy_(p_data_fp32) + + # integrated look ahead... + # we do it at the param level instead of group level + if state['step'] % group['k'] == 0: + slow_p = state['slow_buffer'] # get access to slow param tensor + slow_p.add_(p.data - slow_p, alpha=self.alpha) # (fast weights - slow weights) * alpha + p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor + + return loss \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/common.py b/utils/common.py new file mode 100755 index 0000000000000000000000000000000000000000..5b11560ef7538dd12ec7d1d51ab650efb3f15ea4 --- /dev/null +++ b/utils/common.py @@ -0,0 +1,93 @@ +import cv2 +import numpy as np +from PIL import Image +import matplotlib.pyplot as plt +import random + + +# Log images +def log_input_image(x, opts): + if opts.label_nc == 0: + return tensor2im(x) + elif opts.label_nc == 1: + return tensor2sketch(x) + else: + return tensor2map(x) + + +def tensor2im(var): + var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() + var = ((var + 1) / 2) + var[var < 0] = 0 + var[var > 1] = 1 + var = var * 255 + return Image.fromarray(var.astype('uint8')) + + +def tensor2map(var): + mask = np.argmax(var.data.cpu().numpy(), axis=0) + colors = get_colors() + mask_image = np.ones(shape=(mask.shape[0], mask.shape[1], 3)) + for class_idx in np.unique(mask): + mask_image[mask == class_idx] = colors[class_idx] + mask_image = mask_image.astype('uint8') + return Image.fromarray(mask_image) + + +def tensor2sketch(var): + im = var[0].cpu().detach().numpy() + im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) + im = (im * 255).astype(np.uint8) + return Image.fromarray(im) + + +# Visualization utils +def get_colors(): + # currently support up to 19 classes (for the celebs-hq-mask dataset) + colors = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], + [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], + [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] + + # asign random colors to more 200 classes + random.seed(0) + for i in range(200): + colors.append([random.randint(0,255),random.randint(0,255),random.randint(0,255)]) + return colors + + +def vis_faces(log_hooks): + display_count = len(log_hooks) + fig = plt.figure(figsize=(8, 4 * display_count)) + gs = fig.add_gridspec(display_count, 3) + for i in range(display_count): + hooks_dict = log_hooks[i] + fig.add_subplot(gs[i, 0]) + if 'diff_input' in hooks_dict: + vis_faces_with_id(hooks_dict, fig, gs, i) + else: + vis_faces_no_id(hooks_dict, fig, gs, i) + plt.tight_layout() + return fig + + +def vis_faces_with_id(hooks_dict, fig, gs, i): + plt.imshow(hooks_dict['input_face']) + plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input']))) + fig.add_subplot(gs[i, 1]) + plt.imshow(hooks_dict['target_face']) + plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']), + float(hooks_dict['diff_target']))) + fig.add_subplot(gs[i, 2]) + plt.imshow(hooks_dict['output_face']) + plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target']))) + + +def vis_faces_no_id(hooks_dict, fig, gs, i): + plt.imshow(hooks_dict['input_face'], cmap="gray") + plt.title('Input') + fig.add_subplot(gs[i, 1]) + plt.imshow(hooks_dict['target_face']) + plt.title('Target') + fig.add_subplot(gs[i, 2]) + plt.imshow(hooks_dict['output_face']) + plt.title('Output')