Spaces:
Running
on
L4
Running
on
L4
init
Browse files- .gitignore +166 -0
- LICENSE +21 -0
- app.py +103 -0
- depth/README.md +63 -0
- depth/configs/base_options.py +56 -0
- depth/configs/test_options.py +27 -0
- depth/configs/train_options.py +50 -0
- depth/inference.py +53 -0
- depth/models_depth/attractor.py +208 -0
- depth/models_depth/checkpoint.py +608 -0
- depth/models_depth/dist_layers.py +121 -0
- depth/models_depth/layers.py +36 -0
- depth/models_depth/localbins_layers.py +169 -0
- depth/models_depth/miniViT.py +45 -0
- depth/models_depth/model.py +666 -0
- depth/models_depth/model_vpd.py +252 -0
- depth/models_depth/optimizer.py +154 -0
- depth/requirements.txt +8 -0
- depth/test_img.jpg +0 -0
- depth/utils.py +525 -0
- depth/utils_depth/criterion.py +22 -0
- depth/utils_depth/logging.py +161 -0
- depth/utils_depth/metrics.py +79 -0
- depth/utils_depth/misc.py +73 -0
- depth/v1-inference.yaml +70 -0
- evp/__init__.py +1 -0
- evp/models.py +349 -0
- refer/README.md +78 -0
- refer/args.py +42 -0
- refer/inference.py +60 -0
- refer/models_refer/__init__.py +1 -0
- refer/models_refer/model.py +301 -0
- refer/requirements.txt +12 -0
- refer/transforms.py +126 -0
- refer/utils.py +222 -0
- refer/v1-inference.yaml +70 -0
.gitignore
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
*.ckpt
|
| 7 |
+
*.pth
|
| 8 |
+
refer/refer/data/
|
| 9 |
+
depth/kitti_dataset/
|
| 10 |
+
depth/nyu_depth_v2/
|
| 11 |
+
|
| 12 |
+
# C extensions
|
| 13 |
+
.so
|
| 14 |
+
|
| 15 |
+
# Distribution / packaging
|
| 16 |
+
.Python
|
| 17 |
+
build/
|
| 18 |
+
develop-eggs/
|
| 19 |
+
dist/
|
| 20 |
+
downloads/
|
| 21 |
+
eggs/
|
| 22 |
+
.eggs/
|
| 23 |
+
lib/
|
| 24 |
+
lib64/
|
| 25 |
+
parts/
|
| 26 |
+
sdist/
|
| 27 |
+
var/
|
| 28 |
+
wheels/
|
| 29 |
+
share/python-wheels/
|
| 30 |
+
*.egg-info/
|
| 31 |
+
.installed.cfg
|
| 32 |
+
*.egg
|
| 33 |
+
MANIFEST
|
| 34 |
+
|
| 35 |
+
# PyInstaller
|
| 36 |
+
# Usually these files are written by a python script from a template
|
| 37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 38 |
+
*.manifest
|
| 39 |
+
*.spec
|
| 40 |
+
|
| 41 |
+
# Installer logs
|
| 42 |
+
pip-log.txt
|
| 43 |
+
pip-delete-this-directory.txt
|
| 44 |
+
|
| 45 |
+
# Unit test / coverage reports
|
| 46 |
+
htmlcov/
|
| 47 |
+
.tox/
|
| 48 |
+
.nox/
|
| 49 |
+
.coverage
|
| 50 |
+
.coverage.*
|
| 51 |
+
.cache
|
| 52 |
+
nosetests.xml
|
| 53 |
+
coverage.xml
|
| 54 |
+
*.cover
|
| 55 |
+
*.py,cover
|
| 56 |
+
.hypothesis/
|
| 57 |
+
.pytest_cache/
|
| 58 |
+
cover/
|
| 59 |
+
|
| 60 |
+
# Translations
|
| 61 |
+
*.mo
|
| 62 |
+
*.pot
|
| 63 |
+
|
| 64 |
+
# Django stuff:
|
| 65 |
+
*.log
|
| 66 |
+
local_settings.py
|
| 67 |
+
db.sqlite3
|
| 68 |
+
db.sqlite3-journal
|
| 69 |
+
|
| 70 |
+
# Flask stuff:
|
| 71 |
+
instance/
|
| 72 |
+
.webassets-cache
|
| 73 |
+
|
| 74 |
+
# Scrapy stuff:
|
| 75 |
+
.scrapy
|
| 76 |
+
|
| 77 |
+
# Sphinx documentation
|
| 78 |
+
docs/_build/
|
| 79 |
+
|
| 80 |
+
# PyBuilder
|
| 81 |
+
.pybuilder/
|
| 82 |
+
target/
|
| 83 |
+
|
| 84 |
+
# Jupyter Notebook
|
| 85 |
+
.ipynb_checkpoints
|
| 86 |
+
|
| 87 |
+
# IPython
|
| 88 |
+
profile_default/
|
| 89 |
+
ipython_config.py
|
| 90 |
+
|
| 91 |
+
# pyenv
|
| 92 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 93 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 94 |
+
# .python-version
|
| 95 |
+
|
| 96 |
+
# pipenv
|
| 97 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 98 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 99 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 100 |
+
# install all needed dependencies.
|
| 101 |
+
#Pipfile.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
|
| 110 |
+
# pdm
|
| 111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 112 |
+
#pdm.lock
|
| 113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 114 |
+
# in version control.
|
| 115 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 116 |
+
.pdm.toml
|
| 117 |
+
|
| 118 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 119 |
+
__pypackages__/
|
| 120 |
+
|
| 121 |
+
# Celery stuff
|
| 122 |
+
celerybeat-schedule
|
| 123 |
+
celerybeat.pid
|
| 124 |
+
|
| 125 |
+
# SageMath parsed files
|
| 126 |
+
*.sage.py
|
| 127 |
+
|
| 128 |
+
# Environments
|
| 129 |
+
.env
|
| 130 |
+
.venv
|
| 131 |
+
env/
|
| 132 |
+
venv/
|
| 133 |
+
ENV/
|
| 134 |
+
env.bak/
|
| 135 |
+
venv.bak/
|
| 136 |
+
|
| 137 |
+
# Spyder project settings
|
| 138 |
+
.spyderproject
|
| 139 |
+
.spyproject
|
| 140 |
+
|
| 141 |
+
# Rope project settings
|
| 142 |
+
.ropeproject
|
| 143 |
+
|
| 144 |
+
# mkdocs documentation
|
| 145 |
+
/site
|
| 146 |
+
|
| 147 |
+
# mypy
|
| 148 |
+
.mypy_cache/
|
| 149 |
+
.dmypy.json
|
| 150 |
+
dmypy.json
|
| 151 |
+
|
| 152 |
+
# Pyre type checker
|
| 153 |
+
.pyre/
|
| 154 |
+
|
| 155 |
+
# pytype static type analyzer
|
| 156 |
+
.pytype/
|
| 157 |
+
|
| 158 |
+
# Cython debug symbols
|
| 159 |
+
cython_debug/
|
| 160 |
+
|
| 161 |
+
# PyCharm
|
| 162 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 163 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 164 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 165 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 166 |
+
#.idea/
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 Mykola Lavreniuk
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
app.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
depth_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), 'depth'))
|
| 5 |
+
sys.path.append(depth_directory)
|
| 6 |
+
os.chdir(depth_directory)
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.backends.cudnn as cudnn
|
| 12 |
+
from depth.models_depth.model import EVPDepth
|
| 13 |
+
from depth.configs.train_options import TrainOptions
|
| 14 |
+
from depth.configs.test_options import TestOptions
|
| 15 |
+
import glob
|
| 16 |
+
import utils
|
| 17 |
+
import torchvision.transforms as transforms
|
| 18 |
+
from utils_depth.misc import colorize
|
| 19 |
+
from PIL import Image
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import gradio as gr
|
| 22 |
+
import tempfile
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
css = """
|
| 26 |
+
#img-display-container {
|
| 27 |
+
max-height: 50vh;
|
| 28 |
+
}
|
| 29 |
+
#img-display-input {
|
| 30 |
+
max-height: 40vh;
|
| 31 |
+
}
|
| 32 |
+
#img-display-output {
|
| 33 |
+
max-height: 40vh;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def create_demo(model, device):
|
| 39 |
+
gr.Markdown("### Depth Prediction demo")
|
| 40 |
+
with gr.Row():
|
| 41 |
+
input_image = gr.Image(label="Input Image", type='pil', elem_id='img-display-input')
|
| 42 |
+
depth_image = gr.Image(label="Depth Map", elem_id='img-display-output')
|
| 43 |
+
raw_file = gr.File(label="16-bit raw depth, multiplier:256")
|
| 44 |
+
submit = gr.Button("Submit")
|
| 45 |
+
|
| 46 |
+
def on_submit(image):
|
| 47 |
+
transform = transforms.ToTensor()
|
| 48 |
+
image = transform(image).unsqueeze(0).to(device)
|
| 49 |
+
shape = image.shape
|
| 50 |
+
image = torch.nn.functional.interpolate(image, (440,480), mode='bilinear', align_corners=True)
|
| 51 |
+
image = F.pad(image, (0, 0, 40, 0))
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
pred = model(image)['pred_d']
|
| 54 |
+
|
| 55 |
+
pred = pred[:,:,40:,:]
|
| 56 |
+
pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
|
| 57 |
+
pred_d_numpy = pred.squeeze().cpu().numpy()
|
| 58 |
+
colored_depth, _, _ = colorize(pred_d_numpy, cmap='gray_r')
|
| 59 |
+
|
| 60 |
+
tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 61 |
+
raw_depth = Image.fromarray((pred_d_numpy*256).astype('uint16'))
|
| 62 |
+
raw_depth.save(tmp.name)
|
| 63 |
+
return [colored_depth, tmp.name]
|
| 64 |
+
|
| 65 |
+
submit.click(on_submit, inputs=[input_image], outputs=[depth_image, raw_file])
|
| 66 |
+
examples = gr.Examples(examples=["test_img.jpg"],
|
| 67 |
+
inputs=[input_image])
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def main():
|
| 71 |
+
opt = TestOptions().initialize()
|
| 72 |
+
opt.add_argument('--img_path', type=str)
|
| 73 |
+
args = opt.parse_args()
|
| 74 |
+
|
| 75 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 76 |
+
model = EVPDepth(args=args, caption_aggregation=True)
|
| 77 |
+
cudnn.benchmark = True
|
| 78 |
+
model.to(device)
|
| 79 |
+
model_weight = torch.load(args.ckpt_dir)['model']
|
| 80 |
+
if 'module' in next(iter(model_weight.items()))[0]:
|
| 81 |
+
model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
|
| 82 |
+
model.load_state_dict(model_weight, strict=False)
|
| 83 |
+
model.eval()
|
| 84 |
+
|
| 85 |
+
title = "# EVP"
|
| 86 |
+
description = """Official demo for **EVP: Enhanced Visual Perception using Inverse Multi-Attentive Feature
|
| 87 |
+
Refinement and Regularized Image-Text Alignment**.
|
| 88 |
+
EVP is a deep learning model for metric depth estimation from a single image.
|
| 89 |
+
Please refer to our [paper](https://arxiv.org/abs/2312.08548) or [github](https://github.com/Lavreniuk/EVP) for more details."""
|
| 90 |
+
|
| 91 |
+
with gr.Blocks() as demo:
|
| 92 |
+
gr.Markdown(title)
|
| 93 |
+
gr.Markdown(description)
|
| 94 |
+
with gr.Tab("Depth Prediction"):
|
| 95 |
+
create_demo(model, device)
|
| 96 |
+
gr.HTML('''<br><br><br><center>You can duplicate this Space to skip the queue:<a href="https://huggingface.co/spaces/shariqfarooq/ZoeDepth?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><br>
|
| 97 |
+
<p><img src="https://visitor-badge.glitch.me/badge?page_id=shariqfarooq.zoedepth_demo_hf" alt="visitors"></p></center>''')
|
| 98 |
+
|
| 99 |
+
demo.queue().launch(share=True)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == '__main__':
|
| 103 |
+
main()
|
depth/README.md
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Depth Estimation
|
| 2 |
+
## Getting Started
|
| 3 |
+
|
| 4 |
+
1. Install the [mmcv-full](https://github.com/open-mmlab/mmcv) library and some required packages.
|
| 5 |
+
|
| 6 |
+
```bash
|
| 7 |
+
pip install openmim
|
| 8 |
+
mim install mmcv-full
|
| 9 |
+
pip install -r requirements.txt
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
2. Prepare NYUDepthV2 datasets following [GLPDepth](https://github.com/vinvino02/GLPDepth) and [BTS](https://github.com/cleinc/bts/tree/master).
|
| 13 |
+
|
| 14 |
+
```
|
| 15 |
+
mkdir nyu_depth_v2
|
| 16 |
+
wget http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat
|
| 17 |
+
python extract_official_train_test_set_from_mat.py nyu_depth_v2_labeled.mat splits.mat ./nyu_depth_v2/official_splits/
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
Download sync.zip provided by the authors of BTS from this [url](https://drive.google.com/file/d/1AysroWpfISmm-yRFGBgFTrLy6FjQwvwP/view) and unzip in `./nyu_depth_v2` folder.
|
| 21 |
+
|
| 22 |
+
Your dataset directory should be:
|
| 23 |
+
|
| 24 |
+
```
|
| 25 |
+
│nyu_depth_v2/
|
| 26 |
+
├──official_splits/
|
| 27 |
+
│ ├── test
|
| 28 |
+
│ ├── train
|
| 29 |
+
├──sync/
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
## Results and Fine-tuned Models
|
| 33 |
+
|
| 34 |
+
EVP obtains 0.224 RMSE on NYUv2 depth estimation benchmark, establishing the new state-of-the-art.
|
| 35 |
+
|
| 36 |
+
| | RMSE | d1 | d2 | d3 | REL | log_10 |
|
| 37 |
+
|---------|-------|-------|--------|------|-------|-------|
|
| 38 |
+
| **EVP** | 0.224 | 0.976 | 0.997 | 0.999 | 0.061 | 0.027 |
|
| 39 |
+
|
| 40 |
+
EVP obtains 0.048 REL and 0.136 SqREL on KITTI depth estimation benchmark, establishing the new state-of-the-art.
|
| 41 |
+
|
| 42 |
+
| | REL | SqREL | RMSE | RMSE log | d1 | d2 | d3 |
|
| 43 |
+
|---------|-------|-------|--------|------|-------|-------|-------|
|
| 44 |
+
| **EVP** | 0.048 | 0.136 | 2.015 | 0.073 | 0.980 | 0.998 | 1.000 |
|
| 45 |
+
|
| 46 |
+
## Training
|
| 47 |
+
|
| 48 |
+
Run the following instuction to train the EVP-Depth model.
|
| 49 |
+
|
| 50 |
+
```
|
| 51 |
+
bash train.sh <LOG_DIR>
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Evaluation
|
| 55 |
+
Command format:
|
| 56 |
+
```
|
| 57 |
+
bash test.sh <CHECKPOINT_PATH>
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
## Custom inference
|
| 61 |
+
```
|
| 62 |
+
PYTHONPATH="../":$PYTHONPATH python inference.py --img_path test_img.jpg --ckpt_dir nyu.ckpt
|
| 63 |
+
```
|
depth/configs/base_options.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
|
| 3 |
+
# For non-commercial purpose only (research, evaluation etc).
|
| 4 |
+
# Modified by Zigang Geng ([email protected]).
|
| 5 |
+
# ------------------------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def str2bool(v):
|
| 11 |
+
if isinstance(v, bool):
|
| 12 |
+
return v
|
| 13 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 14 |
+
return True
|
| 15 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 16 |
+
return False
|
| 17 |
+
else:
|
| 18 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BaseOptions():
|
| 22 |
+
def __init__(self):
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
def initialize(self):
|
| 26 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 27 |
+
# base configs
|
| 28 |
+
parser.add_argument('--resume_from', type=str, default='')
|
| 29 |
+
parser.add_argument('--exp_name', type=str, default='')
|
| 30 |
+
parser.add_argument('--gpu_or_cpu', type=str, default='gpu')
|
| 31 |
+
parser.add_argument('--data_path', type=str, default='/data/ssd1/')
|
| 32 |
+
parser.add_argument('--dataset', type=str, default='nyudepthv2',
|
| 33 |
+
choices=['nyudepthv2', 'kitti', 'imagepath'])
|
| 34 |
+
parser.add_argument('--batch_size', type=int, default=8)
|
| 35 |
+
parser.add_argument('--workers', type=int, default=8)
|
| 36 |
+
|
| 37 |
+
# depth configs
|
| 38 |
+
parser.add_argument('--max_depth', type=float, default=10.0)
|
| 39 |
+
parser.add_argument('--max_depth_eval', type=float, default=10.0)
|
| 40 |
+
parser.add_argument('--min_depth_eval', type=float, default=1e-3)
|
| 41 |
+
parser.add_argument('--do_kb_crop', type=int, default=1)
|
| 42 |
+
parser.add_argument('--kitti_crop', type=str, default=None,
|
| 43 |
+
choices=['garg_crop', 'eigen_crop'])
|
| 44 |
+
|
| 45 |
+
parser.add_argument('--pretrained', type=str, default='')
|
| 46 |
+
parser.add_argument('--drop_path_rate', type=float, default=0.3)
|
| 47 |
+
parser.add_argument('--use_checkpoint', type=str2bool, default='False')
|
| 48 |
+
parser.add_argument('--num_deconv', type=int, default=3)
|
| 49 |
+
parser.add_argument('--num_filters', nargs='+', type=int, default=[32,32,32])
|
| 50 |
+
parser.add_argument('--deconv_kernels', nargs='+', type=int, default=[2,2,2])
|
| 51 |
+
|
| 52 |
+
parser.add_argument('--shift_window_test', action='store_true')
|
| 53 |
+
parser.add_argument('--shift_size', type=int, default=2)
|
| 54 |
+
parser.add_argument('--flip_test', action='store_true')
|
| 55 |
+
|
| 56 |
+
return parser
|
depth/configs/test_options.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
|
| 3 |
+
# For non-commercial purpose only (research, evaluation etc).
|
| 4 |
+
# ------------------------------------------------------------------------------
|
| 5 |
+
|
| 6 |
+
from configs.base_options import BaseOptions
|
| 7 |
+
|
| 8 |
+
class TestOptions(BaseOptions):
|
| 9 |
+
def initialize(self):
|
| 10 |
+
parser = BaseOptions.initialize(self)
|
| 11 |
+
|
| 12 |
+
# experiment configs
|
| 13 |
+
parser.add_argument('--ckpt_dir', type=str,
|
| 14 |
+
default='./ckpt/best_model_nyu.ckpt',
|
| 15 |
+
help='load ckpt path')
|
| 16 |
+
parser.add_argument('--result_dir', type=str, default='./results',
|
| 17 |
+
help='save result images into result_dir/exp_name')
|
| 18 |
+
parser.add_argument('--crop_h', type=int, default=448)
|
| 19 |
+
parser.add_argument('--crop_w', type=int, default=576)
|
| 20 |
+
|
| 21 |
+
parser.add_argument('--save_eval_pngs', action='store_true',
|
| 22 |
+
help='save result image into evaluation form')
|
| 23 |
+
parser.add_argument('--save_visualize', action='store_true',
|
| 24 |
+
help='save result image into visulized form')
|
| 25 |
+
return parser
|
| 26 |
+
|
| 27 |
+
|
depth/configs/train_options.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
|
| 3 |
+
# For non-commercial purpose only (research, evaluation etc).
|
| 4 |
+
# Modified by Zigang Geng ([email protected]).
|
| 5 |
+
# ------------------------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
from configs.base_options import BaseOptions
|
| 8 |
+
import argparse
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def str2bool(v):
|
| 12 |
+
if isinstance(v, bool):
|
| 13 |
+
return v
|
| 14 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 15 |
+
return True
|
| 16 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 17 |
+
return False
|
| 18 |
+
else:
|
| 19 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TrainOptions(BaseOptions):
|
| 23 |
+
def initialize(self):
|
| 24 |
+
parser = BaseOptions.initialize(self)
|
| 25 |
+
|
| 26 |
+
# experiment configs
|
| 27 |
+
parser.add_argument('--epochs', type=int, default=25)
|
| 28 |
+
parser.add_argument('--max_lr', type=float, default=5e-4)
|
| 29 |
+
parser.add_argument('--min_lr', type=float, default=3e-5)
|
| 30 |
+
parser.add_argument('--weight_decay', type=float, default=5e-2)
|
| 31 |
+
parser.add_argument('--layer_decay', type=float, default=0.9)
|
| 32 |
+
|
| 33 |
+
parser.add_argument('--crop_h', type=int, default=448)
|
| 34 |
+
parser.add_argument('--crop_w', type=int, default=576)
|
| 35 |
+
parser.add_argument('--log_dir', type=str, default='./logs')
|
| 36 |
+
|
| 37 |
+
# logging options
|
| 38 |
+
parser.add_argument('--val_freq', type=int, default=1)
|
| 39 |
+
parser.add_argument('--pro_bar', type=str2bool, default='False')
|
| 40 |
+
parser.add_argument('--save_freq', type=int, default=1)
|
| 41 |
+
parser.add_argument('--print_freq', type=int, default=100)
|
| 42 |
+
parser.add_argument('--save_model', action='store_true')
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
'--resume-from', help='the checkpoint file to resume from')
|
| 45 |
+
parser.add_argument('--auto_resume', action='store_true')
|
| 46 |
+
parser.add_argument('--save_result', action='store_true')
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
return parser
|
depth/inference.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.backends.cudnn as cudnn
|
| 6 |
+
from models_depth.model import EVPDepth
|
| 7 |
+
from configs.train_options import TrainOptions
|
| 8 |
+
from configs.test_options import TestOptions
|
| 9 |
+
import glob
|
| 10 |
+
import utils
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
from utils_depth.misc import colorize
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def main():
|
| 18 |
+
opt = TestOptions().initialize()
|
| 19 |
+
opt.add_argument('--img_path', type=str)
|
| 20 |
+
args = opt.parse_args()
|
| 21 |
+
|
| 22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
model = EVPDepth(args=args, caption_aggregation=True)
|
| 24 |
+
cudnn.benchmark = True
|
| 25 |
+
model.to(device)
|
| 26 |
+
model_weight = torch.load(args.ckpt_dir)['model']
|
| 27 |
+
if 'module' in next(iter(model_weight.items()))[0]:
|
| 28 |
+
model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
|
| 29 |
+
model.load_state_dict(model_weight, strict=False)
|
| 30 |
+
model.eval()
|
| 31 |
+
|
| 32 |
+
img_path = args.img_path
|
| 33 |
+
image = cv2.imread(img_path)
|
| 34 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 35 |
+
transform = transforms.ToTensor()
|
| 36 |
+
image = transform(image).unsqueeze(0).to(device)
|
| 37 |
+
shape = image.shape
|
| 38 |
+
image = torch.nn.functional.interpolate(image, (440,480), mode='bilinear', align_corners=True)
|
| 39 |
+
image = F.pad(image, (0, 0, 40, 0))
|
| 40 |
+
|
| 41 |
+
with torch.no_grad():
|
| 42 |
+
pred = model(image)['pred_d']
|
| 43 |
+
|
| 44 |
+
pred = pred[:,:,40:,:]
|
| 45 |
+
pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
|
| 46 |
+
pred_d_numpy = pred.squeeze().cpu().numpy()
|
| 47 |
+
pred_d_color, _, _ = colorize(pred_d_numpy, cmap='gray_r')
|
| 48 |
+
Image.fromarray(pred_d_color).save('res.png')
|
| 49 |
+
|
| 50 |
+
return 0
|
| 51 |
+
|
| 52 |
+
if __name__ == '__main__':
|
| 53 |
+
main()
|
depth/models_depth/attractor.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
| 4 |
+
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# File author: Shariq Farooq Bhat
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@torch.jit.script
|
| 30 |
+
def exp_attractor(dx, alpha: float = 300, gamma: int = 2):
|
| 31 |
+
"""Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
|
| 35 |
+
alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300.
|
| 36 |
+
gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc
|
| 40 |
+
"""
|
| 41 |
+
return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@torch.jit.script
|
| 45 |
+
def inv_attractor(dx, alpha: float = 300, gamma: int = 2):
|
| 46 |
+
"""Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center
|
| 47 |
+
This is the default one according to the accompanying paper.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
|
| 51 |
+
alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300.
|
| 52 |
+
gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc
|
| 56 |
+
"""
|
| 57 |
+
return dx.div(1+alpha*dx.pow(gamma))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class AttractorLayer(nn.Module):
|
| 61 |
+
def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10,
|
| 62 |
+
alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False):
|
| 63 |
+
"""
|
| 64 |
+
Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth)
|
| 65 |
+
"""
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
self.n_attractors = n_attractors
|
| 69 |
+
self.n_bins = n_bins
|
| 70 |
+
self.min_depth = min_depth
|
| 71 |
+
self.max_depth = max_depth
|
| 72 |
+
self.alpha = alpha
|
| 73 |
+
self.gamma = gamma
|
| 74 |
+
self.kind = kind
|
| 75 |
+
self.attractor_type = attractor_type
|
| 76 |
+
self.memory_efficient = memory_efficient
|
| 77 |
+
|
| 78 |
+
self._net = nn.Sequential(
|
| 79 |
+
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
| 80 |
+
nn.ReLU(inplace=True),
|
| 81 |
+
nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), # x2 for linear norm
|
| 82 |
+
nn.ReLU(inplace=True)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
|
| 86 |
+
"""
|
| 87 |
+
Args:
|
| 88 |
+
x (torch.Tensor) : feature block; shape - n, c, h, w
|
| 89 |
+
b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w
|
| 93 |
+
"""
|
| 94 |
+
if prev_b_embedding is not None:
|
| 95 |
+
if interpolate:
|
| 96 |
+
prev_b_embedding = nn.functional.interpolate(
|
| 97 |
+
prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
|
| 98 |
+
x = x + prev_b_embedding
|
| 99 |
+
|
| 100 |
+
A = self._net(x)
|
| 101 |
+
eps = 1e-3
|
| 102 |
+
A = A + eps
|
| 103 |
+
n, c, h, w = A.shape
|
| 104 |
+
A = A.view(n, self.n_attractors, 2, h, w)
|
| 105 |
+
A_normed = A / A.sum(dim=2, keepdim=True) # n, a, 2, h, w
|
| 106 |
+
A_normed = A[:, :, 0, ...] # n, na, h, w
|
| 107 |
+
|
| 108 |
+
b_prev = nn.functional.interpolate(
|
| 109 |
+
b_prev, (h, w), mode='bilinear', align_corners=True)
|
| 110 |
+
b_centers = b_prev
|
| 111 |
+
|
| 112 |
+
if self.attractor_type == 'exp':
|
| 113 |
+
dist = exp_attractor
|
| 114 |
+
else:
|
| 115 |
+
dist = inv_attractor
|
| 116 |
+
|
| 117 |
+
if not self.memory_efficient:
|
| 118 |
+
func = {'mean': torch.mean, 'sum': torch.sum}[self.kind]
|
| 119 |
+
# .shape N, nbins, h, w
|
| 120 |
+
delta_c = func(dist(A_normed.unsqueeze(
|
| 121 |
+
2) - b_centers.unsqueeze(1)), dim=1)
|
| 122 |
+
else:
|
| 123 |
+
delta_c = torch.zeros_like(b_centers, device=b_centers.device)
|
| 124 |
+
for i in range(self.n_attractors):
|
| 125 |
+
# .shape N, nbins, h, w
|
| 126 |
+
delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers)
|
| 127 |
+
|
| 128 |
+
if self.kind == 'mean':
|
| 129 |
+
delta_c = delta_c / self.n_attractors
|
| 130 |
+
|
| 131 |
+
b_new_centers = b_centers + delta_c
|
| 132 |
+
B_centers = (self.max_depth - self.min_depth) * \
|
| 133 |
+
b_new_centers + self.min_depth
|
| 134 |
+
B_centers, _ = torch.sort(B_centers, dim=1)
|
| 135 |
+
B_centers = torch.clip(B_centers, self.min_depth, self.max_depth)
|
| 136 |
+
return b_new_centers, B_centers
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class AttractorLayerUnnormed(nn.Module):
|
| 140 |
+
def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10,
|
| 141 |
+
alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False):
|
| 142 |
+
"""
|
| 143 |
+
Attractor layer for bin centers. Bin centers are unbounded
|
| 144 |
+
"""
|
| 145 |
+
super().__init__()
|
| 146 |
+
|
| 147 |
+
self.n_attractors = n_attractors
|
| 148 |
+
self.n_bins = n_bins
|
| 149 |
+
self.min_depth = min_depth
|
| 150 |
+
self.max_depth = max_depth
|
| 151 |
+
self.alpha = alpha
|
| 152 |
+
self.gamma = gamma
|
| 153 |
+
self.kind = kind
|
| 154 |
+
self.attractor_type = attractor_type
|
| 155 |
+
self.memory_efficient = memory_efficient
|
| 156 |
+
|
| 157 |
+
self._net = nn.Sequential(
|
| 158 |
+
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
| 159 |
+
nn.ReLU(inplace=True),
|
| 160 |
+
nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0),
|
| 161 |
+
nn.Softplus()
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
|
| 165 |
+
"""
|
| 166 |
+
Args:
|
| 167 |
+
x (torch.Tensor) : feature block; shape - n, c, h, w
|
| 168 |
+
b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version
|
| 172 |
+
"""
|
| 173 |
+
if prev_b_embedding is not None:
|
| 174 |
+
if interpolate:
|
| 175 |
+
prev_b_embedding = nn.functional.interpolate(
|
| 176 |
+
prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
|
| 177 |
+
x = x + prev_b_embedding
|
| 178 |
+
|
| 179 |
+
A = self._net(x)
|
| 180 |
+
n, c, h, w = A.shape
|
| 181 |
+
|
| 182 |
+
b_prev = nn.functional.interpolate(
|
| 183 |
+
b_prev, (h, w), mode='bilinear', align_corners=True)
|
| 184 |
+
b_centers = b_prev
|
| 185 |
+
|
| 186 |
+
if self.attractor_type == 'exp':
|
| 187 |
+
dist = exp_attractor
|
| 188 |
+
else:
|
| 189 |
+
dist = inv_attractor
|
| 190 |
+
|
| 191 |
+
if not self.memory_efficient:
|
| 192 |
+
func = {'mean': torch.mean, 'sum': torch.sum}[self.kind]
|
| 193 |
+
# .shape N, nbins, h, w
|
| 194 |
+
delta_c = func(
|
| 195 |
+
dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1)
|
| 196 |
+
else:
|
| 197 |
+
delta_c = torch.zeros_like(b_centers, device=b_centers.device)
|
| 198 |
+
for i in range(self.n_attractors):
|
| 199 |
+
delta_c += dist(A[:, i, ...].unsqueeze(1) -
|
| 200 |
+
b_centers) # .shape N, nbins, h, w
|
| 201 |
+
|
| 202 |
+
if self.kind == 'mean':
|
| 203 |
+
delta_c = delta_c / self.n_attractors
|
| 204 |
+
|
| 205 |
+
b_new_centers = b_centers + delta_c
|
| 206 |
+
B_centers = b_new_centers
|
| 207 |
+
|
| 208 |
+
return b_new_centers, B_centers
|
depth/models_depth/checkpoint.py
ADDED
|
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) Microsoft
|
| 3 |
+
# Licensed under the MIT License.
|
| 4 |
+
# The code is from Swin Transformer.
|
| 5 |
+
# (https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmcv_custom/checkpoint.py)
|
| 6 |
+
# ------------------------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import io
|
| 9 |
+
import os
|
| 10 |
+
import os.path as osp
|
| 11 |
+
import pkgutil
|
| 12 |
+
import time
|
| 13 |
+
import warnings
|
| 14 |
+
import numpy as np
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
from importlib import import_module
|
| 17 |
+
from tempfile import TemporaryDirectory
|
| 18 |
+
from scipy import interpolate
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torchvision
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
from torch.optim import Optimizer
|
| 24 |
+
from torch.utils import model_zoo
|
| 25 |
+
from torch.nn import functional as F
|
| 26 |
+
|
| 27 |
+
import mmcv
|
| 28 |
+
from mmcv.fileio import FileClient
|
| 29 |
+
from mmcv.fileio import load as load_file
|
| 30 |
+
from mmcv.parallel import is_module_wrapper
|
| 31 |
+
from mmcv.utils import mkdir_or_exist
|
| 32 |
+
from mmcv.runner import get_dist_info
|
| 33 |
+
from mmcv.utils import get_logger
|
| 34 |
+
|
| 35 |
+
import logging
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_root_logger(log_file=None, log_level=logging.INFO):
|
| 39 |
+
"""Get the root logger.
|
| 40 |
+
|
| 41 |
+
The logger will be initialized if it has not been initialized. By default a
|
| 42 |
+
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
| 43 |
+
also be added. The name of the root logger is the top-level package name,
|
| 44 |
+
e.g., "mmseg".
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
| 48 |
+
will be added to the root logger.
|
| 49 |
+
log_level (int): The root logger level. Note that only the process of
|
| 50 |
+
rank 0 is affected, while other processes will set the level to
|
| 51 |
+
"Error" and be silent most of the time.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
logging.Logger: The root logger.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
logger = get_logger(name='mmpose', log_file=log_file, log_level=log_level)
|
| 58 |
+
|
| 59 |
+
return logger
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _get_mmcv_home():
|
| 63 |
+
mmcv_home = os.path.expanduser(
|
| 64 |
+
os.getenv(
|
| 65 |
+
ENV_MMCV_HOME,
|
| 66 |
+
os.path.join(
|
| 67 |
+
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
|
| 68 |
+
|
| 69 |
+
mkdir_or_exist(mmcv_home)
|
| 70 |
+
return mmcv_home
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_state_dict(module, state_dict, strict=False, logger=None):
|
| 74 |
+
"""Load state_dict to a module.
|
| 75 |
+
|
| 76 |
+
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
| 77 |
+
Default value for ``strict`` is set to ``False`` and the message for
|
| 78 |
+
param mismatch will be shown even if strict is False.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
module (Module): Module that receives the state_dict.
|
| 82 |
+
state_dict (OrderedDict): Weights.
|
| 83 |
+
strict (bool): whether to strictly enforce that the keys
|
| 84 |
+
in :attr:`state_dict` match the keys returned by this module's
|
| 85 |
+
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
| 86 |
+
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
| 87 |
+
message. If not specified, print function will be used.
|
| 88 |
+
"""
|
| 89 |
+
unexpected_keys = []
|
| 90 |
+
all_missing_keys = []
|
| 91 |
+
err_msg = []
|
| 92 |
+
|
| 93 |
+
metadata = getattr(state_dict, '_metadata', None)
|
| 94 |
+
state_dict = state_dict.copy()
|
| 95 |
+
if metadata is not None:
|
| 96 |
+
state_dict._metadata = metadata
|
| 97 |
+
|
| 98 |
+
# use _load_from_state_dict to enable checkpoint version control
|
| 99 |
+
def load(module, prefix=''):
|
| 100 |
+
# recursively check parallel module in case that the model has a
|
| 101 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
| 102 |
+
if is_module_wrapper(module):
|
| 103 |
+
module = module.module
|
| 104 |
+
local_metadata = {} if metadata is None else metadata.get(
|
| 105 |
+
prefix[:-1], {})
|
| 106 |
+
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
| 107 |
+
all_missing_keys, unexpected_keys,
|
| 108 |
+
err_msg)
|
| 109 |
+
for name, child in module._modules.items():
|
| 110 |
+
if child is not None:
|
| 111 |
+
load(child, prefix + name + '.')
|
| 112 |
+
|
| 113 |
+
load(module)
|
| 114 |
+
load = None # break load->load reference cycle
|
| 115 |
+
|
| 116 |
+
# ignore "num_batches_tracked" of BN layers
|
| 117 |
+
missing_keys = [
|
| 118 |
+
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
if unexpected_keys:
|
| 122 |
+
err_msg.append('unexpected key in source '
|
| 123 |
+
f'state_dict: {", ".join(unexpected_keys)}\n')
|
| 124 |
+
if missing_keys:
|
| 125 |
+
err_msg.append(
|
| 126 |
+
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
| 127 |
+
|
| 128 |
+
rank, _ = get_dist_info()
|
| 129 |
+
if len(err_msg) > 0 and rank == 0:
|
| 130 |
+
err_msg.insert(
|
| 131 |
+
0, 'The model and loaded state dict do not match exactly\n')
|
| 132 |
+
err_msg = '\n'.join(err_msg)
|
| 133 |
+
if strict:
|
| 134 |
+
raise RuntimeError(err_msg)
|
| 135 |
+
elif logger is not None:
|
| 136 |
+
logger.warning(err_msg)
|
| 137 |
+
else:
|
| 138 |
+
print(err_msg)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def load_url_dist(url, model_dir=None):
|
| 142 |
+
"""In distributed setting, this function only download checkpoint at local
|
| 143 |
+
rank 0."""
|
| 144 |
+
rank, world_size = get_dist_info()
|
| 145 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
| 146 |
+
if rank == 0:
|
| 147 |
+
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
| 148 |
+
if world_size > 1:
|
| 149 |
+
torch.distributed.barrier()
|
| 150 |
+
if rank > 0:
|
| 151 |
+
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
| 152 |
+
return checkpoint
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def load_pavimodel_dist(model_path, map_location=None):
|
| 156 |
+
"""In distributed setting, this function only download checkpoint at local
|
| 157 |
+
rank 0."""
|
| 158 |
+
try:
|
| 159 |
+
from pavi import modelcloud
|
| 160 |
+
except ImportError:
|
| 161 |
+
raise ImportError(
|
| 162 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
| 163 |
+
rank, world_size = get_dist_info()
|
| 164 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
| 165 |
+
if rank == 0:
|
| 166 |
+
model = modelcloud.get(model_path)
|
| 167 |
+
with TemporaryDirectory() as tmp_dir:
|
| 168 |
+
downloaded_file = osp.join(tmp_dir, model.name)
|
| 169 |
+
model.download(downloaded_file)
|
| 170 |
+
checkpoint = torch.load(downloaded_file, map_location=map_location)
|
| 171 |
+
if world_size > 1:
|
| 172 |
+
torch.distributed.barrier()
|
| 173 |
+
if rank > 0:
|
| 174 |
+
model = modelcloud.get(model_path)
|
| 175 |
+
with TemporaryDirectory() as tmp_dir:
|
| 176 |
+
downloaded_file = osp.join(tmp_dir, model.name)
|
| 177 |
+
model.download(downloaded_file)
|
| 178 |
+
checkpoint = torch.load(
|
| 179 |
+
downloaded_file, map_location=map_location)
|
| 180 |
+
return checkpoint
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def load_fileclient_dist(filename, backend, map_location):
|
| 184 |
+
"""In distributed setting, this function only download checkpoint at local
|
| 185 |
+
rank 0."""
|
| 186 |
+
rank, world_size = get_dist_info()
|
| 187 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
| 188 |
+
allowed_backends = ['ceph']
|
| 189 |
+
if backend not in allowed_backends:
|
| 190 |
+
raise ValueError(f'Load from Backend {backend} is not supported.')
|
| 191 |
+
if rank == 0:
|
| 192 |
+
fileclient = FileClient(backend=backend)
|
| 193 |
+
buffer = io.BytesIO(fileclient.get(filename))
|
| 194 |
+
checkpoint = torch.load(buffer, map_location=map_location)
|
| 195 |
+
if world_size > 1:
|
| 196 |
+
torch.distributed.barrier()
|
| 197 |
+
if rank > 0:
|
| 198 |
+
fileclient = FileClient(backend=backend)
|
| 199 |
+
buffer = io.BytesIO(fileclient.get(filename))
|
| 200 |
+
checkpoint = torch.load(buffer, map_location=map_location)
|
| 201 |
+
return checkpoint
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def get_torchvision_models():
|
| 205 |
+
model_urls = dict()
|
| 206 |
+
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
|
| 207 |
+
if ispkg:
|
| 208 |
+
continue
|
| 209 |
+
_zoo = import_module(f'torchvision.models.{name}')
|
| 210 |
+
if hasattr(_zoo, 'model_urls'):
|
| 211 |
+
_urls = getattr(_zoo, 'model_urls')
|
| 212 |
+
model_urls.update(_urls)
|
| 213 |
+
return model_urls
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def get_external_models():
|
| 217 |
+
mmcv_home = _get_mmcv_home()
|
| 218 |
+
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
|
| 219 |
+
default_urls = load_file(default_json_path)
|
| 220 |
+
assert isinstance(default_urls, dict)
|
| 221 |
+
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
|
| 222 |
+
if osp.exists(external_json_path):
|
| 223 |
+
external_urls = load_file(external_json_path)
|
| 224 |
+
assert isinstance(external_urls, dict)
|
| 225 |
+
default_urls.update(external_urls)
|
| 226 |
+
|
| 227 |
+
return default_urls
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def get_mmcls_models():
|
| 231 |
+
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
|
| 232 |
+
mmcls_urls = load_file(mmcls_json_path)
|
| 233 |
+
|
| 234 |
+
return mmcls_urls
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def get_deprecated_model_names():
|
| 238 |
+
deprecate_json_path = osp.join(mmcv.__path__[0],
|
| 239 |
+
'model_zoo/deprecated.json')
|
| 240 |
+
deprecate_urls = load_file(deprecate_json_path)
|
| 241 |
+
assert isinstance(deprecate_urls, dict)
|
| 242 |
+
|
| 243 |
+
return deprecate_urls
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def _process_mmcls_checkpoint(checkpoint):
|
| 247 |
+
state_dict = checkpoint['state_dict']
|
| 248 |
+
new_state_dict = OrderedDict()
|
| 249 |
+
for k, v in state_dict.items():
|
| 250 |
+
if k.startswith('backbone.'):
|
| 251 |
+
new_state_dict[k[9:]] = v
|
| 252 |
+
new_checkpoint = dict(state_dict=new_state_dict)
|
| 253 |
+
|
| 254 |
+
return new_checkpoint
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _load_checkpoint(filename, map_location=None):
|
| 258 |
+
"""Load checkpoint from somewhere (modelzoo, file, url).
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
| 262 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
| 263 |
+
details.
|
| 264 |
+
map_location (str | None): Same as :func:`torch.load`. Default: None.
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
dict | OrderedDict: The loaded checkpoint. It can be either an
|
| 268 |
+
OrderedDict storing model weights or a dict containing other
|
| 269 |
+
information, which depends on the checkpoint.
|
| 270 |
+
"""
|
| 271 |
+
if filename.startswith('modelzoo://'):
|
| 272 |
+
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
|
| 273 |
+
'use "torchvision://" instead')
|
| 274 |
+
model_urls = get_torchvision_models()
|
| 275 |
+
model_name = filename[11:]
|
| 276 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
| 277 |
+
elif filename.startswith('torchvision://'):
|
| 278 |
+
model_urls = get_torchvision_models()
|
| 279 |
+
model_name = filename[14:]
|
| 280 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
| 281 |
+
elif filename.startswith('open-mmlab://'):
|
| 282 |
+
model_urls = get_external_models()
|
| 283 |
+
model_name = filename[13:]
|
| 284 |
+
deprecated_urls = get_deprecated_model_names()
|
| 285 |
+
if model_name in deprecated_urls:
|
| 286 |
+
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
|
| 287 |
+
f'of open-mmlab://{deprecated_urls[model_name]}')
|
| 288 |
+
model_name = deprecated_urls[model_name]
|
| 289 |
+
model_url = model_urls[model_name]
|
| 290 |
+
# check if is url
|
| 291 |
+
if model_url.startswith(('http://', 'https://')):
|
| 292 |
+
checkpoint = load_url_dist(model_url)
|
| 293 |
+
else:
|
| 294 |
+
filename = osp.join(_get_mmcv_home(), model_url)
|
| 295 |
+
if not osp.isfile(filename):
|
| 296 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
| 297 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
| 298 |
+
elif filename.startswith('mmcls://'):
|
| 299 |
+
model_urls = get_mmcls_models()
|
| 300 |
+
model_name = filename[8:]
|
| 301 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
| 302 |
+
checkpoint = _process_mmcls_checkpoint(checkpoint)
|
| 303 |
+
elif filename.startswith(('http://', 'https://')):
|
| 304 |
+
checkpoint = load_url_dist(filename)
|
| 305 |
+
elif filename.startswith('pavi://'):
|
| 306 |
+
model_path = filename[7:]
|
| 307 |
+
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
|
| 308 |
+
elif filename.startswith('s3://'):
|
| 309 |
+
checkpoint = load_fileclient_dist(
|
| 310 |
+
filename, backend='ceph', map_location=map_location)
|
| 311 |
+
else:
|
| 312 |
+
if not osp.isfile(filename):
|
| 313 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
| 314 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
| 315 |
+
return checkpoint
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def load_checkpoint_swin(model,
|
| 319 |
+
filename,
|
| 320 |
+
map_location='cpu',
|
| 321 |
+
strict=False,
|
| 322 |
+
rpe_interpolation='outer_mask',
|
| 323 |
+
logger=None):
|
| 324 |
+
"""Load checkpoint from a file or URI.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
model (Module): Module to load checkpoint.
|
| 328 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
| 329 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
| 330 |
+
details.
|
| 331 |
+
map_location (str): Same as :func:`torch.load`.
|
| 332 |
+
strict (bool): Whether to allow different params for the model and
|
| 333 |
+
checkpoint.
|
| 334 |
+
logger (:mod:`logging.Logger` or None): The logger for error message.
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
dict or OrderedDict: The loaded checkpoint.
|
| 338 |
+
"""
|
| 339 |
+
checkpoint = _load_checkpoint(filename, map_location)
|
| 340 |
+
# OrderedDict is a subclass of dict
|
| 341 |
+
if not isinstance(checkpoint, dict):
|
| 342 |
+
raise RuntimeError(
|
| 343 |
+
f'No state_dict found in checkpoint file {filename}')
|
| 344 |
+
# get state_dict from checkpoint
|
| 345 |
+
if 'state_dict' in checkpoint:
|
| 346 |
+
state_dict = checkpoint['state_dict']
|
| 347 |
+
elif 'model' in checkpoint:
|
| 348 |
+
state_dict = checkpoint['model']
|
| 349 |
+
elif 'module' in checkpoint:
|
| 350 |
+
state_dict = checkpoint['module']
|
| 351 |
+
else:
|
| 352 |
+
state_dict = checkpoint
|
| 353 |
+
# strip prefix of state_dict
|
| 354 |
+
if list(state_dict.keys())[0].startswith('module.'):
|
| 355 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
| 356 |
+
|
| 357 |
+
# for MoBY, load model of online branch
|
| 358 |
+
if sorted(list(state_dict.keys()))[2].startswith('encoder'):
|
| 359 |
+
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
|
| 360 |
+
|
| 361 |
+
# reshape absolute position embedding for Swin
|
| 362 |
+
if state_dict.get('absolute_pos_embed') is not None:
|
| 363 |
+
absolute_pos_embed = state_dict['absolute_pos_embed']
|
| 364 |
+
N1, L, C1 = absolute_pos_embed.size()
|
| 365 |
+
N2, C2, H, W = model.absolute_pos_embed.size()
|
| 366 |
+
if N1 != N2 or C1 != C2 or L != H * W:
|
| 367 |
+
logger.warning("Error in loading absolute_pos_embed, pass")
|
| 368 |
+
else:
|
| 369 |
+
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
|
| 370 |
+
|
| 371 |
+
# interpolate position bias table if needed
|
| 372 |
+
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
|
| 373 |
+
for k in relative_position_bias_table_keys:
|
| 374 |
+
table_pretrained = state_dict[k]
|
| 375 |
+
table_current = model.state_dict()[k]
|
| 376 |
+
L1, nH1 = table_pretrained.size()
|
| 377 |
+
L2, nH2 = table_current.size()
|
| 378 |
+
if nH1 != nH2:
|
| 379 |
+
logger.warning(f"Error in loading {k}, pass")
|
| 380 |
+
else:
|
| 381 |
+
if L1 != L2:
|
| 382 |
+
if rpe_interpolation in ['bicubic', 'bilinear', 'nearest']:
|
| 383 |
+
logger.info(f"Interpolate relative_position_bias_table using {rpe_interpolation}")
|
| 384 |
+
S1 = int(L1 ** 0.5)
|
| 385 |
+
S2 = int(L2 ** 0.5)
|
| 386 |
+
table_pretrained_resized = F.interpolate(
|
| 387 |
+
table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
|
| 388 |
+
size=(S2, S2), mode=rpe_interpolation)
|
| 389 |
+
state_dict[k] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
|
| 390 |
+
elif rpe_interpolation == 'geo':
|
| 391 |
+
logger.info("Interpolate relative_position_bias_table using geo.")
|
| 392 |
+
src_size = int(L1 ** 0.5)
|
| 393 |
+
dst_size = int(L2 ** 0.5)
|
| 394 |
+
|
| 395 |
+
def geometric_progression(a, r, n):
|
| 396 |
+
return a * (1.0 - r ** n) / (1.0 - r)
|
| 397 |
+
|
| 398 |
+
left, right = 1.01, 1.5
|
| 399 |
+
while right - left > 1e-6:
|
| 400 |
+
q = (left + right) / 2.0
|
| 401 |
+
gp = geometric_progression(1, q, src_size // 2)
|
| 402 |
+
if gp > dst_size // 2:
|
| 403 |
+
right = q
|
| 404 |
+
else:
|
| 405 |
+
left = q
|
| 406 |
+
|
| 407 |
+
# if q > 1.13492:
|
| 408 |
+
# q = 1.13492
|
| 409 |
+
|
| 410 |
+
dis = []
|
| 411 |
+
cur = 1
|
| 412 |
+
for i in range(src_size // 2):
|
| 413 |
+
dis.append(cur)
|
| 414 |
+
cur += q ** (i + 1)
|
| 415 |
+
|
| 416 |
+
r_ids = [-_ for _ in reversed(dis)]
|
| 417 |
+
|
| 418 |
+
x = r_ids + [0] + dis
|
| 419 |
+
y = r_ids + [0] + dis
|
| 420 |
+
|
| 421 |
+
t = dst_size // 2.0
|
| 422 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
| 423 |
+
dy = np.arange(-t, t + 0.1, 1.0)
|
| 424 |
+
|
| 425 |
+
logger.info("Original positions = %s" % str(x))
|
| 426 |
+
logger.info("Target positions = %s" % str(dx))
|
| 427 |
+
|
| 428 |
+
all_rel_pos_bias = []
|
| 429 |
+
|
| 430 |
+
for i in range(nH1):
|
| 431 |
+
z = table_pretrained[:, i].view(src_size, src_size).float().numpy()
|
| 432 |
+
f_cubic = interpolate.interp2d(x, y, z, kind='cubic')
|
| 433 |
+
all_rel_pos_bias.append(torch.Tensor(f_cubic(dx, dy)).contiguous().view(-1, 1).to(
|
| 434 |
+
table_pretrained.device))
|
| 435 |
+
|
| 436 |
+
new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
| 437 |
+
state_dict[k] = new_rel_pos_bias
|
| 438 |
+
|
| 439 |
+
if 'pos_embed' in state_dict:
|
| 440 |
+
pos_embed_checkpoint = state_dict['pos_embed']
|
| 441 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 442 |
+
num_patches = model.patch_embed.num_patches
|
| 443 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
| 444 |
+
# height (== width) for the checkpoint position embedding
|
| 445 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 446 |
+
# height (== width) for the new position embedding
|
| 447 |
+
new_size = int(num_patches ** 0.5)
|
| 448 |
+
# class_token and dist_token are kept unchanged
|
| 449 |
+
if orig_size != new_size:
|
| 450 |
+
if dist.get_rank() == 0:
|
| 451 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
| 452 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 453 |
+
# only the position tokens are interpolated
|
| 454 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 455 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 456 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 457 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| 458 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 459 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 460 |
+
state_dict['pos_embed'] = new_pos_embed
|
| 461 |
+
|
| 462 |
+
# load state_dict
|
| 463 |
+
load_state_dict(model, state_dict, strict, logger)
|
| 464 |
+
return checkpoint
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def weights_to_cpu(state_dict):
|
| 468 |
+
"""Copy a model state_dict to cpu.
|
| 469 |
+
|
| 470 |
+
Args:
|
| 471 |
+
state_dict (OrderedDict): Model weights on GPU.
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
OrderedDict: Model weights on GPU.
|
| 475 |
+
"""
|
| 476 |
+
state_dict_cpu = OrderedDict()
|
| 477 |
+
for key, val in state_dict.items():
|
| 478 |
+
state_dict_cpu[key] = val.cpu()
|
| 479 |
+
return state_dict_cpu
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def _save_to_state_dict(module, destination, prefix, keep_vars):
|
| 483 |
+
"""Saves module state to `destination` dictionary.
|
| 484 |
+
|
| 485 |
+
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
|
| 486 |
+
|
| 487 |
+
Args:
|
| 488 |
+
module (nn.Module): The module to generate state_dict.
|
| 489 |
+
destination (dict): A dict where state will be stored.
|
| 490 |
+
prefix (str): The prefix for parameters and buffers used in this
|
| 491 |
+
module.
|
| 492 |
+
"""
|
| 493 |
+
for name, param in module._parameters.items():
|
| 494 |
+
if param is not None:
|
| 495 |
+
destination[prefix + name] = param if keep_vars else param.detach()
|
| 496 |
+
for name, buf in module._buffers.items():
|
| 497 |
+
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
|
| 498 |
+
if buf is not None:
|
| 499 |
+
destination[prefix + name] = buf if keep_vars else buf.detach()
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
|
| 503 |
+
"""Returns a dictionary containing a whole state of the module.
|
| 504 |
+
|
| 505 |
+
Both parameters and persistent buffers (e.g. running averages) are
|
| 506 |
+
included. Keys are corresponding parameter and buffer names.
|
| 507 |
+
|
| 508 |
+
This method is modified from :meth:`torch.nn.Module.state_dict` to
|
| 509 |
+
recursively check parallel module in case that the model has a complicated
|
| 510 |
+
structure, e.g., nn.Module(nn.Module(DDP)).
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
module (nn.Module): The module to generate state_dict.
|
| 514 |
+
destination (OrderedDict): Returned dict for the state of the
|
| 515 |
+
module.
|
| 516 |
+
prefix (str): Prefix of the key.
|
| 517 |
+
keep_vars (bool): Whether to keep the variable property of the
|
| 518 |
+
parameters. Default: False.
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
dict: A dictionary containing a whole state of the module.
|
| 522 |
+
"""
|
| 523 |
+
# recursively check parallel module in case that the model has a
|
| 524 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
| 525 |
+
if is_module_wrapper(module):
|
| 526 |
+
module = module.module
|
| 527 |
+
|
| 528 |
+
# below is the same as torch.nn.Module.state_dict()
|
| 529 |
+
if destination is None:
|
| 530 |
+
destination = OrderedDict()
|
| 531 |
+
destination._metadata = OrderedDict()
|
| 532 |
+
destination._metadata[prefix[:-1]] = local_metadata = dict(
|
| 533 |
+
version=module._version)
|
| 534 |
+
_save_to_state_dict(module, destination, prefix, keep_vars)
|
| 535 |
+
for name, child in module._modules.items():
|
| 536 |
+
if child is not None:
|
| 537 |
+
get_state_dict(
|
| 538 |
+
child, destination, prefix + name + '.', keep_vars=keep_vars)
|
| 539 |
+
for hook in module._state_dict_hooks.values():
|
| 540 |
+
hook_result = hook(module, destination, prefix, local_metadata)
|
| 541 |
+
if hook_result is not None:
|
| 542 |
+
destination = hook_result
|
| 543 |
+
return destination
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
| 547 |
+
"""Save checkpoint to file.
|
| 548 |
+
|
| 549 |
+
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
|
| 550 |
+
``optimizer``. By default ``meta`` will contain version and time info.
|
| 551 |
+
|
| 552 |
+
Args:
|
| 553 |
+
model (Module): Module whose params are to be saved.
|
| 554 |
+
filename (str): Checkpoint filename.
|
| 555 |
+
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
| 556 |
+
meta (dict, optional): Metadata to be saved in checkpoint.
|
| 557 |
+
"""
|
| 558 |
+
if meta is None:
|
| 559 |
+
meta = {}
|
| 560 |
+
elif not isinstance(meta, dict):
|
| 561 |
+
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
| 562 |
+
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
| 563 |
+
|
| 564 |
+
if is_module_wrapper(model):
|
| 565 |
+
model = model.module
|
| 566 |
+
|
| 567 |
+
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
| 568 |
+
# save class name to the meta
|
| 569 |
+
meta.update(CLASSES=model.CLASSES)
|
| 570 |
+
|
| 571 |
+
checkpoint = {
|
| 572 |
+
'meta': meta,
|
| 573 |
+
'state_dict': weights_to_cpu(get_state_dict(model))
|
| 574 |
+
}
|
| 575 |
+
# save optimizer state dict in the checkpoint
|
| 576 |
+
if isinstance(optimizer, Optimizer):
|
| 577 |
+
checkpoint['optimizer'] = optimizer.state_dict()
|
| 578 |
+
elif isinstance(optimizer, dict):
|
| 579 |
+
checkpoint['optimizer'] = {}
|
| 580 |
+
for name, optim in optimizer.items():
|
| 581 |
+
checkpoint['optimizer'][name] = optim.state_dict()
|
| 582 |
+
|
| 583 |
+
if filename.startswith('pavi://'):
|
| 584 |
+
try:
|
| 585 |
+
from pavi import modelcloud
|
| 586 |
+
from pavi.exception import NodeNotFoundError
|
| 587 |
+
except ImportError:
|
| 588 |
+
raise ImportError(
|
| 589 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
| 590 |
+
model_path = filename[7:]
|
| 591 |
+
root = modelcloud.Folder()
|
| 592 |
+
model_dir, model_name = osp.split(model_path)
|
| 593 |
+
try:
|
| 594 |
+
model = modelcloud.get(model_dir)
|
| 595 |
+
except NodeNotFoundError:
|
| 596 |
+
model = root.create_training_model(model_dir)
|
| 597 |
+
with TemporaryDirectory() as tmp_dir:
|
| 598 |
+
checkpoint_file = osp.join(tmp_dir, model_name)
|
| 599 |
+
with open(checkpoint_file, 'wb') as f:
|
| 600 |
+
torch.save(checkpoint, f)
|
| 601 |
+
f.flush()
|
| 602 |
+
model.create_file(checkpoint_file, name=model_name)
|
| 603 |
+
else:
|
| 604 |
+
mmcv.mkdir_or_exist(osp.dirname(filename))
|
| 605 |
+
# immediately flush buffer
|
| 606 |
+
with open(filename, 'wb') as f:
|
| 607 |
+
torch.save(checkpoint, f)
|
| 608 |
+
f.flush()
|
depth/models_depth/dist_layers.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
| 4 |
+
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# File author: Shariq Farooq Bhat
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def log_binom(n, k, eps=1e-7):
|
| 30 |
+
""" log(nCk) using stirling approximation """
|
| 31 |
+
n = n + eps
|
| 32 |
+
k = k + eps
|
| 33 |
+
return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class LogBinomial(nn.Module):
|
| 37 |
+
def __init__(self, n_classes=256, act=torch.softmax):
|
| 38 |
+
"""Compute log binomial distribution for n_classes
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
n_classes (int, optional): number of output classes. Defaults to 256.
|
| 42 |
+
"""
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.K = n_classes
|
| 45 |
+
self.act = act
|
| 46 |
+
self.register_buffer('k_idx', torch.arange(
|
| 47 |
+
0, n_classes).view(1, -1, 1, 1))
|
| 48 |
+
self.register_buffer('K_minus_1', torch.Tensor(
|
| 49 |
+
[self.K-1]).view(1, -1, 1, 1))
|
| 50 |
+
|
| 51 |
+
def forward(self, x, t=1., eps=1e-4):
|
| 52 |
+
"""Compute log binomial distribution for x
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
x (torch.Tensor - NCHW): probabilities
|
| 56 |
+
t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1..
|
| 57 |
+
eps (float, optional): Small number for numerical stability. Defaults to 1e-4.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
torch.Tensor -NCHW: log binomial distribution logbinomial(p;t)
|
| 61 |
+
"""
|
| 62 |
+
if x.ndim == 3:
|
| 63 |
+
x = x.unsqueeze(1) # make it nchw
|
| 64 |
+
|
| 65 |
+
one_minus_x = torch.clamp(1 - x, eps, 1)
|
| 66 |
+
x = torch.clamp(x, eps, 1)
|
| 67 |
+
y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \
|
| 68 |
+
torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x)
|
| 69 |
+
return self.act(y/t, dim=1)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ConditionalLogBinomial(nn.Module):
|
| 73 |
+
def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax):
|
| 74 |
+
"""Conditional Log Binomial distribution
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
in_features (int): number of input channels in main feature
|
| 78 |
+
condition_dim (int): number of input channels in condition feature
|
| 79 |
+
n_classes (int, optional): Number of classes. Defaults to 256.
|
| 80 |
+
bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2.
|
| 81 |
+
p_eps (float, optional): small eps value. Defaults to 1e-4.
|
| 82 |
+
max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50.
|
| 83 |
+
min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7.
|
| 84 |
+
"""
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.p_eps = p_eps
|
| 87 |
+
self.max_temp = max_temp
|
| 88 |
+
self.min_temp = min_temp
|
| 89 |
+
self.log_binomial_transform = LogBinomial(n_classes, act=act)
|
| 90 |
+
bottleneck = (in_features + condition_dim) // bottleneck_factor
|
| 91 |
+
self.mlp = nn.Sequential(
|
| 92 |
+
nn.Conv2d(in_features + condition_dim, bottleneck,
|
| 93 |
+
kernel_size=1, stride=1, padding=0),
|
| 94 |
+
nn.GELU(),
|
| 95 |
+
# 2 for p linear norm, 2 for t linear norm
|
| 96 |
+
nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0),
|
| 97 |
+
nn.Softplus()
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def forward(self, x, cond):
|
| 101 |
+
"""Forward pass
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
x (torch.Tensor - NCHW): Main feature
|
| 105 |
+
cond (torch.Tensor - NCHW): condition feature
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
torch.Tensor: Output log binomial distribution
|
| 109 |
+
"""
|
| 110 |
+
pt = self.mlp(torch.concat((x, cond), dim=1))
|
| 111 |
+
p, t = pt[:, :2, ...], pt[:, 2:, ...]
|
| 112 |
+
|
| 113 |
+
p = p + self.p_eps
|
| 114 |
+
p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...])
|
| 115 |
+
|
| 116 |
+
t = t + self.p_eps
|
| 117 |
+
t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...])
|
| 118 |
+
t = t.unsqueeze(1)
|
| 119 |
+
t = (self.max_temp - self.min_temp) * t + self.min_temp
|
| 120 |
+
|
| 121 |
+
return self.log_binomial_transform(p, t)
|
depth/models_depth/layers.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PatchTransformerEncoder(nn.Module):
|
| 6 |
+
def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4):
|
| 7 |
+
super(PatchTransformerEncoder, self).__init__()
|
| 8 |
+
encoder_layers = nn.TransformerEncoderLayer(embedding_dim, num_heads, dim_feedforward=1024)
|
| 9 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=4) # takes shape S,N,E
|
| 10 |
+
|
| 11 |
+
self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim,
|
| 12 |
+
kernel_size=patch_size, stride=patch_size, padding=0)
|
| 13 |
+
|
| 14 |
+
self.positional_encodings = nn.Parameter(torch.rand(900, embedding_dim), requires_grad=True)
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
embeddings = self.embedding_convPxP(x).flatten(2) # .shape = n,c,s = n, embedding_dim, s
|
| 18 |
+
# embeddings = nn.functional.pad(embeddings, (1,0)) # extra special token at start ?
|
| 19 |
+
embeddings = embeddings + self.positional_encodings[:embeddings.shape[2], :].T.unsqueeze(0)
|
| 20 |
+
|
| 21 |
+
# change to S,N,E format required by transformer
|
| 22 |
+
embeddings = embeddings.permute(2, 0, 1)
|
| 23 |
+
x = self.transformer_encoder(embeddings) # .shape = S, N, E
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class PixelWiseDotProduct(nn.Module):
|
| 28 |
+
def __init__(self):
|
| 29 |
+
super(PixelWiseDotProduct, self).__init__()
|
| 30 |
+
|
| 31 |
+
def forward(self, x, K):
|
| 32 |
+
n, c, h, w = x.size()
|
| 33 |
+
_, cout, ck = K.size()
|
| 34 |
+
assert c == ck, "Number of channels in x and Embedding dimension (at dim 2) of K matrix must match"
|
| 35 |
+
y = torch.matmul(x.view(n, c, h * w).permute(0, 2, 1), K.permute(0, 2, 1)) # .shape = n, hw, cout
|
| 36 |
+
return y.permute(0, 2, 1).view(n, cout, h, w)
|
depth/models_depth/localbins_layers.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
| 4 |
+
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# File author: Shariq Farooq Bhat
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SeedBinRegressor(nn.Module):
|
| 30 |
+
def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
|
| 31 |
+
"""Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
in_features (int): input channels
|
| 35 |
+
n_bins (int, optional): Number of bin centers. Defaults to 16.
|
| 36 |
+
mlp_dim (int, optional): Hidden dimension. Defaults to 256.
|
| 37 |
+
min_depth (float, optional): Min depth value. Defaults to 1e-3.
|
| 38 |
+
max_depth (float, optional): Max depth value. Defaults to 10.
|
| 39 |
+
"""
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.version = "1_1"
|
| 42 |
+
self.min_depth = min_depth
|
| 43 |
+
self.max_depth = max_depth
|
| 44 |
+
|
| 45 |
+
self._net = nn.Sequential(
|
| 46 |
+
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
| 47 |
+
nn.ReLU(inplace=True),
|
| 48 |
+
nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
|
| 49 |
+
nn.ReLU(inplace=True)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
"""
|
| 54 |
+
Returns tensor of bin_width vectors (centers). One vector b for every pixel
|
| 55 |
+
"""
|
| 56 |
+
B = self._net(x)
|
| 57 |
+
eps = 1e-3
|
| 58 |
+
B = B + eps
|
| 59 |
+
B_widths_normed = B / B.sum(dim=1, keepdim=True)
|
| 60 |
+
B_widths = (self.max_depth - self.min_depth) * \
|
| 61 |
+
B_widths_normed # .shape NCHW
|
| 62 |
+
# pad has the form (left, right, top, bottom, front, back)
|
| 63 |
+
B_widths = nn.functional.pad(
|
| 64 |
+
B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth)
|
| 65 |
+
B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW
|
| 66 |
+
|
| 67 |
+
B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...])
|
| 68 |
+
return B_widths_normed, B_centers
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class SeedBinRegressorUnnormed(nn.Module):
|
| 72 |
+
def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
|
| 73 |
+
"""Bin center regressor network. Bin centers are unbounded
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
in_features (int): input channels
|
| 77 |
+
n_bins (int, optional): Number of bin centers. Defaults to 16.
|
| 78 |
+
mlp_dim (int, optional): Hidden dimension. Defaults to 256.
|
| 79 |
+
min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
|
| 80 |
+
max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
|
| 81 |
+
"""
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.version = "1_1"
|
| 84 |
+
self._net = nn.Sequential(
|
| 85 |
+
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
| 86 |
+
nn.ReLU(inplace=True),
|
| 87 |
+
nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
|
| 88 |
+
nn.Softplus()
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
"""
|
| 93 |
+
Returns tensor of bin_width vectors (centers). One vector b for every pixel
|
| 94 |
+
"""
|
| 95 |
+
B_centers = self._net(x)
|
| 96 |
+
return B_centers, B_centers
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class Projector(nn.Module):
|
| 100 |
+
def __init__(self, in_features, out_features, mlp_dim=128):
|
| 101 |
+
"""Projector MLP
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
in_features (int): input channels
|
| 105 |
+
out_features (int): output channels
|
| 106 |
+
mlp_dim (int, optional): hidden dimension. Defaults to 128.
|
| 107 |
+
"""
|
| 108 |
+
super().__init__()
|
| 109 |
+
|
| 110 |
+
self._net = nn.Sequential(
|
| 111 |
+
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
| 112 |
+
nn.ReLU(inplace=True),
|
| 113 |
+
nn.Conv2d(mlp_dim, out_features, 1, 1, 0),
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def forward(self, x):
|
| 117 |
+
return self._net(x)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class LinearSplitter(nn.Module):
|
| 122 |
+
def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10):
|
| 123 |
+
super().__init__()
|
| 124 |
+
|
| 125 |
+
self.prev_nbins = prev_nbins
|
| 126 |
+
self.split_factor = split_factor
|
| 127 |
+
self.min_depth = min_depth
|
| 128 |
+
self.max_depth = max_depth
|
| 129 |
+
|
| 130 |
+
self._net = nn.Sequential(
|
| 131 |
+
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
| 132 |
+
nn.GELU(),
|
| 133 |
+
nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0),
|
| 134 |
+
nn.ReLU()
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
|
| 138 |
+
"""
|
| 139 |
+
x : feature block; shape - n, c, h, w
|
| 140 |
+
b_prev : previous bin widths normed; shape - n, prev_nbins, h, w
|
| 141 |
+
"""
|
| 142 |
+
if prev_b_embedding is not None:
|
| 143 |
+
if interpolate:
|
| 144 |
+
prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
|
| 145 |
+
x = x + prev_b_embedding
|
| 146 |
+
S = self._net(x)
|
| 147 |
+
eps = 1e-3
|
| 148 |
+
S = S + eps
|
| 149 |
+
n, c, h, w = S.shape
|
| 150 |
+
S = S.view(n, self.prev_nbins, self.split_factor, h, w)
|
| 151 |
+
S_normed = S / S.sum(dim=2, keepdim=True) # fractional splits
|
| 152 |
+
|
| 153 |
+
b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) # renormalize for gurantees
|
| 157 |
+
# print(b_prev.shape, S_normed.shape)
|
| 158 |
+
# if is_for_query:(1).expand(-1, b_prev.size(0)//n, -1, -1, -1, -1).flatten(0,1) # TODO ? can replace all this with a single torch.repeat?
|
| 159 |
+
b = b_prev.unsqueeze(2) * S_normed
|
| 160 |
+
b = b.flatten(1,2) # .shape n, prev_nbins * split_factor, h, w
|
| 161 |
+
|
| 162 |
+
# calculate bin centers for loss calculation
|
| 163 |
+
B_widths = (self.max_depth - self.min_depth) * b # .shape N, nprev * splitfactor, H, W
|
| 164 |
+
# pad has the form (left, right, top, bottom, front, back)
|
| 165 |
+
B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth)
|
| 166 |
+
B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW
|
| 167 |
+
|
| 168 |
+
B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...])
|
| 169 |
+
return b, B_centers
|
depth/models_depth/miniViT.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from .layers import PatchTransformerEncoder, PixelWiseDotProduct
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class mViT(nn.Module):
|
| 8 |
+
def __init__(self, in_channels, n_query_channels=128, patch_size=16, dim_out=256,
|
| 9 |
+
embedding_dim=128, num_heads=4, norm='linear'):
|
| 10 |
+
super(mViT, self).__init__()
|
| 11 |
+
self.norm = norm
|
| 12 |
+
self.n_query_channels = n_query_channels
|
| 13 |
+
self.patch_transformer = PatchTransformerEncoder(in_channels, patch_size, embedding_dim, num_heads)
|
| 14 |
+
self.dot_product_layer = PixelWiseDotProduct()
|
| 15 |
+
|
| 16 |
+
self.conv3x3 = nn.Conv2d(in_channels, embedding_dim, kernel_size=3, stride=1, padding=1)
|
| 17 |
+
self.regressor = nn.Sequential(nn.Linear(embedding_dim, 256),
|
| 18 |
+
nn.LeakyReLU(),
|
| 19 |
+
nn.Linear(256, 256),
|
| 20 |
+
nn.LeakyReLU(),
|
| 21 |
+
nn.Linear(256, dim_out))
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
# n, c, h, w = x.size()
|
| 25 |
+
tgt = self.patch_transformer(x.clone()) # .shape = S, N, E
|
| 26 |
+
|
| 27 |
+
x = self.conv3x3(x)
|
| 28 |
+
|
| 29 |
+
regression_head, queries = tgt[0, ...], tgt[1:self.n_query_channels + 1, ...]
|
| 30 |
+
|
| 31 |
+
# Change from S, N, E to N, S, E
|
| 32 |
+
queries = queries.permute(1, 0, 2)
|
| 33 |
+
range_attention_maps = self.dot_product_layer(x, queries) # .shape = n, n_query_channels, h, w
|
| 34 |
+
|
| 35 |
+
y = self.regressor(regression_head) # .shape = N, dim_out
|
| 36 |
+
if self.norm == 'linear':
|
| 37 |
+
y = torch.relu(y)
|
| 38 |
+
eps = 0.1
|
| 39 |
+
y = y + eps
|
| 40 |
+
elif self.norm == 'softmax':
|
| 41 |
+
return torch.softmax(y, dim=1), range_attention_maps
|
| 42 |
+
else:
|
| 43 |
+
y = torch.sigmoid(y)
|
| 44 |
+
y = y / y.sum(dim=1, keepdim=True)
|
| 45 |
+
return y, range_attention_maps
|
depth/models_depth/model.py
ADDED
|
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) Microsoft
|
| 3 |
+
# Licensed under the MIT License.
|
| 4 |
+
# The deconvolution code is based on Simple Baseline.
|
| 5 |
+
# (https://github.com/microsoft/human-pose-estimation.pytorch/blob/master/lib/models/pose_resnet.py)
|
| 6 |
+
# Modified by Zigang Geng ([email protected]).
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 12 |
+
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
|
| 13 |
+
constant_init, normal_init)
|
| 14 |
+
from omegaconf import OmegaConf
|
| 15 |
+
from ldm.util import instantiate_from_config
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
from evp.models import UNetWrapper, TextAdapterRefer, FrozenCLIPEmbedder
|
| 19 |
+
from .miniViT import mViT
|
| 20 |
+
from .attractor import AttractorLayer, AttractorLayerUnnormed
|
| 21 |
+
from .dist_layers import ConditionalLogBinomial
|
| 22 |
+
from .localbins_layers import (Projector, SeedBinRegressor, SeedBinRegressorUnnormed)
|
| 23 |
+
import os
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def icnr(x, scale=2, init=nn.init.kaiming_normal_):
|
| 27 |
+
"""
|
| 28 |
+
Checkerboard artifact free sub-pixel convolution
|
| 29 |
+
https://arxiv.org/abs/1707.02937
|
| 30 |
+
"""
|
| 31 |
+
ni,nf,h,w = x.shape
|
| 32 |
+
ni2 = int(ni/(scale**2))
|
| 33 |
+
k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
|
| 34 |
+
k = k.contiguous().view(ni2, nf, -1)
|
| 35 |
+
k = k.repeat(1, 1, scale**2)
|
| 36 |
+
k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
|
| 37 |
+
x.data.copy_(k)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class PixelShuffle(nn.Module):
|
| 41 |
+
"""
|
| 42 |
+
Real-Time Single Image and Video Super-Resolution
|
| 43 |
+
https://arxiv.org/abs/1609.05158
|
| 44 |
+
"""
|
| 45 |
+
def __init__(self, n_channels, scale):
|
| 46 |
+
super(PixelShuffle, self).__init__()
|
| 47 |
+
self.conv = nn.Conv2d(n_channels, n_channels*(scale**2), kernel_size=1)
|
| 48 |
+
icnr(self.conv.weight)
|
| 49 |
+
self.shuf = nn.PixelShuffle(scale)
|
| 50 |
+
self.relu = nn.ReLU()
|
| 51 |
+
|
| 52 |
+
def forward(self,x):
|
| 53 |
+
x = self.shuf(self.relu(self.conv(x)))
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class AttentionModule(nn.Module):
|
| 58 |
+
def __init__(self, in_channels, out_channels):
|
| 59 |
+
super(AttentionModule, self).__init__()
|
| 60 |
+
|
| 61 |
+
# Convolutional Layers
|
| 62 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 63 |
+
|
| 64 |
+
# Group Normalization
|
| 65 |
+
self.group_norm = nn.GroupNorm(20, out_channels)
|
| 66 |
+
|
| 67 |
+
# ReLU Activation
|
| 68 |
+
self.relu = nn.ReLU()
|
| 69 |
+
|
| 70 |
+
# Spatial Attention
|
| 71 |
+
self.spatial_attention = nn.Sequential(
|
| 72 |
+
nn.Conv2d(in_channels, 1, kernel_size=1),
|
| 73 |
+
nn.Sigmoid()
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
# Apply spatial attention
|
| 78 |
+
spatial_attention = self.spatial_attention(x)
|
| 79 |
+
x = x * spatial_attention
|
| 80 |
+
|
| 81 |
+
# Apply convolutional layer
|
| 82 |
+
x = self.conv1(x)
|
| 83 |
+
x = self.group_norm(x)
|
| 84 |
+
x = self.relu(x)
|
| 85 |
+
|
| 86 |
+
return x
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class AttentionDownsamplingModule(nn.Module):
|
| 90 |
+
def __init__(self, in_channels, out_channels, scale_factor=2):
|
| 91 |
+
super(AttentionDownsamplingModule, self).__init__()
|
| 92 |
+
|
| 93 |
+
# Spatial Attention
|
| 94 |
+
self.spatial_attention = nn.Sequential(
|
| 95 |
+
nn.Conv2d(in_channels, 1, kernel_size=1),
|
| 96 |
+
nn.Sigmoid()
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Channel Attention
|
| 100 |
+
self.channel_attention = nn.Sequential(
|
| 101 |
+
nn.AdaptiveAvgPool2d(1),
|
| 102 |
+
nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),
|
| 103 |
+
nn.ReLU(inplace=True),
|
| 104 |
+
nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),
|
| 105 |
+
nn.Sigmoid()
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Convolutional Layers
|
| 109 |
+
if scale_factor == 2:
|
| 110 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 111 |
+
elif scale_factor == 4:
|
| 112 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
| 113 |
+
|
| 114 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
| 115 |
+
|
| 116 |
+
# Group Normalization
|
| 117 |
+
self.group_norm = nn.GroupNorm(20, out_channels)
|
| 118 |
+
|
| 119 |
+
# ReLU Activation
|
| 120 |
+
self.relu = nn.ReLU(inplace=True)
|
| 121 |
+
|
| 122 |
+
def forward(self, x):
|
| 123 |
+
# Apply spatial attention
|
| 124 |
+
spatial_attention = self.spatial_attention(x)
|
| 125 |
+
x = x * spatial_attention
|
| 126 |
+
|
| 127 |
+
# Apply channel attention
|
| 128 |
+
channel_attention = self.channel_attention(x)
|
| 129 |
+
x = x * channel_attention
|
| 130 |
+
|
| 131 |
+
# Apply convolutional layers
|
| 132 |
+
x = self.conv1(x)
|
| 133 |
+
x = self.group_norm(x)
|
| 134 |
+
x = self.relu(x)
|
| 135 |
+
x = self.conv2(x)
|
| 136 |
+
x = self.group_norm(x)
|
| 137 |
+
x = self.relu(x)
|
| 138 |
+
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class AttentionUpsamplingModule(nn.Module):
|
| 143 |
+
def __init__(self, in_channels, out_channels):
|
| 144 |
+
super(AttentionUpsamplingModule, self).__init__()
|
| 145 |
+
|
| 146 |
+
# Spatial Attention for outs[2]
|
| 147 |
+
self.spatial_attention = nn.Sequential(
|
| 148 |
+
nn.Conv2d(in_channels, 1, kernel_size=1),
|
| 149 |
+
nn.Sigmoid()
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Channel Attention for outs[2]
|
| 153 |
+
self.channel_attention = nn.Sequential(
|
| 154 |
+
nn.AdaptiveAvgPool2d(1),
|
| 155 |
+
nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),
|
| 156 |
+
nn.ReLU(),
|
| 157 |
+
nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),
|
| 158 |
+
nn.Sigmoid()
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 162 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 163 |
+
|
| 164 |
+
# Group Normalization
|
| 165 |
+
self.group_norm = nn.GroupNorm(20, out_channels)
|
| 166 |
+
|
| 167 |
+
# ReLU Activation
|
| 168 |
+
self.relu = nn.ReLU()
|
| 169 |
+
self.upscale = PixelShuffle(in_channels, 2)
|
| 170 |
+
|
| 171 |
+
def forward(self, x):
|
| 172 |
+
# Apply spatial attention
|
| 173 |
+
spatial_attention = self.spatial_attention(x)
|
| 174 |
+
x = x * spatial_attention
|
| 175 |
+
|
| 176 |
+
# Apply channel attention
|
| 177 |
+
channel_attention = self.channel_attention(x)
|
| 178 |
+
x = x * channel_attention
|
| 179 |
+
|
| 180 |
+
# Apply convolutional layers
|
| 181 |
+
x = self.conv1(x)
|
| 182 |
+
x = self.group_norm(x)
|
| 183 |
+
x = self.relu(x)
|
| 184 |
+
x = self.conv2(x)
|
| 185 |
+
x = self.group_norm(x)
|
| 186 |
+
x = self.relu(x)
|
| 187 |
+
|
| 188 |
+
# Upsample
|
| 189 |
+
x = self.upscale(x)
|
| 190 |
+
|
| 191 |
+
return x
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class ConvLayer(nn.Module):
|
| 195 |
+
def __init__(self, in_channels, out_channels):
|
| 196 |
+
super(ConvLayer, self).__init__()
|
| 197 |
+
|
| 198 |
+
self.conv1 = nn.Sequential(
|
| 199 |
+
nn.Conv2d(in_channels, out_channels, 1),
|
| 200 |
+
nn.GroupNorm(20, out_channels),
|
| 201 |
+
nn.ReLU(),
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
def forward(self, x):
|
| 205 |
+
x = self.conv1(x)
|
| 206 |
+
|
| 207 |
+
return x
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class InverseMultiAttentiveFeatureRefinement(nn.Module):
|
| 211 |
+
def __init__(self, in_channels_list):
|
| 212 |
+
super(InverseMultiAttentiveFeatureRefinement, self).__init__()
|
| 213 |
+
|
| 214 |
+
self.layer1 = AttentionModule(in_channels_list[0], in_channels_list[0])
|
| 215 |
+
self.layer2 = AttentionDownsamplingModule(in_channels_list[0], in_channels_list[0]//2, scale_factor = 2)
|
| 216 |
+
self.layer3 = ConvLayer(in_channels_list[0]//2 + in_channels_list[1], in_channels_list[1])
|
| 217 |
+
self.layer4 = AttentionDownsamplingModule(in_channels_list[1], in_channels_list[1]//2, scale_factor = 2)
|
| 218 |
+
self.layer5 = ConvLayer(in_channels_list[1]//2 + in_channels_list[2], in_channels_list[2])
|
| 219 |
+
self.layer6 = AttentionDownsamplingModule(in_channels_list[2], in_channels_list[2]//2, scale_factor = 2)
|
| 220 |
+
self.layer7 = ConvLayer(in_channels_list[2]//2 + in_channels_list[3], in_channels_list[3])
|
| 221 |
+
|
| 222 |
+
'''
|
| 223 |
+
self.layer8 = AttentionUpsamplingModule(in_channels_list[3], in_channels_list[3])
|
| 224 |
+
self.layer9 = ConvLayer(in_channels_list[2] + in_channels_list[3], in_channels_list[2])
|
| 225 |
+
self.layer10 = AttentionUpsamplingModule(in_channels_list[2], in_channels_list[2])
|
| 226 |
+
self.layer11 = ConvLayer(in_channels_list[1] + in_channels_list[2], in_channels_list[1])
|
| 227 |
+
self.layer12 = AttentionUpsamplingModule(in_channels_list[1], in_channels_list[1])
|
| 228 |
+
self.layer13 = ConvLayer(in_channels_list[0] + in_channels_list[1], in_channels_list[0])
|
| 229 |
+
'''
|
| 230 |
+
def forward(self, inputs):
|
| 231 |
+
x_c4, x_c3, x_c2, x_c1 = inputs
|
| 232 |
+
x_c4 = self.layer1(x_c4)
|
| 233 |
+
x_c4_3 = self.layer2(x_c4)
|
| 234 |
+
x_c3 = torch.cat([x_c4_3, x_c3], dim=1)
|
| 235 |
+
x_c3 = self.layer3(x_c3)
|
| 236 |
+
x_c3_2 = self.layer4(x_c3)
|
| 237 |
+
x_c2 = torch.cat([x_c3_2, x_c2], dim=1)
|
| 238 |
+
x_c2 = self.layer5(x_c2)
|
| 239 |
+
x_c2_1 = self.layer6(x_c2)
|
| 240 |
+
x_c1 = torch.cat([x_c2_1, x_c1], dim=1)
|
| 241 |
+
x_c1 = self.layer7(x_c1)
|
| 242 |
+
'''
|
| 243 |
+
x_c1_2 = self.layer8(x_c1)
|
| 244 |
+
x_c2 = torch.cat([x_c1_2, x_c2], dim=1)
|
| 245 |
+
x_c2 = self.layer9(x_c2)
|
| 246 |
+
x_c2_3 = self.layer10(x_c2)
|
| 247 |
+
x_c3 = torch.cat([x_c2_3, x_c3], dim=1)
|
| 248 |
+
x_c3 = self.layer11(x_c3)
|
| 249 |
+
x_c3_4 = self.layer12(x_c3)
|
| 250 |
+
x_c4 = torch.cat([x_c3_4, x_c4], dim=1)
|
| 251 |
+
x_c4 = self.layer13(x_c4)
|
| 252 |
+
'''
|
| 253 |
+
return [x_c4, x_c3, x_c2, x_c1]
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class EVPDepthEncoder(nn.Module):
|
| 257 |
+
def __init__(self, out_dim=1024, ldm_prior=[320, 680, 1320+1280], sd_path=None, text_dim=768,
|
| 258 |
+
dataset='nyu', caption_aggregation=False
|
| 259 |
+
):
|
| 260 |
+
super().__init__()
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
self.layer1 = nn.Sequential(
|
| 264 |
+
nn.Conv2d(ldm_prior[0], ldm_prior[0], 3, stride=2, padding=1),
|
| 265 |
+
nn.GroupNorm(16, ldm_prior[0]),
|
| 266 |
+
nn.ReLU(),
|
| 267 |
+
nn.Conv2d(ldm_prior[0], ldm_prior[0], 3, stride=2, padding=1),
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
self.layer2 = nn.Sequential(
|
| 271 |
+
nn.Conv2d(ldm_prior[1], ldm_prior[1], 3, stride=2, padding=1),
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
self.out_layer = nn.Sequential(
|
| 275 |
+
nn.Conv2d(sum(ldm_prior), out_dim, 1),
|
| 276 |
+
nn.GroupNorm(16, out_dim),
|
| 277 |
+
nn.ReLU(),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
self.aggregation = InverseMultiAttentiveFeatureRefinement([320, 680, 1320, 1280])
|
| 281 |
+
|
| 282 |
+
self.apply(self._init_weights)
|
| 283 |
+
|
| 284 |
+
### stable diffusion layers
|
| 285 |
+
|
| 286 |
+
config = OmegaConf.load('./v1-inference.yaml')
|
| 287 |
+
if sd_path is None:
|
| 288 |
+
if os.path.exists('../checkpoints/v1-5-pruned-emaonly.ckpt'):
|
| 289 |
+
config.model.params.ckpt_path = '../checkpoints/v1-5-pruned-emaonly.ckpt'
|
| 290 |
+
else:
|
| 291 |
+
config.model.params.ckpt_path = None
|
| 292 |
+
else:
|
| 293 |
+
config.model.params.ckpt_path = f'../{sd_path}'
|
| 294 |
+
|
| 295 |
+
sd_model = instantiate_from_config(config.model)
|
| 296 |
+
self.encoder_vq = sd_model.first_stage_model
|
| 297 |
+
|
| 298 |
+
self.unet = UNetWrapper(sd_model.model, use_attn=True)
|
| 299 |
+
if dataset == 'kitti':
|
| 300 |
+
self.unet = UNetWrapper(sd_model.model, use_attn=True, base_size=384)
|
| 301 |
+
|
| 302 |
+
del sd_model.cond_stage_model
|
| 303 |
+
del self.encoder_vq.decoder
|
| 304 |
+
del self.unet.unet.diffusion_model.out
|
| 305 |
+
del self.encoder_vq.post_quant_conv.weight
|
| 306 |
+
del self.encoder_vq.post_quant_conv.bias
|
| 307 |
+
|
| 308 |
+
for param in self.encoder_vq.parameters():
|
| 309 |
+
param.requires_grad = True
|
| 310 |
+
|
| 311 |
+
self.text_adapter = TextAdapterRefer(text_dim=text_dim)
|
| 312 |
+
self.gamma = nn.Parameter(torch.ones(text_dim) * 1e-4)
|
| 313 |
+
|
| 314 |
+
if caption_aggregation:
|
| 315 |
+
class_embeddings = torch.load(f'{dataset}_class_embeddings_my_captions.pth')
|
| 316 |
+
#class_embeddings_list = [value['class_embeddings'] for key, value in class_embeddings.items()]
|
| 317 |
+
#stacked_embeddings = torch.stack(class_embeddings_list, dim=0)
|
| 318 |
+
#class_embeddings = torch.mean(stacked_embeddings, dim=0).unsqueeze(0)
|
| 319 |
+
|
| 320 |
+
if 'aggregated' in class_embeddings:
|
| 321 |
+
class_embeddings = class_embeddings['aggregated']
|
| 322 |
+
else:
|
| 323 |
+
clip_model = FrozenCLIPEmbedder(max_length=40,pool=False).cuda()
|
| 324 |
+
class_embeddings_new = [clip_model.encode(value['caption'][0]) for key, value in class_embeddings.items()]
|
| 325 |
+
class_embeddings_new = torch.mean(torch.stack(class_embeddings_new, dim=0), dim=0)
|
| 326 |
+
class_embeddings['aggregated'] = class_embeddings_new
|
| 327 |
+
torch.save(class_embeddings, f'{dataset}_class_embeddings_my_captions.pth')
|
| 328 |
+
class_embeddings = class_embeddings['aggregated']
|
| 329 |
+
self.register_buffer('class_embeddings', class_embeddings)
|
| 330 |
+
else:
|
| 331 |
+
self.class_embeddings = torch.load(f'{dataset}_class_embeddings_my_captions.pth')
|
| 332 |
+
|
| 333 |
+
self.clip_model = FrozenCLIPEmbedder(max_length=40,pool=False)
|
| 334 |
+
for param in self.clip_model.parameters():
|
| 335 |
+
param.requires_grad = True
|
| 336 |
+
|
| 337 |
+
#if dataset == 'kitti':
|
| 338 |
+
# self.text_adapter_ = TextAdapterRefer(text_dim=text_dim)
|
| 339 |
+
# self.gamma_ = nn.Parameter(torch.ones(text_dim) * 1e-4)
|
| 340 |
+
|
| 341 |
+
self.caption_aggregation = caption_aggregation
|
| 342 |
+
self.dataset = dataset
|
| 343 |
+
|
| 344 |
+
def _init_weights(self, m):
|
| 345 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 346 |
+
trunc_normal_(m.weight, std=.02)
|
| 347 |
+
nn.init.constant_(m.bias, 0)
|
| 348 |
+
|
| 349 |
+
def forward_features(self, feats):
|
| 350 |
+
x = self.ldm_to_net[0](feats[0])
|
| 351 |
+
for i in range(3):
|
| 352 |
+
if i > 0:
|
| 353 |
+
x = x + self.ldm_to_net[i](feats[i])
|
| 354 |
+
x = self.layers[i](x)
|
| 355 |
+
x = self.upsample_layers[i](x)
|
| 356 |
+
return self.out_conv(x)
|
| 357 |
+
|
| 358 |
+
def forward(self, x, class_ids=None, img_paths=None):
|
| 359 |
+
latents = self.encoder_vq.encode(x).mode()
|
| 360 |
+
|
| 361 |
+
# add division by std
|
| 362 |
+
if self.dataset == 'nyu':
|
| 363 |
+
latents = latents / 5.07543
|
| 364 |
+
elif self.dataset == 'kitti':
|
| 365 |
+
latents = latents / 4.6211
|
| 366 |
+
else:
|
| 367 |
+
print('Please calculate the STD for the dataset!')
|
| 368 |
+
|
| 369 |
+
if class_ids is not None:
|
| 370 |
+
if self.caption_aggregation:
|
| 371 |
+
class_embeddings = self.class_embeddings[[0]*len(class_ids.tolist())]#[class_ids.tolist()]
|
| 372 |
+
else:
|
| 373 |
+
class_embeddings = []
|
| 374 |
+
|
| 375 |
+
for img_path in img_paths:
|
| 376 |
+
class_embeddings.extend([value['caption'][0] for key, value in self.class_embeddings.items() if key in img_path.replace('//', '/')])
|
| 377 |
+
|
| 378 |
+
class_embeddings = self.clip_model.encode(class_embeddings)
|
| 379 |
+
else:
|
| 380 |
+
class_embeddings = self.class_embeddings
|
| 381 |
+
|
| 382 |
+
c_crossattn = self.text_adapter(latents, class_embeddings, self.gamma)
|
| 383 |
+
t = torch.ones((x.shape[0],), device=x.device).long()
|
| 384 |
+
|
| 385 |
+
#if self.dataset == 'kitti':
|
| 386 |
+
# c_crossattn_last = self.text_adapter_(latents, class_embeddings, self.gamma_)
|
| 387 |
+
# outs = self.unet(latents, t, c_crossattn=[c_crossattn, c_crossattn_last])
|
| 388 |
+
#else:
|
| 389 |
+
outs = self.unet(latents, t, c_crossattn=[c_crossattn])
|
| 390 |
+
outs = self.aggregation(outs)
|
| 391 |
+
|
| 392 |
+
feats = [outs[0], outs[1], torch.cat([outs[2], F.interpolate(outs[3], scale_factor=2)], dim=1)]
|
| 393 |
+
x = torch.cat([self.layer1(feats[0]), self.layer2(feats[1]), feats[2]], dim=1)
|
| 394 |
+
return self.out_layer(x)
|
| 395 |
+
|
| 396 |
+
def get_latent(self, x):
|
| 397 |
+
return self.encoder_vq.encode(x).mode()
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class EVPDepth(nn.Module):
|
| 401 |
+
def __init__(self, args=None, caption_aggregation=False):
|
| 402 |
+
super().__init__()
|
| 403 |
+
self.max_depth = args.max_depth
|
| 404 |
+
self.min_depth = args.min_depth_eval
|
| 405 |
+
|
| 406 |
+
embed_dim = 192
|
| 407 |
+
|
| 408 |
+
channels_in = embed_dim*8
|
| 409 |
+
channels_out = embed_dim
|
| 410 |
+
|
| 411 |
+
if args.dataset == 'nyudepthv2':
|
| 412 |
+
self.encoder = EVPDepthEncoder(out_dim=channels_in, dataset='nyu', caption_aggregation=caption_aggregation)
|
| 413 |
+
else:
|
| 414 |
+
self.encoder = EVPDepthEncoder(out_dim=channels_in, dataset='kitti', caption_aggregation=caption_aggregation)
|
| 415 |
+
|
| 416 |
+
self.decoder = Decoder(channels_in, channels_out, args)
|
| 417 |
+
self.decoder.init_weights()
|
| 418 |
+
self.mViT = False
|
| 419 |
+
self.custom = False
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
if not self.mViT and not self.custom:
|
| 423 |
+
n_bins = 64
|
| 424 |
+
bin_embedding_dim = 128
|
| 425 |
+
num_out_features = [32, 32, 32, 192]
|
| 426 |
+
min_temp = 0.0212
|
| 427 |
+
max_temp = 50
|
| 428 |
+
btlnck_features = 256
|
| 429 |
+
n_attractors = [16, 8, 4, 1]
|
| 430 |
+
attractor_alpha = 1000
|
| 431 |
+
attractor_gamma = 2
|
| 432 |
+
attractor_kind = "mean"
|
| 433 |
+
attractor_type = "inv"
|
| 434 |
+
self.bin_centers_type = "softplus"
|
| 435 |
+
|
| 436 |
+
self.bottle_neck = nn.Sequential(
|
| 437 |
+
nn.Conv2d(channels_in, btlnck_features, kernel_size=3, stride=1, padding=1),
|
| 438 |
+
nn.ReLU(inplace=False),
|
| 439 |
+
nn.Conv2d(btlnck_features, btlnck_features, kernel_size=3, stride=1, padding=1))
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
for m in self.bottle_neck.modules():
|
| 443 |
+
if isinstance(m, nn.Conv2d):
|
| 444 |
+
normal_init(m, std=0.001, bias=0)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
SeedBinRegressorLayer = SeedBinRegressorUnnormed
|
| 448 |
+
Attractor = AttractorLayerUnnormed
|
| 449 |
+
self.seed_bin_regressor = SeedBinRegressorLayer(
|
| 450 |
+
btlnck_features, n_bins=n_bins, min_depth=self.min_depth, max_depth=self.max_depth)
|
| 451 |
+
self.seed_projector = Projector(btlnck_features, bin_embedding_dim)
|
| 452 |
+
self.projectors = nn.ModuleList([
|
| 453 |
+
Projector(num_out, bin_embedding_dim)
|
| 454 |
+
for num_out in num_out_features
|
| 455 |
+
])
|
| 456 |
+
self.attractors = nn.ModuleList([
|
| 457 |
+
Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=self.min_depth, max_depth=self.max_depth,
|
| 458 |
+
alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type)
|
| 459 |
+
for i in range(len(num_out_features))
|
| 460 |
+
])
|
| 461 |
+
|
| 462 |
+
last_in = 192 + 1
|
| 463 |
+
self.conditional_log_binomial = ConditionalLogBinomial(
|
| 464 |
+
last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp)
|
| 465 |
+
elif self.mViT and not self.custom:
|
| 466 |
+
n_bins = 256
|
| 467 |
+
self.adaptive_bins_layer = mViT(192, n_query_channels=192, patch_size=16,
|
| 468 |
+
dim_out=n_bins,
|
| 469 |
+
embedding_dim=192, norm='linear')
|
| 470 |
+
self.conv_out = nn.Sequential(nn.Conv2d(192, n_bins, kernel_size=1, stride=1, padding=0),
|
| 471 |
+
nn.Softmax(dim=1))
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def forward(self, x, class_ids=None, img_paths=None):
|
| 475 |
+
b, c, h, w = x.shape
|
| 476 |
+
x = x*2.0 - 1.0 # normalize to [-1, 1]
|
| 477 |
+
if h == 480 and w == 480:
|
| 478 |
+
new_x = torch.zeros(b, c, 512, 512, device=x.device)
|
| 479 |
+
new_x[:, :, 0:480, 0:480] = x
|
| 480 |
+
x = new_x
|
| 481 |
+
elif h==352 and w==352:
|
| 482 |
+
new_x = torch.zeros(b, c, 384, 384, device=x.device)
|
| 483 |
+
new_x[:, :, 0:352, 0:352] = x
|
| 484 |
+
x = new_x
|
| 485 |
+
elif h == 512 and w == 512:
|
| 486 |
+
pass
|
| 487 |
+
else:
|
| 488 |
+
print(h,w)
|
| 489 |
+
raise NotImplementedError
|
| 490 |
+
conv_feats = self.encoder(x, class_ids, img_paths)
|
| 491 |
+
|
| 492 |
+
if h == 480 or h == 352:
|
| 493 |
+
conv_feats = conv_feats[:, :, :-1, :-1]
|
| 494 |
+
|
| 495 |
+
self.decoder.remove_hooks()
|
| 496 |
+
out_depth, out, x_blocks = self.decoder([conv_feats])
|
| 497 |
+
|
| 498 |
+
if not self.mViT and not self.custom:
|
| 499 |
+
x = self.bottle_neck(conv_feats)
|
| 500 |
+
_, seed_b_centers = self.seed_bin_regressor(x)
|
| 501 |
+
|
| 502 |
+
if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2':
|
| 503 |
+
b_prev = (seed_b_centers - self.min_depth) / \
|
| 504 |
+
(self.max_depth - self.min_depth)
|
| 505 |
+
else:
|
| 506 |
+
b_prev = seed_b_centers
|
| 507 |
+
|
| 508 |
+
prev_b_embedding = self.seed_projector(x)
|
| 509 |
+
|
| 510 |
+
for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks):
|
| 511 |
+
b_embedding = projector(x)
|
| 512 |
+
b, b_centers = attractor(
|
| 513 |
+
b_embedding, b_prev, prev_b_embedding, interpolate=True)
|
| 514 |
+
b_prev = b.clone()
|
| 515 |
+
prev_b_embedding = b_embedding.clone()
|
| 516 |
+
|
| 517 |
+
rel_cond = torch.sigmoid(out_depth) * self.max_depth
|
| 518 |
+
|
| 519 |
+
# concat rel depth with last. First interpolate rel depth to last size
|
| 520 |
+
rel_cond = nn.functional.interpolate(
|
| 521 |
+
rel_cond, size=out.shape[2:], mode='bilinear', align_corners=True)
|
| 522 |
+
last = torch.cat([out, rel_cond], dim=1)
|
| 523 |
+
|
| 524 |
+
b_embedding = nn.functional.interpolate(
|
| 525 |
+
b_embedding, last.shape[-2:], mode='bilinear', align_corners=True)
|
| 526 |
+
x = self.conditional_log_binomial(last, b_embedding)
|
| 527 |
+
|
| 528 |
+
# Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor
|
| 529 |
+
b_centers = nn.functional.interpolate(
|
| 530 |
+
b_centers, x.shape[-2:], mode='bilinear', align_corners=True)
|
| 531 |
+
out_depth = torch.sum(x * b_centers, dim=1, keepdim=True)
|
| 532 |
+
|
| 533 |
+
elif self.mViT and not self.custom:
|
| 534 |
+
bin_widths_normed, range_attention_maps = self.adaptive_bins_layer(out)
|
| 535 |
+
out = self.conv_out(range_attention_maps)
|
| 536 |
+
|
| 537 |
+
bin_widths = (self.max_depth - self.min_depth) * bin_widths_normed # .shape = N, dim_out
|
| 538 |
+
bin_widths = nn.functional.pad(bin_widths, (1, 0), mode='constant', value=self.min_depth)
|
| 539 |
+
bin_edges = torch.cumsum(bin_widths, dim=1)
|
| 540 |
+
|
| 541 |
+
centers = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])
|
| 542 |
+
n, dout = centers.size()
|
| 543 |
+
centers = centers.view(n, dout, 1, 1)
|
| 544 |
+
|
| 545 |
+
out_depth = torch.sum(out * centers, dim=1, keepdim=True)
|
| 546 |
+
else:
|
| 547 |
+
out_depth = torch.sigmoid(out_depth) * self.max_depth
|
| 548 |
+
|
| 549 |
+
return {'pred_d': out_depth}
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
class Decoder(nn.Module):
|
| 553 |
+
def __init__(self, in_channels, out_channels, args):
|
| 554 |
+
super().__init__()
|
| 555 |
+
self.deconv = args.num_deconv
|
| 556 |
+
self.in_channels = in_channels
|
| 557 |
+
|
| 558 |
+
embed_dim = 192
|
| 559 |
+
|
| 560 |
+
channels_in = embed_dim*8
|
| 561 |
+
channels_out = embed_dim
|
| 562 |
+
|
| 563 |
+
self.deconv_layers, self.intermediate_results = self._make_deconv_layer(
|
| 564 |
+
args.num_deconv,
|
| 565 |
+
args.num_filters,
|
| 566 |
+
args.deconv_kernels,
|
| 567 |
+
)
|
| 568 |
+
self.last_layer_depth = nn.Sequential(
|
| 569 |
+
nn.Conv2d(channels_out, channels_out, kernel_size=3, stride=1, padding=1),
|
| 570 |
+
nn.ReLU(inplace=False),
|
| 571 |
+
nn.Conv2d(channels_out, 1, kernel_size=3, stride=1, padding=1))
|
| 572 |
+
|
| 573 |
+
for m in self.last_layer_depth.modules():
|
| 574 |
+
if isinstance(m, nn.Conv2d):
|
| 575 |
+
normal_init(m, std=0.001, bias=0)
|
| 576 |
+
|
| 577 |
+
conv_layers = []
|
| 578 |
+
conv_layers.append(
|
| 579 |
+
build_conv_layer(
|
| 580 |
+
dict(type='Conv2d'),
|
| 581 |
+
in_channels=args.num_filters[-1],
|
| 582 |
+
out_channels=out_channels,
|
| 583 |
+
kernel_size=3,
|
| 584 |
+
stride=1,
|
| 585 |
+
padding=1))
|
| 586 |
+
conv_layers.append(
|
| 587 |
+
build_norm_layer(dict(type='BN'), out_channels)[1])
|
| 588 |
+
conv_layers.append(nn.ReLU())
|
| 589 |
+
self.conv_layers = nn.Sequential(*conv_layers)
|
| 590 |
+
|
| 591 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
| 592 |
+
|
| 593 |
+
def forward(self, conv_feats):
|
| 594 |
+
out = self.deconv_layers(conv_feats[0])
|
| 595 |
+
out = self.conv_layers(out)
|
| 596 |
+
out = self.up(out)
|
| 597 |
+
self.intermediate_results.append(out)
|
| 598 |
+
out = self.up(out)
|
| 599 |
+
out_depth = self.last_layer_depth(out)
|
| 600 |
+
|
| 601 |
+
return out_depth, out, self.intermediate_results
|
| 602 |
+
|
| 603 |
+
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
| 604 |
+
"""Make deconv layers."""
|
| 605 |
+
|
| 606 |
+
layers = []
|
| 607 |
+
in_planes = self.in_channels
|
| 608 |
+
intermediate_results = [] # List to store intermediate feature maps
|
| 609 |
+
|
| 610 |
+
for i in range(num_layers):
|
| 611 |
+
kernel, padding, output_padding = \
|
| 612 |
+
self._get_deconv_cfg(num_kernels[i])
|
| 613 |
+
|
| 614 |
+
planes = num_filters[i]
|
| 615 |
+
layers.append(
|
| 616 |
+
build_upsample_layer(
|
| 617 |
+
dict(type='deconv'),
|
| 618 |
+
in_channels=in_planes,
|
| 619 |
+
out_channels=planes,
|
| 620 |
+
kernel_size=kernel,
|
| 621 |
+
stride=2,
|
| 622 |
+
padding=padding,
|
| 623 |
+
output_padding=output_padding,
|
| 624 |
+
bias=False))
|
| 625 |
+
layers.append(nn.BatchNorm2d(planes))
|
| 626 |
+
layers.append(nn.ReLU())
|
| 627 |
+
in_planes = planes
|
| 628 |
+
|
| 629 |
+
# Add a hook to store the intermediate result
|
| 630 |
+
layers[-1].register_forward_hook(self._hook_fn(intermediate_results))
|
| 631 |
+
|
| 632 |
+
return nn.Sequential(*layers), intermediate_results
|
| 633 |
+
|
| 634 |
+
def _hook_fn(self, intermediate_results):
|
| 635 |
+
def hook(module, input, output):
|
| 636 |
+
intermediate_results.append(output)
|
| 637 |
+
return hook
|
| 638 |
+
|
| 639 |
+
def remove_hooks(self):
|
| 640 |
+
self.intermediate_results.clear()
|
| 641 |
+
|
| 642 |
+
def _get_deconv_cfg(self, deconv_kernel):
|
| 643 |
+
"""Get configurations for deconv layers."""
|
| 644 |
+
if deconv_kernel == 4:
|
| 645 |
+
padding = 1
|
| 646 |
+
output_padding = 0
|
| 647 |
+
elif deconv_kernel == 3:
|
| 648 |
+
padding = 1
|
| 649 |
+
output_padding = 1
|
| 650 |
+
elif deconv_kernel == 2:
|
| 651 |
+
padding = 0
|
| 652 |
+
output_padding = 0
|
| 653 |
+
else:
|
| 654 |
+
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
|
| 655 |
+
|
| 656 |
+
return deconv_kernel, padding, output_padding
|
| 657 |
+
|
| 658 |
+
def init_weights(self):
|
| 659 |
+
"""Initialize model weights."""
|
| 660 |
+
for m in self.modules():
|
| 661 |
+
if isinstance(m, nn.Conv2d):
|
| 662 |
+
normal_init(m, std=0.001, bias=0)
|
| 663 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 664 |
+
constant_init(m, 1)
|
| 665 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 666 |
+
normal_init(m, std=0.001)
|
depth/models_depth/model_vpd.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) Microsoft
|
| 3 |
+
# Licensed under the MIT License.
|
| 4 |
+
# The deconvolution code is based on Simple Baseline.
|
| 5 |
+
# (https://github.com/microsoft/human-pose-estimation.pytorch/blob/master/lib/models/pose_resnet.py)
|
| 6 |
+
# Modified by Zigang Geng ([email protected]).
|
| 7 |
+
# ------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 12 |
+
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
|
| 13 |
+
constant_init, normal_init)
|
| 14 |
+
from omegaconf import OmegaConf
|
| 15 |
+
from ldm.util import instantiate_from_config
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
from evp.models import UNetWrapper, TextAdapterDepth
|
| 19 |
+
|
| 20 |
+
class VPDDepthEncoder(nn.Module):
|
| 21 |
+
def __init__(self, out_dim=1024, ldm_prior=[320, 640, 1280+1280], sd_path=None, text_dim=768,
|
| 22 |
+
dataset='nyu'
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
self.layer1 = nn.Sequential(
|
| 28 |
+
nn.Conv2d(ldm_prior[0], ldm_prior[0], 3, stride=2, padding=1),
|
| 29 |
+
nn.GroupNorm(16, ldm_prior[0]),
|
| 30 |
+
nn.ReLU(),
|
| 31 |
+
nn.Conv2d(ldm_prior[0], ldm_prior[0], 3, stride=2, padding=1),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
self.layer2 = nn.Sequential(
|
| 35 |
+
nn.Conv2d(ldm_prior[1], ldm_prior[1], 3, stride=2, padding=1),
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
self.out_layer = nn.Sequential(
|
| 39 |
+
nn.Conv2d(sum(ldm_prior), out_dim, 1),
|
| 40 |
+
nn.GroupNorm(16, out_dim),
|
| 41 |
+
nn.ReLU(),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.apply(self._init_weights)
|
| 45 |
+
|
| 46 |
+
### stable diffusion layers
|
| 47 |
+
|
| 48 |
+
config = OmegaConf.load('./v1-inference.yaml')
|
| 49 |
+
if sd_path is None:
|
| 50 |
+
config.model.params.ckpt_path = '../checkpoints/v1-5-pruned-emaonly.ckpt'
|
| 51 |
+
else:
|
| 52 |
+
config.model.params.ckpt_path = f'../{sd_path}'
|
| 53 |
+
|
| 54 |
+
sd_model = instantiate_from_config(config.model)
|
| 55 |
+
self.encoder_vq = sd_model.first_stage_model
|
| 56 |
+
|
| 57 |
+
self.unet = UNetWrapper(sd_model.model, use_attn=False)
|
| 58 |
+
|
| 59 |
+
del sd_model.cond_stage_model
|
| 60 |
+
del self.encoder_vq.decoder
|
| 61 |
+
del self.unet.unet.diffusion_model.out
|
| 62 |
+
|
| 63 |
+
for param in self.encoder_vq.parameters():
|
| 64 |
+
param.requires_grad = False
|
| 65 |
+
|
| 66 |
+
if dataset == 'nyu':
|
| 67 |
+
self.text_adapter = TextAdapterDepth(text_dim=text_dim)
|
| 68 |
+
class_embeddings = torch.load('nyu_class_embeddings.pth')
|
| 69 |
+
else:
|
| 70 |
+
raise NotImplementedError
|
| 71 |
+
|
| 72 |
+
self.register_buffer('class_embeddings', class_embeddings)
|
| 73 |
+
self.gamma = nn.Parameter(torch.ones(text_dim) * 1e-4)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _init_weights(self, m):
|
| 77 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 78 |
+
trunc_normal_(m.weight, std=.02)
|
| 79 |
+
nn.init.constant_(m.bias, 0)
|
| 80 |
+
|
| 81 |
+
def forward_features(self, feats):
|
| 82 |
+
x = self.ldm_to_net[0](feats[0])
|
| 83 |
+
for i in range(3):
|
| 84 |
+
if i > 0:
|
| 85 |
+
x = x + self.ldm_to_net[i](feats[i])
|
| 86 |
+
x = self.layers[i](x)
|
| 87 |
+
x = self.upsample_layers[i](x)
|
| 88 |
+
return self.out_conv(x)
|
| 89 |
+
|
| 90 |
+
def forward(self, x, class_ids=None,img_paths=None):
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
latents = self.encoder_vq.encode(x).mode().detach()
|
| 93 |
+
|
| 94 |
+
if class_ids is not None:
|
| 95 |
+
class_embeddings = self.class_embeddings[class_ids.tolist()]
|
| 96 |
+
else:
|
| 97 |
+
class_embeddings = self.class_embeddings
|
| 98 |
+
|
| 99 |
+
c_crossattn = self.text_adapter(latents, class_embeddings, self.gamma) # NOTE: here the c_crossattn should be expand_dim as latents
|
| 100 |
+
t = torch.ones((x.shape[0],), device=x.device).long()
|
| 101 |
+
# import pdb; pdb.set_trace()
|
| 102 |
+
outs = self.unet(latents, t, c_crossattn=[c_crossattn])
|
| 103 |
+
feats = [outs[0], outs[1], torch.cat([outs[2], F.interpolate(outs[3], scale_factor=2)], dim=1)]
|
| 104 |
+
x = torch.cat([self.layer1(feats[0]), self.layer2(feats[1]), feats[2]], dim=1)
|
| 105 |
+
return self.out_layer(x)
|
| 106 |
+
|
| 107 |
+
class VPDDepth(nn.Module):
|
| 108 |
+
def __init__(self, args=None):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.max_depth = args.max_depth
|
| 111 |
+
|
| 112 |
+
embed_dim = 192
|
| 113 |
+
|
| 114 |
+
channels_in = embed_dim*8
|
| 115 |
+
channels_out = embed_dim
|
| 116 |
+
|
| 117 |
+
if args.dataset == 'nyudepthv2':
|
| 118 |
+
self.encoder = VPDDepthEncoder(out_dim=channels_in, dataset='nyu')
|
| 119 |
+
else:
|
| 120 |
+
raise NotImplementedError
|
| 121 |
+
|
| 122 |
+
self.decoder = Decoder(channels_in, channels_out, args)
|
| 123 |
+
self.decoder.init_weights()
|
| 124 |
+
|
| 125 |
+
self.last_layer_depth = nn.Sequential(
|
| 126 |
+
nn.Conv2d(channels_out, channels_out, kernel_size=3, stride=1, padding=1),
|
| 127 |
+
nn.ReLU(inplace=False),
|
| 128 |
+
nn.Conv2d(channels_out, 1, kernel_size=3, stride=1, padding=1))
|
| 129 |
+
|
| 130 |
+
for m in self.last_layer_depth.modules():
|
| 131 |
+
if isinstance(m, nn.Conv2d):
|
| 132 |
+
normal_init(m, std=0.001, bias=0)
|
| 133 |
+
|
| 134 |
+
def forward(self, x, class_ids=None,img_paths=None):
|
| 135 |
+
# import pdb; pdb.set_trace()
|
| 136 |
+
b, c, h, w = x.shape
|
| 137 |
+
x = x*2.0 - 1.0 # normalize to [-1, 1]
|
| 138 |
+
if h == 480 and w == 480:
|
| 139 |
+
new_x = torch.zeros(b, c, 512, 512, device=x.device)
|
| 140 |
+
new_x[:, :, 0:480, 0:480] = x
|
| 141 |
+
x = new_x
|
| 142 |
+
elif h==352 and w==352:
|
| 143 |
+
new_x = torch.zeros(b, c, 384, 384, device=x.device)
|
| 144 |
+
new_x[:, :, 0:352, 0:352] = x
|
| 145 |
+
x = new_x
|
| 146 |
+
elif h == 512 and w == 512:
|
| 147 |
+
pass
|
| 148 |
+
else:
|
| 149 |
+
raise NotImplementedError
|
| 150 |
+
conv_feats = self.encoder(x, class_ids)
|
| 151 |
+
|
| 152 |
+
if h == 480 or h == 352:
|
| 153 |
+
conv_feats = conv_feats[:, :, :-1, :-1]
|
| 154 |
+
|
| 155 |
+
out = self.decoder([conv_feats])
|
| 156 |
+
out_depth = self.last_layer_depth(out)
|
| 157 |
+
out_depth = torch.sigmoid(out_depth) * self.max_depth
|
| 158 |
+
|
| 159 |
+
return {'pred_d': out_depth}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class Decoder(nn.Module):
|
| 163 |
+
def __init__(self, in_channels, out_channels, args):
|
| 164 |
+
super().__init__()
|
| 165 |
+
self.deconv = args.num_deconv
|
| 166 |
+
self.in_channels = in_channels
|
| 167 |
+
|
| 168 |
+
# import pdb; pdb.set_trace()
|
| 169 |
+
|
| 170 |
+
self.deconv_layers = self._make_deconv_layer(
|
| 171 |
+
args.num_deconv,
|
| 172 |
+
args.num_filters,
|
| 173 |
+
args.deconv_kernels,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
conv_layers = []
|
| 177 |
+
conv_layers.append(
|
| 178 |
+
build_conv_layer(
|
| 179 |
+
dict(type='Conv2d'),
|
| 180 |
+
in_channels=args.num_filters[-1],
|
| 181 |
+
out_channels=out_channels,
|
| 182 |
+
kernel_size=3,
|
| 183 |
+
stride=1,
|
| 184 |
+
padding=1))
|
| 185 |
+
conv_layers.append(
|
| 186 |
+
build_norm_layer(dict(type='BN'), out_channels)[1])
|
| 187 |
+
conv_layers.append(nn.ReLU(inplace=True))
|
| 188 |
+
self.conv_layers = nn.Sequential(*conv_layers)
|
| 189 |
+
|
| 190 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
| 191 |
+
|
| 192 |
+
def forward(self, conv_feats):
|
| 193 |
+
# import pdb; pdb.set_trace()
|
| 194 |
+
out = self.deconv_layers(conv_feats[0])
|
| 195 |
+
out = self.conv_layers(out)
|
| 196 |
+
|
| 197 |
+
out = self.up(out)
|
| 198 |
+
out = self.up(out)
|
| 199 |
+
|
| 200 |
+
return out
|
| 201 |
+
|
| 202 |
+
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
| 203 |
+
"""Make deconv layers."""
|
| 204 |
+
|
| 205 |
+
layers = []
|
| 206 |
+
in_planes = self.in_channels
|
| 207 |
+
for i in range(num_layers):
|
| 208 |
+
kernel, padding, output_padding = \
|
| 209 |
+
self._get_deconv_cfg(num_kernels[i])
|
| 210 |
+
|
| 211 |
+
planes = num_filters[i]
|
| 212 |
+
layers.append(
|
| 213 |
+
build_upsample_layer(
|
| 214 |
+
dict(type='deconv'),
|
| 215 |
+
in_channels=in_planes,
|
| 216 |
+
out_channels=planes,
|
| 217 |
+
kernel_size=kernel,
|
| 218 |
+
stride=2,
|
| 219 |
+
padding=padding,
|
| 220 |
+
output_padding=output_padding,
|
| 221 |
+
bias=False))
|
| 222 |
+
layers.append(nn.BatchNorm2d(planes))
|
| 223 |
+
layers.append(nn.ReLU(inplace=True))
|
| 224 |
+
in_planes = planes
|
| 225 |
+
|
| 226 |
+
return nn.Sequential(*layers)
|
| 227 |
+
|
| 228 |
+
def _get_deconv_cfg(self, deconv_kernel):
|
| 229 |
+
"""Get configurations for deconv layers."""
|
| 230 |
+
if deconv_kernel == 4:
|
| 231 |
+
padding = 1
|
| 232 |
+
output_padding = 0
|
| 233 |
+
elif deconv_kernel == 3:
|
| 234 |
+
padding = 1
|
| 235 |
+
output_padding = 1
|
| 236 |
+
elif deconv_kernel == 2:
|
| 237 |
+
padding = 0
|
| 238 |
+
output_padding = 0
|
| 239 |
+
else:
|
| 240 |
+
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
|
| 241 |
+
|
| 242 |
+
return deconv_kernel, padding, output_padding
|
| 243 |
+
|
| 244 |
+
def init_weights(self):
|
| 245 |
+
"""Initialize model weights."""
|
| 246 |
+
for m in self.modules():
|
| 247 |
+
if isinstance(m, nn.Conv2d):
|
| 248 |
+
normal_init(m, std=0.001, bias=0)
|
| 249 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 250 |
+
constant_init(m, 1)
|
| 251 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 252 |
+
normal_init(m, std=0.001)
|
depth/models_depth/optimizer.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) Microsoft
|
| 3 |
+
# Licensed under the MIT License.
|
| 4 |
+
# The code is from SimMIM.
|
| 5 |
+
# (https://github.com/microsoft/SimMIM)
|
| 6 |
+
# ------------------------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor
|
| 10 |
+
from mmcv.runner import build_optimizer
|
| 11 |
+
from mmcv.runner import get_dist_info
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage):
|
| 15 |
+
var_name = var_name.replace('encoder', 'backbone') if var_name.startswith('encoder') else var_name
|
| 16 |
+
if var_name in ("backbone.cls_token", "backbone.mask_token",
|
| 17 |
+
"backbone.pos_embed", "backbone.absolute_pos_embed"):
|
| 18 |
+
return 0
|
| 19 |
+
elif var_name.startswith("backbone.patch_embed"):
|
| 20 |
+
return 0
|
| 21 |
+
elif var_name.startswith("backbone.layers"):
|
| 22 |
+
if var_name.split('.')[3] == "blocks":
|
| 23 |
+
stage_id = int(var_name.split('.')[2])
|
| 24 |
+
layer_id = int(var_name.split('.')[4]) \
|
| 25 |
+
+ sum(layers_per_stage[:stage_id])
|
| 26 |
+
return layer_id + 1
|
| 27 |
+
elif var_name.split('.')[3] == "downsample":
|
| 28 |
+
stage_id = int(var_name.split('.')[2])
|
| 29 |
+
layer_id = sum(layers_per_stage[:stage_id + 1])
|
| 30 |
+
return layer_id
|
| 31 |
+
else:
|
| 32 |
+
return num_max_layer - 1
|
| 33 |
+
|
| 34 |
+
@OPTIMIZER_BUILDERS.register_module()
|
| 35 |
+
class LDMOptimizerConstructor(DefaultOptimizerConstructor):
|
| 36 |
+
def add_params(self, params, module, prefix='', is_dcn_module=None):
|
| 37 |
+
"""Add all parameters of module to the params list.
|
| 38 |
+
The parameters of the given module will be added to the list of param
|
| 39 |
+
groups, with specific rules defined by paramwise_cfg.
|
| 40 |
+
Args:
|
| 41 |
+
params (list[dict]): A list of param groups, it will be modified
|
| 42 |
+
in place.
|
| 43 |
+
module (nn.Module): The module to be added.
|
| 44 |
+
prefix (str): The prefix of the module
|
| 45 |
+
is_dcn_module (int|float|None): If the current module is a
|
| 46 |
+
submodule of DCN, `is_dcn_module` will be passed to
|
| 47 |
+
control conv_offset layer's learning rate. Defaults to None.
|
| 48 |
+
"""
|
| 49 |
+
parameter_groups = {}
|
| 50 |
+
no_decay_names = self.paramwise_cfg.get('no_decay_names', [])
|
| 51 |
+
print("Build LDMOptimizerConstructor")
|
| 52 |
+
weight_decay = self.base_wd
|
| 53 |
+
|
| 54 |
+
for name, param in module.named_parameters():
|
| 55 |
+
if not param.requires_grad:
|
| 56 |
+
continue # frozen weights
|
| 57 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in ('absolute_pos_embed'):
|
| 58 |
+
group_name = "no_decay"
|
| 59 |
+
this_weight_decay = 0.
|
| 60 |
+
else:
|
| 61 |
+
group_name = "decay"
|
| 62 |
+
this_weight_decay = weight_decay
|
| 63 |
+
|
| 64 |
+
for nd_name in no_decay_names:
|
| 65 |
+
if nd_name in name:
|
| 66 |
+
group_name = "no_decay"
|
| 67 |
+
this_weight_decay = 0.
|
| 68 |
+
break
|
| 69 |
+
|
| 70 |
+
if 'unet' in name or 'cond_stage_model' in name or 'encoder_vq' in name or 'clip_model' in name:
|
| 71 |
+
layer_id = 0
|
| 72 |
+
else:
|
| 73 |
+
layer_id = 1
|
| 74 |
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
| 75 |
+
|
| 76 |
+
if group_name not in parameter_groups:
|
| 77 |
+
if layer_id == 0:
|
| 78 |
+
scale = 0.01
|
| 79 |
+
else:
|
| 80 |
+
scale = 1.0
|
| 81 |
+
|
| 82 |
+
parameter_groups[group_name] = {
|
| 83 |
+
"weight_decay": this_weight_decay,
|
| 84 |
+
"params": [],
|
| 85 |
+
"param_names": [],
|
| 86 |
+
"lr_scale": scale,
|
| 87 |
+
"group_name": group_name,
|
| 88 |
+
"lr": scale * self.base_lr,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
parameter_groups[group_name]["params"].append(param)
|
| 92 |
+
parameter_groups[group_name]["param_names"].append(name)
|
| 93 |
+
rank, _ = get_dist_info()
|
| 94 |
+
if rank == 0:
|
| 95 |
+
to_display = {}
|
| 96 |
+
for key in parameter_groups:
|
| 97 |
+
to_display[key] = {
|
| 98 |
+
"param_names": parameter_groups[key]["param_names"],
|
| 99 |
+
"lr_scale": parameter_groups[key]["lr_scale"],
|
| 100 |
+
"lr": parameter_groups[key]["lr"],
|
| 101 |
+
"weight_decay": parameter_groups[key]["weight_decay"],
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
params.extend(parameter_groups.values())
|
| 105 |
+
|
| 106 |
+
def build_optimizers(model, cfgs):
|
| 107 |
+
"""Build multiple optimizers from configs.
|
| 108 |
+
|
| 109 |
+
If `cfgs` contains several dicts for optimizers, then a dict for each
|
| 110 |
+
constructed optimizers will be returned.
|
| 111 |
+
If `cfgs` only contains one optimizer config, the constructed optimizer
|
| 112 |
+
itself will be returned.
|
| 113 |
+
|
| 114 |
+
For example,
|
| 115 |
+
|
| 116 |
+
1) Multiple optimizer configs:
|
| 117 |
+
|
| 118 |
+
.. code-block:: python
|
| 119 |
+
|
| 120 |
+
optimizer_cfg = dict(
|
| 121 |
+
model1=dict(type='SGD', lr=lr),
|
| 122 |
+
model2=dict(type='SGD', lr=lr))
|
| 123 |
+
|
| 124 |
+
The return dict is
|
| 125 |
+
``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)``
|
| 126 |
+
|
| 127 |
+
2) Single optimizer config:
|
| 128 |
+
|
| 129 |
+
.. code-block:: python
|
| 130 |
+
|
| 131 |
+
optimizer_cfg = dict(type='SGD', lr=lr)
|
| 132 |
+
|
| 133 |
+
The return is ``torch.optim.Optimizer``.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
model (:obj:`nn.Module`): The model with parameters to be optimized.
|
| 137 |
+
cfgs (dict): The config dict of the optimizer.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`:
|
| 141 |
+
The initialized optimizers.
|
| 142 |
+
"""
|
| 143 |
+
optimizers = {}
|
| 144 |
+
if hasattr(model, 'module'):
|
| 145 |
+
model = model.module
|
| 146 |
+
# determine whether 'cfgs' has several dicts for optimizers
|
| 147 |
+
if all(isinstance(v, dict) for v in cfgs.values()):
|
| 148 |
+
for key, cfg in cfgs.items():
|
| 149 |
+
cfg_ = cfg.copy()
|
| 150 |
+
module = getattr(model, key)
|
| 151 |
+
optimizers[key] = build_optimizer(module, cfg_)
|
| 152 |
+
return optimizers
|
| 153 |
+
|
| 154 |
+
return build_optimizer(model, cfgs)
|
depth/requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.6.0
|
| 2 |
+
h5py>=3.6.0
|
| 3 |
+
scipy>=1.7.3
|
| 4 |
+
opencv-python>=4.5.5
|
| 5 |
+
timm>=0.5.4
|
| 6 |
+
albumentations>=1.1.0
|
| 7 |
+
tensorboardX>=2.4.1
|
| 8 |
+
gdown>=4.2.1
|
depth/test_img.jpg
ADDED
|
depth/utils.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import math
|
| 11 |
+
import time
|
| 12 |
+
from collections import defaultdict, deque
|
| 13 |
+
import datetime
|
| 14 |
+
import numpy as np
|
| 15 |
+
from timm.utils import get_state_dict
|
| 16 |
+
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.distributed as dist
|
| 21 |
+
from torch._six import inf
|
| 22 |
+
|
| 23 |
+
from tensorboardX import SummaryWriter
|
| 24 |
+
|
| 25 |
+
class SmoothedValue(object):
|
| 26 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 27 |
+
window or the global series average.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, window_size=20, fmt=None):
|
| 31 |
+
if fmt is None:
|
| 32 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 33 |
+
self.deque = deque(maxlen=window_size)
|
| 34 |
+
self.total = 0.0
|
| 35 |
+
self.count = 0
|
| 36 |
+
self.fmt = fmt
|
| 37 |
+
|
| 38 |
+
def update(self, value, n=1):
|
| 39 |
+
self.deque.append(value)
|
| 40 |
+
self.count += n
|
| 41 |
+
self.total += value * n
|
| 42 |
+
|
| 43 |
+
def synchronize_between_processes(self):
|
| 44 |
+
"""
|
| 45 |
+
Warning: does not synchronize the deque!
|
| 46 |
+
"""
|
| 47 |
+
if not is_dist_avail_and_initialized():
|
| 48 |
+
return
|
| 49 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 50 |
+
dist.barrier()
|
| 51 |
+
dist.all_reduce(t)
|
| 52 |
+
t = t.tolist()
|
| 53 |
+
self.count = int(t[0])
|
| 54 |
+
self.total = t[1]
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def median(self):
|
| 58 |
+
d = torch.tensor(list(self.deque))
|
| 59 |
+
return d.median().item()
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def avg(self):
|
| 63 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 64 |
+
return d.mean().item()
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def global_avg(self):
|
| 68 |
+
return self.total / self.count
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def max(self):
|
| 72 |
+
return max(self.deque)
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def value(self):
|
| 76 |
+
return self.deque[-1]
|
| 77 |
+
|
| 78 |
+
def __str__(self):
|
| 79 |
+
return self.fmt.format(
|
| 80 |
+
median=self.median,
|
| 81 |
+
avg=self.avg,
|
| 82 |
+
global_avg=self.global_avg,
|
| 83 |
+
max=self.max,
|
| 84 |
+
value=self.value)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class MetricLogger(object):
|
| 88 |
+
def __init__(self, delimiter="\t"):
|
| 89 |
+
self.meters = defaultdict(SmoothedValue)
|
| 90 |
+
self.delimiter = delimiter
|
| 91 |
+
|
| 92 |
+
def update(self, **kwargs):
|
| 93 |
+
for k, v in kwargs.items():
|
| 94 |
+
if v is None:
|
| 95 |
+
continue
|
| 96 |
+
if isinstance(v, torch.Tensor):
|
| 97 |
+
v = v.item()
|
| 98 |
+
assert isinstance(v, (float, int))
|
| 99 |
+
self.meters[k].update(v)
|
| 100 |
+
|
| 101 |
+
def __getattr__(self, attr):
|
| 102 |
+
if attr in self.meters:
|
| 103 |
+
return self.meters[attr]
|
| 104 |
+
if attr in self.__dict__:
|
| 105 |
+
return self.__dict__[attr]
|
| 106 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 107 |
+
type(self).__name__, attr))
|
| 108 |
+
|
| 109 |
+
def __str__(self):
|
| 110 |
+
loss_str = []
|
| 111 |
+
for name, meter in self.meters.items():
|
| 112 |
+
loss_str.append(
|
| 113 |
+
"{}: {}".format(name, str(meter))
|
| 114 |
+
)
|
| 115 |
+
return self.delimiter.join(loss_str)
|
| 116 |
+
|
| 117 |
+
def synchronize_between_processes(self):
|
| 118 |
+
for meter in self.meters.values():
|
| 119 |
+
meter.synchronize_between_processes()
|
| 120 |
+
|
| 121 |
+
def add_meter(self, name, meter):
|
| 122 |
+
self.meters[name] = meter
|
| 123 |
+
|
| 124 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 125 |
+
i = 0
|
| 126 |
+
if not header:
|
| 127 |
+
header = ''
|
| 128 |
+
start_time = time.time()
|
| 129 |
+
end = time.time()
|
| 130 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 131 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 132 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 133 |
+
log_msg = [
|
| 134 |
+
header,
|
| 135 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 136 |
+
'eta: {eta}',
|
| 137 |
+
'{meters}',
|
| 138 |
+
'time: {time}',
|
| 139 |
+
'data: {data}'
|
| 140 |
+
]
|
| 141 |
+
if torch.cuda.is_available():
|
| 142 |
+
log_msg.append('max mem: {memory:.0f}')
|
| 143 |
+
log_msg = self.delimiter.join(log_msg)
|
| 144 |
+
MB = 1024.0 * 1024.0
|
| 145 |
+
for obj in iterable:
|
| 146 |
+
data_time.update(time.time() - end)
|
| 147 |
+
yield obj
|
| 148 |
+
iter_time.update(time.time() - end)
|
| 149 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 150 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 151 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 152 |
+
if torch.cuda.is_available():
|
| 153 |
+
print(log_msg.format(
|
| 154 |
+
i, len(iterable), eta=eta_string,
|
| 155 |
+
meters=str(self),
|
| 156 |
+
time=str(iter_time), data=str(data_time),
|
| 157 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
| 158 |
+
else:
|
| 159 |
+
print(log_msg.format(
|
| 160 |
+
i, len(iterable), eta=eta_string,
|
| 161 |
+
meters=str(self),
|
| 162 |
+
time=str(iter_time), data=str(data_time)))
|
| 163 |
+
i += 1
|
| 164 |
+
end = time.time()
|
| 165 |
+
total_time = time.time() - start_time
|
| 166 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 167 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
| 168 |
+
header, total_time_str, total_time / len(iterable)))
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class TensorboardLogger(object):
|
| 172 |
+
def __init__(self, log_dir):
|
| 173 |
+
self.writer = SummaryWriter(logdir=log_dir)
|
| 174 |
+
self.step = 0
|
| 175 |
+
|
| 176 |
+
def set_step(self, step=None):
|
| 177 |
+
if step is not None:
|
| 178 |
+
self.step = step
|
| 179 |
+
else:
|
| 180 |
+
self.step += 1
|
| 181 |
+
|
| 182 |
+
def update(self, head='scalar', step=None, **kwargs):
|
| 183 |
+
for k, v in kwargs.items():
|
| 184 |
+
if v is None:
|
| 185 |
+
continue
|
| 186 |
+
if isinstance(v, torch.Tensor):
|
| 187 |
+
v = v.item()
|
| 188 |
+
assert isinstance(v, (float, int))
|
| 189 |
+
self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
|
| 190 |
+
|
| 191 |
+
def flush(self):
|
| 192 |
+
self.writer.flush()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class WandbLogger(object):
|
| 196 |
+
def __init__(self, args):
|
| 197 |
+
self.args = args
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
import wandb
|
| 201 |
+
self._wandb = wandb
|
| 202 |
+
except ImportError:
|
| 203 |
+
raise ImportError(
|
| 204 |
+
"To use the Weights and Biases Logger please install wandb."
|
| 205 |
+
"Run `pip install wandb` to install it."
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Initialize a W&B run
|
| 209 |
+
if self._wandb.run is None:
|
| 210 |
+
self._wandb.init(
|
| 211 |
+
project=args.project,
|
| 212 |
+
config=args
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
def log_epoch_metrics(self, metrics, commit=True):
|
| 216 |
+
"""
|
| 217 |
+
Log train/test metrics onto W&B.
|
| 218 |
+
"""
|
| 219 |
+
# Log number of model parameters as W&B summary
|
| 220 |
+
self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None)
|
| 221 |
+
metrics.pop('n_parameters', None)
|
| 222 |
+
|
| 223 |
+
# Log current epoch
|
| 224 |
+
self._wandb.log({'epoch': metrics.get('epoch')}, commit=False)
|
| 225 |
+
metrics.pop('epoch')
|
| 226 |
+
|
| 227 |
+
for k, v in metrics.items():
|
| 228 |
+
if 'train' in k:
|
| 229 |
+
self._wandb.log({f'Global Train/{k}': v}, commit=False)
|
| 230 |
+
elif 'test' in k:
|
| 231 |
+
self._wandb.log({f'Global Test/{k}': v}, commit=False)
|
| 232 |
+
|
| 233 |
+
self._wandb.log({})
|
| 234 |
+
|
| 235 |
+
def log_checkpoints(self):
|
| 236 |
+
output_dir = self.args.output_dir
|
| 237 |
+
model_artifact = self._wandb.Artifact(
|
| 238 |
+
self._wandb.run.id + "_model", type="model"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
model_artifact.add_dir(output_dir)
|
| 242 |
+
self._wandb.log_artifact(model_artifact, aliases=["latest", "best"])
|
| 243 |
+
|
| 244 |
+
def set_steps(self):
|
| 245 |
+
# Set global training step
|
| 246 |
+
self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step')
|
| 247 |
+
# Set epoch-wise step
|
| 248 |
+
self._wandb.define_metric('Global Train/*', step_metric='epoch')
|
| 249 |
+
self._wandb.define_metric('Global Test/*', step_metric='epoch')
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def setup_for_distributed(is_master):
|
| 253 |
+
"""
|
| 254 |
+
This function disables printing when not in master process
|
| 255 |
+
"""
|
| 256 |
+
import builtins as __builtin__
|
| 257 |
+
builtin_print = __builtin__.print
|
| 258 |
+
|
| 259 |
+
def print(*args, **kwargs):
|
| 260 |
+
force = kwargs.pop('force', False)
|
| 261 |
+
if is_master or force:
|
| 262 |
+
builtin_print(*args, **kwargs)
|
| 263 |
+
|
| 264 |
+
__builtin__.print = print
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def is_dist_avail_and_initialized():
|
| 268 |
+
if not dist.is_available():
|
| 269 |
+
return False
|
| 270 |
+
if not dist.is_initialized():
|
| 271 |
+
return False
|
| 272 |
+
return True
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def get_world_size():
|
| 276 |
+
if not is_dist_avail_and_initialized():
|
| 277 |
+
return 1
|
| 278 |
+
return dist.get_world_size()
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def get_rank():
|
| 282 |
+
if not is_dist_avail_and_initialized():
|
| 283 |
+
return 0
|
| 284 |
+
return dist.get_rank()
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def is_main_process():
|
| 288 |
+
return get_rank() == 0
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def save_on_master(*args, **kwargs):
|
| 292 |
+
if is_main_process():
|
| 293 |
+
torch.save(*args, **kwargs)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def init_distributed_mode(args):
|
| 297 |
+
|
| 298 |
+
if args.dist_on_itp:
|
| 299 |
+
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
| 300 |
+
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
| 301 |
+
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
| 302 |
+
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
| 303 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
| 304 |
+
os.environ['RANK'] = str(args.rank)
|
| 305 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
| 306 |
+
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
| 307 |
+
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 308 |
+
args.rank = int(os.environ["RANK"])
|
| 309 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 310 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 311 |
+
elif 'SLURM_PROCID' in os.environ:
|
| 312 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
| 313 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 314 |
+
|
| 315 |
+
os.environ['RANK'] = str(args.rank)
|
| 316 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
| 317 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
| 318 |
+
else:
|
| 319 |
+
print('Not using distributed mode')
|
| 320 |
+
args.distributed = False
|
| 321 |
+
return
|
| 322 |
+
|
| 323 |
+
args.distributed = True
|
| 324 |
+
|
| 325 |
+
torch.cuda.set_device(args.gpu)
|
| 326 |
+
args.dist_backend = 'nccl'
|
| 327 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
| 328 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
| 329 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 330 |
+
world_size=args.world_size, rank=args.rank)
|
| 331 |
+
torch.distributed.barrier()
|
| 332 |
+
setup_for_distributed(args.rank == 0)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def init_distributed_mode_simple(args):
|
| 336 |
+
|
| 337 |
+
args.rank = int(os.environ["RANK"])
|
| 338 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 339 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 340 |
+
args.dist_url = 'env://'
|
| 341 |
+
|
| 342 |
+
args.distributed = True
|
| 343 |
+
|
| 344 |
+
torch.cuda.set_device(args.gpu)
|
| 345 |
+
args.dist_backend = 'nccl'
|
| 346 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
| 347 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
| 348 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 349 |
+
world_size=args.world_size, rank=args.rank)
|
| 350 |
+
torch.distributed.barrier()
|
| 351 |
+
setup_for_distributed(args.rank == 0)
|
| 352 |
+
|
| 353 |
+
def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
|
| 354 |
+
missing_keys = []
|
| 355 |
+
unexpected_keys = []
|
| 356 |
+
error_msgs = []
|
| 357 |
+
# copy state_dict so _load_from_state_dict can modify it
|
| 358 |
+
metadata = getattr(state_dict, '_metadata', None)
|
| 359 |
+
state_dict = state_dict.copy()
|
| 360 |
+
if metadata is not None:
|
| 361 |
+
state_dict._metadata = metadata
|
| 362 |
+
|
| 363 |
+
def load(module, prefix=''):
|
| 364 |
+
local_metadata = {} if metadata is None else metadata.get(
|
| 365 |
+
prefix[:-1], {})
|
| 366 |
+
module._load_from_state_dict(
|
| 367 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
| 368 |
+
for name, child in module._modules.items():
|
| 369 |
+
if child is not None:
|
| 370 |
+
load(child, prefix + name + '.')
|
| 371 |
+
|
| 372 |
+
load(model, prefix=prefix)
|
| 373 |
+
|
| 374 |
+
warn_missing_keys = []
|
| 375 |
+
ignore_missing_keys = []
|
| 376 |
+
for key in missing_keys:
|
| 377 |
+
keep_flag = True
|
| 378 |
+
for ignore_key in ignore_missing.split('|'):
|
| 379 |
+
if ignore_key in key:
|
| 380 |
+
keep_flag = False
|
| 381 |
+
break
|
| 382 |
+
if keep_flag:
|
| 383 |
+
warn_missing_keys.append(key)
|
| 384 |
+
else:
|
| 385 |
+
ignore_missing_keys.append(key)
|
| 386 |
+
|
| 387 |
+
missing_keys = warn_missing_keys
|
| 388 |
+
|
| 389 |
+
if len(missing_keys) > 0:
|
| 390 |
+
print("Weights of {} not initialized from pretrained model: {}".format(
|
| 391 |
+
model.__class__.__name__, missing_keys))
|
| 392 |
+
if len(unexpected_keys) > 0:
|
| 393 |
+
print("Weights from pretrained model not used in {}: {}".format(
|
| 394 |
+
model.__class__.__name__, unexpected_keys))
|
| 395 |
+
if len(ignore_missing_keys) > 0:
|
| 396 |
+
print("Ignored weights of {} not initialized from pretrained model: {}".format(
|
| 397 |
+
model.__class__.__name__, ignore_missing_keys))
|
| 398 |
+
if len(error_msgs) > 0:
|
| 399 |
+
print('\n'.join(error_msgs))
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class NativeScalerWithGradNormCount:
|
| 403 |
+
state_dict_key = "amp_scaler"
|
| 404 |
+
|
| 405 |
+
def __init__(self):
|
| 406 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
| 407 |
+
|
| 408 |
+
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
|
| 409 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
| 410 |
+
if update_grad:
|
| 411 |
+
if clip_grad is not None:
|
| 412 |
+
assert parameters is not None
|
| 413 |
+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
| 414 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
| 415 |
+
else:
|
| 416 |
+
self._scaler.unscale_(optimizer)
|
| 417 |
+
norm = get_grad_norm_(parameters)
|
| 418 |
+
self._scaler.step(optimizer)
|
| 419 |
+
self._scaler.update()
|
| 420 |
+
else:
|
| 421 |
+
norm = None
|
| 422 |
+
return norm
|
| 423 |
+
|
| 424 |
+
def state_dict(self):
|
| 425 |
+
return self._scaler.state_dict()
|
| 426 |
+
|
| 427 |
+
def load_state_dict(self, state_dict):
|
| 428 |
+
self._scaler.load_state_dict(state_dict)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
| 432 |
+
if isinstance(parameters, torch.Tensor):
|
| 433 |
+
parameters = [parameters]
|
| 434 |
+
parameters = [p for p in parameters if p.grad is not None]
|
| 435 |
+
norm_type = float(norm_type)
|
| 436 |
+
if len(parameters) == 0:
|
| 437 |
+
return torch.tensor(0.)
|
| 438 |
+
device = parameters[0].grad.device
|
| 439 |
+
if norm_type == inf:
|
| 440 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
| 441 |
+
else:
|
| 442 |
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
| 443 |
+
return total_norm
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
|
| 447 |
+
start_warmup_value=0, warmup_steps=-1):
|
| 448 |
+
warmup_schedule = np.array([])
|
| 449 |
+
warmup_iters = warmup_epochs * niter_per_ep
|
| 450 |
+
if warmup_steps > 0:
|
| 451 |
+
warmup_iters = warmup_steps
|
| 452 |
+
print("Set warmup steps = %d" % warmup_iters)
|
| 453 |
+
if warmup_epochs > 0:
|
| 454 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
| 455 |
+
|
| 456 |
+
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
| 457 |
+
schedule = np.array(
|
| 458 |
+
[final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
|
| 459 |
+
|
| 460 |
+
schedule = np.concatenate((warmup_schedule, schedule))
|
| 461 |
+
|
| 462 |
+
assert len(schedule) == epochs * niter_per_ep
|
| 463 |
+
return schedule
|
| 464 |
+
|
| 465 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
|
| 466 |
+
output_dir = Path(args.output_dir)
|
| 467 |
+
epoch_name = str(epoch)
|
| 468 |
+
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
|
| 469 |
+
for checkpoint_path in checkpoint_paths:
|
| 470 |
+
to_save = {
|
| 471 |
+
'model': model_without_ddp.state_dict(),
|
| 472 |
+
'optimizer': optimizer.state_dict(),
|
| 473 |
+
'epoch': epoch,
|
| 474 |
+
'scaler': loss_scaler.state_dict(),
|
| 475 |
+
'args': args,
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
if model_ema is not None:
|
| 479 |
+
to_save['model_ema'] = get_state_dict(model_ema)
|
| 480 |
+
|
| 481 |
+
save_on_master(to_save, checkpoint_path)
|
| 482 |
+
|
| 483 |
+
if is_main_process() and isinstance(epoch, int):
|
| 484 |
+
to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq
|
| 485 |
+
old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del)
|
| 486 |
+
if os.path.exists(old_ckpt):
|
| 487 |
+
os.remove(old_ckpt)
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
|
| 491 |
+
output_dir = Path(args.output_dir)
|
| 492 |
+
if args.auto_resume and len(args.resume) == 0:
|
| 493 |
+
import glob
|
| 494 |
+
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
|
| 495 |
+
latest_ckpt = -1
|
| 496 |
+
for ckpt in all_checkpoints:
|
| 497 |
+
t = ckpt.split('-')[-1].split('.')[0]
|
| 498 |
+
if t.isdigit():
|
| 499 |
+
latest_ckpt = max(int(t), latest_ckpt)
|
| 500 |
+
if latest_ckpt >= 0:
|
| 501 |
+
args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
|
| 502 |
+
print("Auto resume checkpoint: %s" % args.resume)
|
| 503 |
+
|
| 504 |
+
if args.resume:
|
| 505 |
+
if args.resume.startswith('https'):
|
| 506 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 507 |
+
args.resume, map_location='cpu', check_hash=True)
|
| 508 |
+
else:
|
| 509 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
| 510 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
| 511 |
+
print("Resume checkpoint %s" % args.resume)
|
| 512 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
|
| 513 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 514 |
+
if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema'
|
| 515 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
| 516 |
+
else:
|
| 517 |
+
assert args.eval, 'Does not support resuming with checkpoint-best'
|
| 518 |
+
if hasattr(args, 'model_ema') and args.model_ema:
|
| 519 |
+
if 'model_ema' in checkpoint.keys():
|
| 520 |
+
model_ema.ema.load_state_dict(checkpoint['model_ema'])
|
| 521 |
+
else:
|
| 522 |
+
model_ema.ema.load_state_dict(checkpoint['model'])
|
| 523 |
+
if 'scaler' in checkpoint:
|
| 524 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
| 525 |
+
print("With optim & sched!")
|
depth/utils_depth/criterion.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
|
| 3 |
+
# For non-commercial purpose only (research, evaluation etc).
|
| 4 |
+
# ------------------------------------------------------------------------------
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SiLogLoss(nn.Module):
|
| 11 |
+
def __init__(self, lambd=0.5):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.lambd = lambd
|
| 14 |
+
|
| 15 |
+
def forward(self, pred, target):
|
| 16 |
+
valid_mask = (target > 0).detach()
|
| 17 |
+
diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask])
|
| 18 |
+
loss = torch.sqrt(torch.pow(diff_log, 2).mean() -
|
| 19 |
+
self.lambd * torch.pow(diff_log.mean(), 2))
|
| 20 |
+
|
| 21 |
+
return loss
|
| 22 |
+
|
depth/utils_depth/logging.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
|
| 3 |
+
# For non-commercial purpose only (research, evaluation etc).
|
| 4 |
+
# ------------------------------------------------------------------------------
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import cv2
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
TOTAL_BAR_LENGTH = 30.
|
| 16 |
+
last_time = time.time()
|
| 17 |
+
begin_time = last_time
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def progress_bar(current, total, epochs, cur_epoch, msg=None):
|
| 21 |
+
_, term_width = os.popen('stty size', 'r').read().split()
|
| 22 |
+
term_width = int(term_width)
|
| 23 |
+
global last_time, begin_time
|
| 24 |
+
if current == 0:
|
| 25 |
+
begin_time = time.time() # Reset for new bar.
|
| 26 |
+
|
| 27 |
+
cur_len = int(TOTAL_BAR_LENGTH * current / total)
|
| 28 |
+
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
|
| 29 |
+
|
| 30 |
+
sys.stdout.write(' [')
|
| 31 |
+
for i in range(cur_len):
|
| 32 |
+
sys.stdout.write('=')
|
| 33 |
+
sys.stdout.write('>')
|
| 34 |
+
for i in range(rest_len):
|
| 35 |
+
sys.stdout.write('.')
|
| 36 |
+
sys.stdout.write(']')
|
| 37 |
+
|
| 38 |
+
cur_time = time.time()
|
| 39 |
+
step_time = cur_time - last_time
|
| 40 |
+
last_time = cur_time
|
| 41 |
+
tot_time = cur_time - begin_time
|
| 42 |
+
remain_time = step_time * (total - current) + \
|
| 43 |
+
(epochs - cur_epoch) * step_time * total
|
| 44 |
+
|
| 45 |
+
L = []
|
| 46 |
+
L.append(' Step: %s' % format_time(step_time))
|
| 47 |
+
L.append(' | Tot: %s' % format_time(tot_time))
|
| 48 |
+
L.append(' | Rem: %s' % format_time(remain_time))
|
| 49 |
+
if msg:
|
| 50 |
+
L.append(' | ' + msg)
|
| 51 |
+
|
| 52 |
+
msg = ''.join(L)
|
| 53 |
+
sys.stdout.write(msg)
|
| 54 |
+
for i in range(157 - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
|
| 55 |
+
sys.stdout.write(' ')
|
| 56 |
+
|
| 57 |
+
# Go back to the center of the bar.
|
| 58 |
+
for i in range(157 - int(TOTAL_BAR_LENGTH / 2) + 2):
|
| 59 |
+
sys.stdout.write('\b')
|
| 60 |
+
sys.stdout.write(' %d/%d ' % (current + 1, total))
|
| 61 |
+
|
| 62 |
+
if current < total - 1:
|
| 63 |
+
sys.stdout.write('\r')
|
| 64 |
+
else:
|
| 65 |
+
sys.stdout.write('\n')
|
| 66 |
+
sys.stdout.flush()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class AverageMeter():
|
| 70 |
+
"""Computes and stores the average and current value"""
|
| 71 |
+
|
| 72 |
+
def __init__(self):
|
| 73 |
+
self.reset()
|
| 74 |
+
|
| 75 |
+
def reset(self):
|
| 76 |
+
self.val = 0
|
| 77 |
+
self.avg = 0
|
| 78 |
+
self.sum = 0
|
| 79 |
+
self.count = 0
|
| 80 |
+
|
| 81 |
+
def update(self, val, n=1):
|
| 82 |
+
self.val = val
|
| 83 |
+
self.sum += val * n
|
| 84 |
+
self.count += n
|
| 85 |
+
self.avg = self.sum / self.count
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def format_time(seconds):
|
| 89 |
+
days = int(seconds / 3600 / 24)
|
| 90 |
+
seconds = seconds - days * 3600 * 24
|
| 91 |
+
hours = int(seconds / 3600)
|
| 92 |
+
seconds = seconds - hours * 3600
|
| 93 |
+
minutes = int(seconds / 60)
|
| 94 |
+
seconds = seconds - minutes * 60
|
| 95 |
+
secondsf = int(seconds)
|
| 96 |
+
seconds = seconds - secondsf
|
| 97 |
+
millis = int(seconds * 1000)
|
| 98 |
+
|
| 99 |
+
f = ''
|
| 100 |
+
i = 1
|
| 101 |
+
if days > 0:
|
| 102 |
+
f += str(days) + 'D'
|
| 103 |
+
i += 1
|
| 104 |
+
if hours > 0 and i <= 2:
|
| 105 |
+
f += str(hours) + 'h'
|
| 106 |
+
i += 1
|
| 107 |
+
if minutes > 0 and i <= 2:
|
| 108 |
+
f += str(minutes).zfill(2) + 'm'
|
| 109 |
+
i += 1
|
| 110 |
+
if secondsf > 0 and i <= 2:
|
| 111 |
+
f += str(secondsf).zfill(2) + 's'
|
| 112 |
+
i += 1
|
| 113 |
+
if millis > 0 and i <= 2:
|
| 114 |
+
f += str(millis).zfill(3) + 'ms'
|
| 115 |
+
i += 1
|
| 116 |
+
if f == '':
|
| 117 |
+
f = '0ms'
|
| 118 |
+
return f
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def display_result(result_dict):
|
| 122 |
+
line = "\n"
|
| 123 |
+
line += "=" * 100 + '\n'
|
| 124 |
+
for metric, value in result_dict.items():
|
| 125 |
+
line += "{:>10} ".format(metric)
|
| 126 |
+
line += "\n"
|
| 127 |
+
for metric, value in result_dict.items():
|
| 128 |
+
line += "{:10.4f} ".format(value)
|
| 129 |
+
line += "\n"
|
| 130 |
+
line += "=" * 100 + '\n'
|
| 131 |
+
|
| 132 |
+
return line
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def save_images(pred, save_path):
|
| 136 |
+
if len(pred.shape) > 3:
|
| 137 |
+
pred = pred.squeeze()
|
| 138 |
+
|
| 139 |
+
if isinstance(pred, torch.Tensor):
|
| 140 |
+
pred = pred.cpu().numpy().astype(np.uint8)
|
| 141 |
+
|
| 142 |
+
if pred.shape[0] < 4:
|
| 143 |
+
pred = np.transpose(pred, (1, 2, 0))
|
| 144 |
+
cv2.imwrite(save_path, pred, [cv2.IMWRITE_PNG_COMPRESSION, 0])
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def check_and_make_dirs(paths):
|
| 148 |
+
if not isinstance(paths, list):
|
| 149 |
+
paths = [paths]
|
| 150 |
+
for path in paths:
|
| 151 |
+
if not os.path.exists(path):
|
| 152 |
+
os.makedirs(path)
|
| 153 |
+
|
| 154 |
+
def log_args_to_txt(log_txt, args):
|
| 155 |
+
if not os.path.exists(log_txt):
|
| 156 |
+
with open(log_txt, 'w') as txtfile:
|
| 157 |
+
args_ = vars(args)
|
| 158 |
+
args_str = ''
|
| 159 |
+
for k, v in args_.items():
|
| 160 |
+
args_str = args_str + str(k) + ':' + str(v) + ',\t\n'
|
| 161 |
+
txtfile.write(args_str + '\n')
|
depth/utils_depth/metrics.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
|
| 3 |
+
# For non-commercial purpose only (research, evaluation etc).
|
| 4 |
+
# ------------------------------------------------------------------------------
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def eval_depth(pred, target):
|
| 10 |
+
assert pred.shape == target.shape
|
| 11 |
+
|
| 12 |
+
thresh = torch.max((target / pred), (pred / target))
|
| 13 |
+
|
| 14 |
+
d1 = torch.sum(thresh < 1.25).float() / len(thresh)
|
| 15 |
+
d2 = torch.sum(thresh < 1.25 ** 2).float() / len(thresh)
|
| 16 |
+
d3 = torch.sum(thresh < 1.25 ** 3).float() / len(thresh)
|
| 17 |
+
|
| 18 |
+
diff = pred - target
|
| 19 |
+
diff_log = torch.log(pred) - torch.log(target)
|
| 20 |
+
|
| 21 |
+
abs_rel = torch.mean(torch.abs(diff) / target)
|
| 22 |
+
sq_rel = torch.mean(torch.pow(diff, 2) / target)
|
| 23 |
+
|
| 24 |
+
rmse = torch.sqrt(torch.mean(torch.pow(diff, 2)))
|
| 25 |
+
|
| 26 |
+
rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log , 2)))
|
| 27 |
+
|
| 28 |
+
log10 = torch.mean(torch.abs(torch.log10(pred) - torch.log10(target)))
|
| 29 |
+
silog = torch.sqrt(torch.pow(diff_log, 2).mean() - 0.5 * torch.pow(diff_log.mean(), 2))
|
| 30 |
+
|
| 31 |
+
return {'d1': d1.item(), 'd2': d2.item(), 'd3': d3.item(), 'abs_rel': abs_rel.item(),
|
| 32 |
+
'sq_rel': sq_rel.item(), 'rmse': rmse.item(), 'rmse_log': rmse_log.item(),
|
| 33 |
+
'log10':log10.item(), 'silog':silog.item()}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def cropping_img(args, pred, gt_depth):
|
| 37 |
+
min_depth_eval = args.min_depth_eval
|
| 38 |
+
|
| 39 |
+
max_depth_eval = args.max_depth_eval
|
| 40 |
+
|
| 41 |
+
pred[torch.isinf(pred)] = max_depth_eval
|
| 42 |
+
pred[torch.isnan(pred)] = min_depth_eval
|
| 43 |
+
|
| 44 |
+
valid_mask = torch.logical_and(
|
| 45 |
+
gt_depth > min_depth_eval, gt_depth < max_depth_eval)
|
| 46 |
+
|
| 47 |
+
if args.dataset == 'kitti':
|
| 48 |
+
if args.do_kb_crop:
|
| 49 |
+
height, width = gt_depth.shape
|
| 50 |
+
top_margin = int(height - 352)
|
| 51 |
+
left_margin = int((width - 1216) / 2)
|
| 52 |
+
gt_depth = gt_depth[top_margin:top_margin +
|
| 53 |
+
352, left_margin:left_margin + 1216]
|
| 54 |
+
|
| 55 |
+
if args.kitti_crop:
|
| 56 |
+
gt_height, gt_width = gt_depth.shape
|
| 57 |
+
eval_mask = torch.zeros(valid_mask.shape).to(
|
| 58 |
+
device=valid_mask.device)
|
| 59 |
+
|
| 60 |
+
if args.kitti_crop == 'garg_crop':
|
| 61 |
+
eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height),
|
| 62 |
+
int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
|
| 63 |
+
|
| 64 |
+
elif args.kitti_crop == 'eigen_crop':
|
| 65 |
+
eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height),
|
| 66 |
+
int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
|
| 67 |
+
else:
|
| 68 |
+
eval_mask = valid_mask
|
| 69 |
+
|
| 70 |
+
elif args.dataset == 'nyudepthv2':
|
| 71 |
+
eval_mask = torch.zeros(valid_mask.shape).to(device=valid_mask.device)
|
| 72 |
+
eval_mask[45:471, 41:601] = 1
|
| 73 |
+
else:
|
| 74 |
+
eval_mask = valid_mask
|
| 75 |
+
|
| 76 |
+
valid_mask = torch.logical_and(valid_mask, eval_mask)
|
| 77 |
+
|
| 78 |
+
return pred[valid_mask], gt_depth[valid_mask]
|
| 79 |
+
|
depth/utils_depth/misc.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# The code is from ZoeDepth (https://github.com/isl-org/ZoeDepth).
|
| 3 |
+
# For non-commercial purpose only (research, evaluation etc).
|
| 4 |
+
# ------------------------------------------------------------------------------
|
| 5 |
+
from scipy import ndimage
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import matplotlib
|
| 10 |
+
import matplotlib.cm
|
| 11 |
+
import numpy as np
|
| 12 |
+
import requests
|
| 13 |
+
import torch
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from torchvision.transforms import ToTensor
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
|
| 19 |
+
"""Converts a depth map to a color image.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
|
| 23 |
+
vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
|
| 24 |
+
vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
|
| 25 |
+
cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
|
| 26 |
+
invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
|
| 27 |
+
invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
|
| 28 |
+
background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
|
| 29 |
+
gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
|
| 30 |
+
value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
|
| 34 |
+
"""
|
| 35 |
+
if isinstance(value, torch.Tensor):
|
| 36 |
+
value = value.detach().cpu().numpy()
|
| 37 |
+
|
| 38 |
+
value = value.squeeze()
|
| 39 |
+
if invalid_mask is None:
|
| 40 |
+
invalid_mask = value == invalid_val
|
| 41 |
+
mask = np.logical_not(invalid_mask)
|
| 42 |
+
|
| 43 |
+
# normalize
|
| 44 |
+
vmin = np.percentile(value[mask],2) if vmin is None else vmin
|
| 45 |
+
vmax = np.percentile(value[mask],85) if vmax is None else vmax
|
| 46 |
+
if vmin != vmax:
|
| 47 |
+
value = (value - vmin) / (vmax - vmin) # vmin..vmax
|
| 48 |
+
else:
|
| 49 |
+
# Avoid 0-division
|
| 50 |
+
value = value * 0.
|
| 51 |
+
|
| 52 |
+
# squeeze last dim if it exists
|
| 53 |
+
# grey out the invalid values
|
| 54 |
+
|
| 55 |
+
value[invalid_mask] = np.nan
|
| 56 |
+
cmapper = matplotlib.colormaps.get_cmap(cmap)
|
| 57 |
+
if value_transform:
|
| 58 |
+
value = value_transform(value)
|
| 59 |
+
# value = value / value.max()
|
| 60 |
+
value = cmapper(value, bytes=True) # (nxmx4)
|
| 61 |
+
|
| 62 |
+
# img = value[:, :, :]
|
| 63 |
+
img = value[...]
|
| 64 |
+
img[invalid_mask] = background_color
|
| 65 |
+
|
| 66 |
+
# return img.transpose((2, 0, 1))
|
| 67 |
+
if gamma_corrected:
|
| 68 |
+
# gamma correction
|
| 69 |
+
img = img / 255
|
| 70 |
+
img = np.power(img, 2.2)
|
| 71 |
+
img = img * 255
|
| 72 |
+
img = img.astype(np.uint8)
|
| 73 |
+
return img, vmin, vmax
|
depth/v1-inference.yaml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-04
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.00085
|
| 6 |
+
linear_end: 0.0120
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: "jpg"
|
| 11 |
+
cond_stage_key: "txt"
|
| 12 |
+
image_size: 64
|
| 13 |
+
channels: 4
|
| 14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
| 15 |
+
conditioning_key: crossattn
|
| 16 |
+
monitor: val/loss_simple_ema
|
| 17 |
+
scale_factor: 0.18215
|
| 18 |
+
use_ema: False
|
| 19 |
+
|
| 20 |
+
scheduler_config: # 10000 warmup steps
|
| 21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
| 22 |
+
params:
|
| 23 |
+
warm_up_steps: [ 10000 ]
|
| 24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
| 25 |
+
f_start: [ 1.e-6 ]
|
| 26 |
+
f_max: [ 1. ]
|
| 27 |
+
f_min: [ 1. ]
|
| 28 |
+
|
| 29 |
+
unet_config:
|
| 30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 31 |
+
params:
|
| 32 |
+
image_size: 32 # unused
|
| 33 |
+
in_channels: 4
|
| 34 |
+
out_channels: 4
|
| 35 |
+
model_channels: 320
|
| 36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
| 37 |
+
num_res_blocks: 2
|
| 38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
| 39 |
+
num_heads: 8
|
| 40 |
+
use_spatial_transformer: True
|
| 41 |
+
transformer_depth: 1
|
| 42 |
+
context_dim: 768
|
| 43 |
+
use_checkpoint: True
|
| 44 |
+
legacy: False
|
| 45 |
+
|
| 46 |
+
first_stage_config:
|
| 47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
| 48 |
+
params:
|
| 49 |
+
embed_dim: 4
|
| 50 |
+
monitor: val/rec_loss
|
| 51 |
+
ddconfig:
|
| 52 |
+
double_z: true
|
| 53 |
+
z_channels: 4
|
| 54 |
+
resolution: 256
|
| 55 |
+
in_channels: 3
|
| 56 |
+
out_ch: 3
|
| 57 |
+
ch: 128
|
| 58 |
+
ch_mult:
|
| 59 |
+
- 1
|
| 60 |
+
- 2
|
| 61 |
+
- 4
|
| 62 |
+
- 4
|
| 63 |
+
num_res_blocks: 2
|
| 64 |
+
attn_resolutions: []
|
| 65 |
+
dropout: 0.0
|
| 66 |
+
lossconfig:
|
| 67 |
+
target: torch.nn.Identity
|
| 68 |
+
|
| 69 |
+
cond_stage_config:
|
| 70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
evp/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .models import UNetWrapper, TextAdapter
|
evp/models.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from omegaconf import OmegaConf
|
| 2 |
+
|
| 3 |
+
import torch as th
|
| 4 |
+
import torch
|
| 5 |
+
import math
|
| 6 |
+
import abc
|
| 7 |
+
|
| 8 |
+
from torch import nn, einsum
|
| 9 |
+
|
| 10 |
+
from einops import rearrange, repeat
|
| 11 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 12 |
+
from transformers import CLIPTokenizer
|
| 13 |
+
from transformers.models.clip.modeling_clip import CLIPTextConfig, CLIPTextModel, CLIPTextTransformer#, _expand_mask
|
| 14 |
+
from inspect import isfunction
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def exists(val):
|
| 18 |
+
return val is not None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def uniq(arr):
|
| 22 |
+
return{el: True for el in arr}.keys()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def default(val, d):
|
| 26 |
+
if exists(val):
|
| 27 |
+
return val
|
| 28 |
+
return d() if isfunction(d) else d
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def register_attention_control(model, controller):
|
| 33 |
+
def ca_forward(self, place_in_unet):
|
| 34 |
+
def forward(x, context=None, mask=None):
|
| 35 |
+
h = self.heads
|
| 36 |
+
|
| 37 |
+
q = self.to_q(x)
|
| 38 |
+
is_cross = context is not None
|
| 39 |
+
context = default(context, x)
|
| 40 |
+
k = self.to_k(context)
|
| 41 |
+
v = self.to_v(context)
|
| 42 |
+
|
| 43 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
| 44 |
+
|
| 45 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
| 46 |
+
|
| 47 |
+
if exists(mask):
|
| 48 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
| 49 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 50 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
| 51 |
+
sim.masked_fill_(~mask, max_neg_value)
|
| 52 |
+
|
| 53 |
+
# attention, what we cannot get enough of
|
| 54 |
+
attn = sim.softmax(dim=-1)
|
| 55 |
+
|
| 56 |
+
attn2 = rearrange(attn, '(b h) k c -> h b k c', h=h).mean(0)
|
| 57 |
+
controller(attn2, is_cross, place_in_unet)
|
| 58 |
+
|
| 59 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
| 60 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
| 61 |
+
return self.to_out(out)
|
| 62 |
+
|
| 63 |
+
return forward
|
| 64 |
+
|
| 65 |
+
class DummyController:
|
| 66 |
+
def __call__(self, *args):
|
| 67 |
+
return args[0]
|
| 68 |
+
|
| 69 |
+
def __init__(self):
|
| 70 |
+
self.num_att_layers = 0
|
| 71 |
+
|
| 72 |
+
if controller is None:
|
| 73 |
+
controller = DummyController()
|
| 74 |
+
|
| 75 |
+
def register_recr(net_, count, place_in_unet):
|
| 76 |
+
if net_.__class__.__name__ == 'CrossAttention':
|
| 77 |
+
net_.forward = ca_forward(net_, place_in_unet)
|
| 78 |
+
return count + 1
|
| 79 |
+
elif hasattr(net_, 'children'):
|
| 80 |
+
for net__ in net_.children():
|
| 81 |
+
count = register_recr(net__, count, place_in_unet)
|
| 82 |
+
return count
|
| 83 |
+
|
| 84 |
+
cross_att_count = 0
|
| 85 |
+
sub_nets = model.diffusion_model.named_children()
|
| 86 |
+
|
| 87 |
+
for net in sub_nets:
|
| 88 |
+
if "input_blocks" in net[0]:
|
| 89 |
+
cross_att_count += register_recr(net[1], 0, "down")
|
| 90 |
+
elif "output_blocks" in net[0]:
|
| 91 |
+
cross_att_count += register_recr(net[1], 0, "up")
|
| 92 |
+
elif "middle_block" in net[0]:
|
| 93 |
+
cross_att_count += register_recr(net[1], 0, "mid")
|
| 94 |
+
|
| 95 |
+
controller.num_att_layers = cross_att_count
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class AttentionControl(abc.ABC):
|
| 99 |
+
|
| 100 |
+
def step_callback(self, x_t):
|
| 101 |
+
return x_t
|
| 102 |
+
|
| 103 |
+
def between_steps(self):
|
| 104 |
+
return
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def num_uncond_att_layers(self):
|
| 108 |
+
return 0
|
| 109 |
+
|
| 110 |
+
@abc.abstractmethod
|
| 111 |
+
def forward (self, attn, is_cross: bool, place_in_unet: str):
|
| 112 |
+
raise NotImplementedError
|
| 113 |
+
|
| 114 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
| 115 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
| 116 |
+
return attn
|
| 117 |
+
|
| 118 |
+
def reset(self):
|
| 119 |
+
self.cur_step = 0
|
| 120 |
+
self.cur_att_layer = 0
|
| 121 |
+
|
| 122 |
+
def __init__(self):
|
| 123 |
+
self.cur_step = 0
|
| 124 |
+
self.num_att_layers = -1
|
| 125 |
+
self.cur_att_layer = 0
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class AttentionStore(AttentionControl):
|
| 129 |
+
@staticmethod
|
| 130 |
+
def get_empty_store():
|
| 131 |
+
return {"down_cross": [], "mid_cross": [], "up_cross": [],
|
| 132 |
+
"down_self": [], "mid_self": [], "up_self": []}
|
| 133 |
+
|
| 134 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 135 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
| 136 |
+
if attn.shape[1] <= (self.max_size) ** 2: # avoid memory overhead
|
| 137 |
+
self.step_store[key].append(attn)
|
| 138 |
+
return attn
|
| 139 |
+
|
| 140 |
+
def between_steps(self):
|
| 141 |
+
if len(self.attention_store) == 0:
|
| 142 |
+
self.attention_store = self.step_store
|
| 143 |
+
else:
|
| 144 |
+
for key in self.attention_store:
|
| 145 |
+
for i in range(len(self.attention_store[key])):
|
| 146 |
+
self.attention_store[key][i] += self.step_store[key][i]
|
| 147 |
+
self.step_store = self.get_empty_store()
|
| 148 |
+
|
| 149 |
+
def get_average_attention(self):
|
| 150 |
+
average_attention = {key: [item for item in self.step_store[key]] for key in self.step_store}
|
| 151 |
+
return average_attention
|
| 152 |
+
|
| 153 |
+
def reset(self):
|
| 154 |
+
super(AttentionStore, self).reset()
|
| 155 |
+
self.step_store = self.get_empty_store()
|
| 156 |
+
self.attention_store = {}
|
| 157 |
+
|
| 158 |
+
def __init__(self, base_size=64, max_size=None):
|
| 159 |
+
super(AttentionStore, self).__init__()
|
| 160 |
+
self.step_store = self.get_empty_store()
|
| 161 |
+
self.attention_store = {}
|
| 162 |
+
self.base_size = base_size
|
| 163 |
+
if max_size is None:
|
| 164 |
+
self.max_size = self.base_size // 2
|
| 165 |
+
else:
|
| 166 |
+
self.max_size = max_size
|
| 167 |
+
|
| 168 |
+
def register_hier_output(model):
|
| 169 |
+
self = model.diffusion_model
|
| 170 |
+
from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding
|
| 171 |
+
def forward(x, timesteps=None, context=None, y=None,**kwargs):
|
| 172 |
+
"""
|
| 173 |
+
Apply the model to an input batch.
|
| 174 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
| 175 |
+
:param timesteps: a 1-D batch of timesteps.
|
| 176 |
+
:param context: conditioning plugged in via crossattn
|
| 177 |
+
:param y: an [N] Tensor of labels, if class-conditional.
|
| 178 |
+
:return: an [N x C x ...] Tensor of outputs.
|
| 179 |
+
"""
|
| 180 |
+
assert (y is not None) == (
|
| 181 |
+
self.num_classes is not None
|
| 182 |
+
), "must specify y if and only if the model is class-conditional"
|
| 183 |
+
hs = []
|
| 184 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 185 |
+
emb = self.time_embed(t_emb)
|
| 186 |
+
|
| 187 |
+
if self.num_classes is not None:
|
| 188 |
+
assert y.shape == (x.shape[0],)
|
| 189 |
+
emb = emb + self.label_emb(y)
|
| 190 |
+
|
| 191 |
+
h = x.type(self.dtype)
|
| 192 |
+
for module in self.input_blocks:
|
| 193 |
+
# import pdb; pdb.set_trace()
|
| 194 |
+
if context.shape[1]==2:
|
| 195 |
+
h = module(h, emb, context[:,0,:].unsqueeze(1))
|
| 196 |
+
else:
|
| 197 |
+
h = module(h, emb, context)
|
| 198 |
+
hs.append(h)
|
| 199 |
+
if context.shape[1]==2:
|
| 200 |
+
h = self.middle_block(h, emb, context[:,0,:].unsqueeze(1))
|
| 201 |
+
else:
|
| 202 |
+
h = self.middle_block(h, emb, context)
|
| 203 |
+
out_list = []
|
| 204 |
+
|
| 205 |
+
for i_out, module in enumerate(self.output_blocks):
|
| 206 |
+
h = th.cat([h, hs.pop()], dim=1)
|
| 207 |
+
if context.shape[1]==2:
|
| 208 |
+
h = module(h, emb, context[:,1,:].unsqueeze(1))
|
| 209 |
+
else:
|
| 210 |
+
h = module(h, emb, context)
|
| 211 |
+
if i_out in [1, 4, 7]:
|
| 212 |
+
out_list.append(h)
|
| 213 |
+
h = h.type(x.dtype)
|
| 214 |
+
|
| 215 |
+
out_list.append(h)
|
| 216 |
+
return out_list
|
| 217 |
+
|
| 218 |
+
self.forward = forward
|
| 219 |
+
|
| 220 |
+
class UNetWrapper(nn.Module):
|
| 221 |
+
def __init__(self, unet, use_attn=True, base_size=512, max_attn_size=None, attn_selector='up_cross+down_cross') -> None:
|
| 222 |
+
super().__init__()
|
| 223 |
+
self.unet = unet
|
| 224 |
+
self.attention_store = AttentionStore(base_size=base_size // 8, max_size=max_attn_size)
|
| 225 |
+
self.size16 = base_size // 32
|
| 226 |
+
self.size32 = base_size // 16
|
| 227 |
+
self.size64 = base_size // 8
|
| 228 |
+
self.use_attn = use_attn
|
| 229 |
+
if self.use_attn:
|
| 230 |
+
register_attention_control(unet, self.attention_store)
|
| 231 |
+
register_hier_output(unet)
|
| 232 |
+
self.attn_selector = attn_selector.split('+')
|
| 233 |
+
|
| 234 |
+
def forward(self, *args, **kwargs):
|
| 235 |
+
if self.use_attn:
|
| 236 |
+
self.attention_store.reset()
|
| 237 |
+
out_list = self.unet(*args, **kwargs)
|
| 238 |
+
if self.use_attn:
|
| 239 |
+
avg_attn = self.attention_store.get_average_attention()
|
| 240 |
+
attn16, attn32, attn64 = self.process_attn(avg_attn)
|
| 241 |
+
out_list[1] = torch.cat([out_list[1], attn16], dim=1)
|
| 242 |
+
out_list[2] = torch.cat([out_list[2], attn32], dim=1)
|
| 243 |
+
if attn64 is not None:
|
| 244 |
+
out_list[3] = torch.cat([out_list[3], attn64], dim=1)
|
| 245 |
+
return out_list[::-1]
|
| 246 |
+
|
| 247 |
+
def process_attn(self, avg_attn):
|
| 248 |
+
attns = {self.size16: [], self.size32: [], self.size64: []}
|
| 249 |
+
for k in self.attn_selector:
|
| 250 |
+
for up_attn in avg_attn[k]:
|
| 251 |
+
size = int(math.sqrt(up_attn.shape[1]))
|
| 252 |
+
attns[size].append(rearrange(up_attn, 'b (h w) c -> b c h w', h=size))
|
| 253 |
+
attn16 = torch.stack(attns[self.size16]).mean(0)
|
| 254 |
+
attn32 = torch.stack(attns[self.size32]).mean(0)
|
| 255 |
+
if len(attns[self.size64]) > 0:
|
| 256 |
+
attn64 = torch.stack(attns[self.size64]).mean(0)
|
| 257 |
+
else:
|
| 258 |
+
attn64 = None
|
| 259 |
+
return attn16, attn32, attn64
|
| 260 |
+
|
| 261 |
+
class TextAdapter(nn.Module):
|
| 262 |
+
def __init__(self, text_dim=768, hidden_dim=None):
|
| 263 |
+
super().__init__()
|
| 264 |
+
if hidden_dim is None:
|
| 265 |
+
hidden_dim = text_dim
|
| 266 |
+
self.fc = nn.Sequential(
|
| 267 |
+
nn.Linear(text_dim, hidden_dim),
|
| 268 |
+
nn.GELU(),
|
| 269 |
+
nn.Linear(hidden_dim, text_dim)
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
def forward(self, latents, texts, gamma):
|
| 273 |
+
n_class, channel = texts.shape
|
| 274 |
+
bs = latents.shape[0]
|
| 275 |
+
|
| 276 |
+
texts_after = self.fc(texts)
|
| 277 |
+
texts = texts + gamma * texts_after
|
| 278 |
+
texts = repeat(texts, 'n c -> b n c', b=bs)
|
| 279 |
+
return texts
|
| 280 |
+
|
| 281 |
+
class TextAdapterRefer(nn.Module):
|
| 282 |
+
def __init__(self, text_dim=768):
|
| 283 |
+
super().__init__()
|
| 284 |
+
|
| 285 |
+
self.fc = nn.Sequential(
|
| 286 |
+
nn.Linear(text_dim, text_dim),
|
| 287 |
+
nn.GELU(),
|
| 288 |
+
nn.Linear(text_dim, text_dim)
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def forward(self, latents, texts, gamma):
|
| 292 |
+
texts_after = self.fc(texts)
|
| 293 |
+
texts = texts + gamma * texts_after
|
| 294 |
+
return texts
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class TextAdapterDepth(nn.Module):
|
| 298 |
+
def __init__(self, text_dim=768):
|
| 299 |
+
super().__init__()
|
| 300 |
+
|
| 301 |
+
self.fc = nn.Sequential(
|
| 302 |
+
nn.Linear(text_dim, text_dim),
|
| 303 |
+
nn.GELU(),
|
| 304 |
+
nn.Linear(text_dim, text_dim)
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
def forward(self, latents, texts, gamma):
|
| 308 |
+
# use the gamma to blend
|
| 309 |
+
n_sen, channel = texts.shape
|
| 310 |
+
bs = latents.shape[0]
|
| 311 |
+
|
| 312 |
+
texts_after = self.fc(texts)
|
| 313 |
+
texts = texts + gamma * texts_after
|
| 314 |
+
texts = repeat(texts, 'n c -> n b c', b=1)
|
| 315 |
+
return texts
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class FrozenCLIPEmbedder(nn.Module):
|
| 319 |
+
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
| 320 |
+
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, pool=True):
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
| 323 |
+
self.transformer = CLIPTextModel.from_pretrained(version)
|
| 324 |
+
self.device = device
|
| 325 |
+
self.max_length = max_length
|
| 326 |
+
self.freeze()
|
| 327 |
+
|
| 328 |
+
self.pool = pool
|
| 329 |
+
|
| 330 |
+
def freeze(self):
|
| 331 |
+
self.transformer = self.transformer.eval()
|
| 332 |
+
for param in self.parameters():
|
| 333 |
+
param.requires_grad = False
|
| 334 |
+
|
| 335 |
+
def forward(self, text):
|
| 336 |
+
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
| 337 |
+
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
| 338 |
+
tokens = batch_encoding["input_ids"].to(self.device)
|
| 339 |
+
outputs = self.transformer(input_ids=tokens)
|
| 340 |
+
|
| 341 |
+
if self.pool:
|
| 342 |
+
z = outputs.pooler_output
|
| 343 |
+
else:
|
| 344 |
+
z = outputs.last_hidden_state
|
| 345 |
+
return z
|
| 346 |
+
|
| 347 |
+
def encode(self, text):
|
| 348 |
+
return self(text)
|
| 349 |
+
|
refer/README.md
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Referring Image Segmentation
|
| 2 |
+
## Getting Started
|
| 3 |
+
|
| 4 |
+
1. Install the required packages.
|
| 5 |
+
|
| 6 |
+
```
|
| 7 |
+
pip install -r requirements.txt
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
2. Prepare RefCOCO datasets following [LAVT](https://github.com/yz93/LAVT-RIS).
|
| 11 |
+
|
| 12 |
+
* Download COCO 2014 Train Images [83K/13GB] from [COCO](https://cocodataset.org/#download), and extract `train2014.zip` to `./refer/data/images/mscoco/images`
|
| 13 |
+
|
| 14 |
+
* Follow the instructions in `./refer` to download and extract `refclef.zip, refcoco.zip, refcoco+.zip, refcocog.zip` to `./refer/data`
|
| 15 |
+
|
| 16 |
+
Your dataset directory should be:
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
refer/
|
| 20 |
+
├──data/
|
| 21 |
+
│ ├── images/mscoco/images/
|
| 22 |
+
│ ├── refclef
|
| 23 |
+
│ ├── refcoco
|
| 24 |
+
│ ├── refcoco+
|
| 25 |
+
│ ├── refcocog
|
| 26 |
+
├──evaluation/
|
| 27 |
+
├──...
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
## Results and Fine-tuned Models of EVP
|
| 31 |
+
EVP achieves 76.35 overall IoU and 77.61 mean IoU on the validation set of RefCOCO.
|
| 32 |
+
|
| 33 |
+
## Training
|
| 34 |
+
|
| 35 |
+
We count the max length of referring sentences and set the token length of lenguage model accrodingly. The checkpoint of the best epoch would be saved at `./checkpoints/`.
|
| 36 |
+
|
| 37 |
+
* Train on RefCOCO
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
bash train.sh refcoco /path/to/logdir <NUM_GPUS> --token_length 40
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
* Train on RefCOCO+
|
| 44 |
+
|
| 45 |
+
```
|
| 46 |
+
bash train.sh refcoco+ /path/to/logdir <NUM_GPUS> --token_length 40
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
* Train on RefCOCOg
|
| 50 |
+
|
| 51 |
+
```
|
| 52 |
+
bash train.sh refcocog /path/to/logdir <NUM_GPUS> --token_length 77 --splitBy umd
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## Evaluation
|
| 56 |
+
|
| 57 |
+
* Evaluate on RefCOCO
|
| 58 |
+
|
| 59 |
+
```
|
| 60 |
+
bash test.sh refcoco /path/to/evp_ris_refcoco.pth --token_length 40
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
* Evaluate on RefCOCO+
|
| 64 |
+
|
| 65 |
+
```
|
| 66 |
+
bash test.sh refcoco+ /path/to/evp_ris_refcoco+.pth --token_length 40
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
* Evaluate on RefCOCOg
|
| 70 |
+
|
| 71 |
+
```
|
| 72 |
+
bash test.sh refcocog /path/to/evp_ris_gref.pth --token_length 77 --splitBy umd
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
## Custom inference
|
| 76 |
+
```
|
| 77 |
+
PYTHONPATH="../":$PYTHONPATH python inference.py --img_path test_img.jpg --resume refcoco.pth --token_length 40 --prompt 'green plant'
|
| 78 |
+
```
|
refer/args.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_parser():
|
| 5 |
+
parser = argparse.ArgumentParser(description='EVP training and testing')
|
| 6 |
+
parser.add_argument('--amsgrad', action='store_true',
|
| 7 |
+
help='if true, set amsgrad to True in an Adam or AdamW optimizer.')
|
| 8 |
+
parser.add_argument('-b', '--batch-size', default=8, type=int)
|
| 9 |
+
parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights')
|
| 10 |
+
parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog')
|
| 11 |
+
parser.add_argument('--ddp_trained_weights', action='store_true',
|
| 12 |
+
help='Only needs specified when testing,'
|
| 13 |
+
'whether the weights to be loaded are from a DDP-trained model')
|
| 14 |
+
parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine
|
| 15 |
+
parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run')
|
| 16 |
+
parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs')
|
| 17 |
+
parser.add_argument('--img_size', default=480, type=int, help='input image size')
|
| 18 |
+
parser.add_argument("--local_rank", type=int, default=0, help='local rank for DistributedDataParallel')
|
| 19 |
+
parser.add_argument("--local-rank", type=int, default=0, help='local rank for DistributedDataParallel')
|
| 20 |
+
parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate')
|
| 21 |
+
parser.add_argument('--model_id', default='evp', help='name to identify the model')
|
| 22 |
+
parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights')
|
| 23 |
+
parser.add_argument('--pin_mem', action='store_true',
|
| 24 |
+
help='If true, pin memory when using the data loader.')
|
| 25 |
+
parser.add_argument('--pretrained_swin_weights', default='',
|
| 26 |
+
help='path to pre-trained Swin backbone weights')
|
| 27 |
+
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
|
| 28 |
+
parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory')
|
| 29 |
+
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
| 30 |
+
parser.add_argument('--split', default='val')
|
| 31 |
+
parser.add_argument('--splitBy', default='unc')
|
| 32 |
+
parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay',
|
| 33 |
+
dest='weight_decay')
|
| 34 |
+
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers')
|
| 35 |
+
parser.add_argument('--token_length', default=77, type=int)
|
| 36 |
+
|
| 37 |
+
return parser
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
parser = get_parser()
|
| 42 |
+
args_dict = parser.parse_args()
|
refer/inference.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.backends.cudnn as cudnn
|
| 6 |
+
from models_refer.model import EVPRefer
|
| 7 |
+
from args import get_parser
|
| 8 |
+
import glob
|
| 9 |
+
import utils
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from transformers import CLIPTokenizer
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
parser = get_parser()
|
| 18 |
+
parser.add_argument('--img_path', type=str)
|
| 19 |
+
parser.add_argument('--prompt', type=str)
|
| 20 |
+
args = parser.parse_args()
|
| 21 |
+
|
| 22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
| 24 |
+
model = EVPRefer(sd_path='../checkpoints/v1-5-pruned-emaonly.ckpt')
|
| 25 |
+
cudnn.benchmark = True
|
| 26 |
+
model.to(device)
|
| 27 |
+
model_weight = torch.load(args.resume)['model']
|
| 28 |
+
if 'module' in next(iter(model_weight.items()))[0]:
|
| 29 |
+
model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
|
| 30 |
+
model.load_state_dict(model_weight, strict=False)
|
| 31 |
+
model.eval()
|
| 32 |
+
|
| 33 |
+
img_path = args.img_path
|
| 34 |
+
|
| 35 |
+
image = cv2.imread(img_path)
|
| 36 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 37 |
+
image_t = transforms.ToTensor()(image).unsqueeze(0).to(device)
|
| 38 |
+
image_t = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(image_t)
|
| 39 |
+
shape = image_t.shape
|
| 40 |
+
image_t = torch.nn.functional.interpolate(image_t, (512,512), mode='bilinear', align_corners=True)
|
| 41 |
+
input_ids = tokenizer(text=args.prompt, truncation=True, max_length=args.token_length, return_length=True,
|
| 42 |
+
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")['input_ids'].to(device)
|
| 43 |
+
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
pred = model(image_t, input_ids)
|
| 46 |
+
|
| 47 |
+
pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
|
| 48 |
+
output_mask = pred.cpu().argmax(1).data.numpy().squeeze()
|
| 49 |
+
|
| 50 |
+
alpha = 0.65
|
| 51 |
+
image[output_mask == 0] = (image[output_mask == 0]*alpha).astype(np.uint8)
|
| 52 |
+
contours, _ = cv2.findContours(output_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 53 |
+
cv2.drawContours(image, contours, -1, (0, 255, 0), 2)
|
| 54 |
+
|
| 55 |
+
Image.fromarray(image.astype(np.uint8)).save('res.png')
|
| 56 |
+
|
| 57 |
+
return 0
|
| 58 |
+
|
| 59 |
+
if __name__ == '__main__':
|
| 60 |
+
main()
|
refer/models_refer/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .model import EVPRefer
|
refer/models_refer/model.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
from ldm.util import instantiate_from_config
|
| 7 |
+
from transformers.models.clip.modeling_clip import CLIPTextModel
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
from lib.mask_predictor import SimpleDecoding
|
| 10 |
+
|
| 11 |
+
from evp.models import UNetWrapper, TextAdapterRefer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def icnr(x, scale=2, init=nn.init.kaiming_normal_):
|
| 15 |
+
"""
|
| 16 |
+
Checkerboard artifact free sub-pixel convolution
|
| 17 |
+
https://arxiv.org/abs/1707.02937
|
| 18 |
+
"""
|
| 19 |
+
ni,nf,h,w = x.shape
|
| 20 |
+
ni2 = int(ni/(scale**2))
|
| 21 |
+
k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
|
| 22 |
+
k = k.contiguous().view(ni2, nf, -1)
|
| 23 |
+
k = k.repeat(1, 1, scale**2)
|
| 24 |
+
k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
|
| 25 |
+
x.data.copy_(k)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PixelShuffle(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
Real-Time Single Image and Video Super-Resolution
|
| 31 |
+
https://arxiv.org/abs/1609.05158
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self, n_channels, scale):
|
| 34 |
+
super(PixelShuffle, self).__init__()
|
| 35 |
+
self.conv = nn.Conv2d(n_channels, n_channels*(scale**2), kernel_size=1)
|
| 36 |
+
icnr(self.conv.weight)
|
| 37 |
+
self.shuf = nn.PixelShuffle(scale)
|
| 38 |
+
self.relu = nn.ReLU()
|
| 39 |
+
|
| 40 |
+
def forward(self,x):
|
| 41 |
+
x = self.shuf(self.relu(self.conv(x)))
|
| 42 |
+
return x
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class AttentionModule(nn.Module):
|
| 46 |
+
def __init__(self, in_channels, out_channels):
|
| 47 |
+
super(AttentionModule, self).__init__()
|
| 48 |
+
|
| 49 |
+
# Convolutional Layers
|
| 50 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 51 |
+
|
| 52 |
+
# Group Normalization
|
| 53 |
+
self.group_norm = nn.GroupNorm(20, out_channels)
|
| 54 |
+
|
| 55 |
+
# ReLU Activation
|
| 56 |
+
self.relu = nn.ReLU()
|
| 57 |
+
|
| 58 |
+
# Spatial Attention
|
| 59 |
+
self.spatial_attention = nn.Sequential(
|
| 60 |
+
nn.Conv2d(in_channels, 1, kernel_size=1),
|
| 61 |
+
nn.Sigmoid()
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
# Apply spatial attention
|
| 66 |
+
spatial_attention = self.spatial_attention(x)
|
| 67 |
+
x = x * spatial_attention
|
| 68 |
+
|
| 69 |
+
# Apply convolutional layer
|
| 70 |
+
x = self.conv1(x)
|
| 71 |
+
x = self.group_norm(x)
|
| 72 |
+
x = self.relu(x)
|
| 73 |
+
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class AttentionDownsamplingModule(nn.Module):
|
| 78 |
+
def __init__(self, in_channels, out_channels, scale_factor=2):
|
| 79 |
+
super(AttentionDownsamplingModule, self).__init__()
|
| 80 |
+
|
| 81 |
+
# Spatial Attention
|
| 82 |
+
self.spatial_attention = nn.Sequential(
|
| 83 |
+
nn.Conv2d(in_channels, 1, kernel_size=1),
|
| 84 |
+
nn.Sigmoid()
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Channel Attention
|
| 88 |
+
self.channel_attention = nn.Sequential(
|
| 89 |
+
nn.AdaptiveAvgPool2d(1),
|
| 90 |
+
nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),
|
| 91 |
+
nn.ReLU(inplace=True),
|
| 92 |
+
nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),
|
| 93 |
+
nn.Sigmoid()
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Convolutional Layers
|
| 97 |
+
if scale_factor == 2:
|
| 98 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 99 |
+
elif scale_factor == 4:
|
| 100 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
| 101 |
+
|
| 102 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
| 103 |
+
|
| 104 |
+
# Group Normalization
|
| 105 |
+
self.group_norm = nn.GroupNorm(20, out_channels)
|
| 106 |
+
|
| 107 |
+
# ReLU Activation
|
| 108 |
+
self.relu = nn.ReLU(inplace=True)
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
# Apply spatial attention
|
| 112 |
+
spatial_attention = self.spatial_attention(x)
|
| 113 |
+
x = x * spatial_attention
|
| 114 |
+
|
| 115 |
+
# Apply channel attention
|
| 116 |
+
channel_attention = self.channel_attention(x)
|
| 117 |
+
x = x * channel_attention
|
| 118 |
+
|
| 119 |
+
# Apply convolutional layers
|
| 120 |
+
x = self.conv1(x)
|
| 121 |
+
x = self.group_norm(x)
|
| 122 |
+
x = self.relu(x)
|
| 123 |
+
x = self.conv2(x)
|
| 124 |
+
x = self.group_norm(x)
|
| 125 |
+
x = self.relu(x)
|
| 126 |
+
|
| 127 |
+
return x
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class AttentionUpsamplingModule(nn.Module):
|
| 131 |
+
def __init__(self, in_channels, out_channels):
|
| 132 |
+
super(AttentionUpsamplingModule, self).__init__()
|
| 133 |
+
|
| 134 |
+
# Spatial Attention for outs[2]
|
| 135 |
+
self.spatial_attention = nn.Sequential(
|
| 136 |
+
nn.Conv2d(in_channels, 1, kernel_size=1),
|
| 137 |
+
nn.Sigmoid()
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Channel Attention for outs[2]
|
| 141 |
+
self.channel_attention = nn.Sequential(
|
| 142 |
+
nn.AdaptiveAvgPool2d(1),
|
| 143 |
+
nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),
|
| 144 |
+
nn.ReLU(),
|
| 145 |
+
nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),
|
| 146 |
+
nn.Sigmoid()
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 150 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 151 |
+
|
| 152 |
+
# Group Normalization
|
| 153 |
+
self.group_norm = nn.GroupNorm(20, out_channels)
|
| 154 |
+
|
| 155 |
+
# ReLU Activation
|
| 156 |
+
self.relu = nn.ReLU()
|
| 157 |
+
self.upscale = PixelShuffle(in_channels, 2)
|
| 158 |
+
|
| 159 |
+
def forward(self, x):
|
| 160 |
+
# Apply spatial attention
|
| 161 |
+
spatial_attention = self.spatial_attention(x)
|
| 162 |
+
x = x * spatial_attention
|
| 163 |
+
|
| 164 |
+
# Apply channel attention
|
| 165 |
+
channel_attention = self.channel_attention(x)
|
| 166 |
+
x = x * channel_attention
|
| 167 |
+
|
| 168 |
+
# Apply convolutional layers
|
| 169 |
+
x = self.conv1(x)
|
| 170 |
+
x = self.group_norm(x)
|
| 171 |
+
x = self.relu(x)
|
| 172 |
+
x = self.conv2(x)
|
| 173 |
+
x = self.group_norm(x)
|
| 174 |
+
x = self.relu(x)
|
| 175 |
+
|
| 176 |
+
# Upsample
|
| 177 |
+
x = self.upscale(x)
|
| 178 |
+
|
| 179 |
+
return x
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class ConvLayer(nn.Module):
|
| 183 |
+
def __init__(self, in_channels, out_channels):
|
| 184 |
+
super(ConvLayer, self).__init__()
|
| 185 |
+
|
| 186 |
+
self.conv1 = nn.Sequential(
|
| 187 |
+
nn.Conv2d(in_channels, out_channels, 1),
|
| 188 |
+
nn.GroupNorm(20, out_channels),
|
| 189 |
+
nn.ReLU(),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def forward(self, x):
|
| 193 |
+
x = self.conv1(x)
|
| 194 |
+
|
| 195 |
+
return x
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class InverseMultiAttentiveFeatureRefinement(nn.Module):
|
| 199 |
+
def __init__(self, in_channels_list):
|
| 200 |
+
super(InverseMultiAttentiveFeatureRefinement, self).__init__()
|
| 201 |
+
|
| 202 |
+
self.layer1 = AttentionModule(in_channels_list[0], in_channels_list[0])
|
| 203 |
+
self.layer2 = AttentionDownsamplingModule(in_channels_list[0], in_channels_list[0]//2, scale_factor = 2)
|
| 204 |
+
self.layer3 = ConvLayer(in_channels_list[0]//2 + in_channels_list[1], in_channels_list[1])
|
| 205 |
+
self.layer4 = AttentionDownsamplingModule(in_channels_list[1], in_channels_list[1]//2, scale_factor = 2)
|
| 206 |
+
self.layer5 = ConvLayer(in_channels_list[1]//2 + in_channels_list[2], in_channels_list[2])
|
| 207 |
+
self.layer6 = AttentionDownsamplingModule(in_channels_list[2], in_channels_list[2]//2, scale_factor = 2)
|
| 208 |
+
self.layer7 = ConvLayer(in_channels_list[2]//2 + in_channels_list[3], in_channels_list[3])
|
| 209 |
+
|
| 210 |
+
'''
|
| 211 |
+
self.layer8 = AttentionUpsamplingModule(in_channels_list[3], in_channels_list[3])
|
| 212 |
+
self.layer9 = ConvLayer(in_channels_list[2] + in_channels_list[3], in_channels_list[2])
|
| 213 |
+
self.layer10 = AttentionUpsamplingModule(in_channels_list[2], in_channels_list[2])
|
| 214 |
+
self.layer11 = ConvLayer(in_channels_list[1] + in_channels_list[2], in_channels_list[1])
|
| 215 |
+
self.layer12 = AttentionUpsamplingModule(in_channels_list[1], in_channels_list[1])
|
| 216 |
+
self.layer13 = ConvLayer(in_channels_list[0] + in_channels_list[1], in_channels_list[0])
|
| 217 |
+
'''
|
| 218 |
+
def forward(self, inputs):
|
| 219 |
+
x_c4, x_c3, x_c2, x_c1 = inputs
|
| 220 |
+
x_c4 = self.layer1(x_c4)
|
| 221 |
+
x_c4_3 = self.layer2(x_c4)
|
| 222 |
+
x_c3 = torch.cat([x_c4_3, x_c3], dim=1)
|
| 223 |
+
x_c3 = self.layer3(x_c3)
|
| 224 |
+
x_c3_2 = self.layer4(x_c3)
|
| 225 |
+
x_c2 = torch.cat([x_c3_2, x_c2], dim=1)
|
| 226 |
+
x_c2 = self.layer5(x_c2)
|
| 227 |
+
x_c2_1 = self.layer6(x_c2)
|
| 228 |
+
x_c1 = torch.cat([x_c2_1, x_c1], dim=1)
|
| 229 |
+
x_c1 = self.layer7(x_c1)
|
| 230 |
+
'''
|
| 231 |
+
x_c1_2 = self.layer8(x_c1)
|
| 232 |
+
x_c2 = torch.cat([x_c1_2, x_c2], dim=1)
|
| 233 |
+
x_c2 = self.layer9(x_c2)
|
| 234 |
+
x_c2_3 = self.layer10(x_c2)
|
| 235 |
+
x_c3 = torch.cat([x_c2_3, x_c3], dim=1)
|
| 236 |
+
x_c3 = self.layer11(x_c3)
|
| 237 |
+
x_c3_4 = self.layer12(x_c3)
|
| 238 |
+
x_c4 = torch.cat([x_c3_4, x_c4], dim=1)
|
| 239 |
+
x_c4 = self.layer13(x_c4)
|
| 240 |
+
'''
|
| 241 |
+
return [x_c4, x_c3, x_c2, x_c1]
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class EVPRefer(nn.Module):
|
| 246 |
+
"""Encoder Decoder segmentors.
|
| 247 |
+
|
| 248 |
+
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
|
| 249 |
+
Note that auxiliary_head is only used for deep supervision during training,
|
| 250 |
+
which could be dumped during inference.
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
def __init__(self,
|
| 254 |
+
sd_path=None,
|
| 255 |
+
base_size=512,
|
| 256 |
+
token_embed_dim=768,
|
| 257 |
+
neck_dim=[320,680,1320,1280],
|
| 258 |
+
**args):
|
| 259 |
+
super().__init__()
|
| 260 |
+
config = OmegaConf.load('./v1-inference.yaml')
|
| 261 |
+
config.model.params.ckpt_path = f'{sd_path}'
|
| 262 |
+
sd_model = instantiate_from_config(config.model)
|
| 263 |
+
self.encoder_vq = sd_model.first_stage_model
|
| 264 |
+
self.unet = UNetWrapper(sd_model.model, base_size=base_size)
|
| 265 |
+
del sd_model.cond_stage_model
|
| 266 |
+
del self.encoder_vq.decoder
|
| 267 |
+
for param in self.encoder_vq.parameters():
|
| 268 |
+
param.requires_grad = True
|
| 269 |
+
|
| 270 |
+
self.text_adapter = TextAdapterRefer(text_dim=token_embed_dim)
|
| 271 |
+
|
| 272 |
+
self.classifier = SimpleDecoding(dims=neck_dim)
|
| 273 |
+
|
| 274 |
+
self.gamma = nn.Parameter(torch.ones(token_embed_dim) * 1e-4)
|
| 275 |
+
self.aggregation = InverseMultiAttentiveFeatureRefinement([320,680,1320,1280])
|
| 276 |
+
self.clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
| 277 |
+
for param in self.clip_model.parameters():
|
| 278 |
+
param.requires_grad = True
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def forward(self, img, sentences):
|
| 282 |
+
input_shape = img.shape[-2:]
|
| 283 |
+
|
| 284 |
+
latents = self.encoder_vq.encode(img).mode()
|
| 285 |
+
latents = latents / 4.7164
|
| 286 |
+
|
| 287 |
+
l_feats = self.clip_model(input_ids=sentences).last_hidden_state
|
| 288 |
+
c_crossattn = self.text_adapter(latents, l_feats, self.gamma) # NOTE: here the c_crossattn should be expand_dim as latents
|
| 289 |
+
t = torch.ones((img.shape[0],), device=img.device).long()
|
| 290 |
+
outs = self.unet(latents, t, c_crossattn=[c_crossattn])
|
| 291 |
+
|
| 292 |
+
outs = self.aggregation(outs)
|
| 293 |
+
|
| 294 |
+
x_c1, x_c2, x_c3, x_c4 = outs
|
| 295 |
+
x = self.classifier(x_c4, x_c3, x_c2, x_c1)
|
| 296 |
+
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
|
| 297 |
+
|
| 298 |
+
return x
|
| 299 |
+
|
| 300 |
+
def get_latent(self, x):
|
| 301 |
+
return self.encoder_vq.encode(x).mode()
|
refer/requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
requests
|
| 2 |
+
filelock
|
| 3 |
+
tqdm
|
| 4 |
+
timm
|
| 5 |
+
ftfy
|
| 6 |
+
regex
|
| 7 |
+
scipy
|
| 8 |
+
scikit-image
|
| 9 |
+
pycocotools==2.0.2
|
| 10 |
+
opencv-python==4.5.3.56
|
| 11 |
+
tokenizers
|
| 12 |
+
h5py
|
refer/transforms.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torchvision import transforms as T
|
| 7 |
+
from torchvision.transforms import functional as F
|
| 8 |
+
|
| 9 |
+
import warnings
|
| 10 |
+
warnings.filterwarnings("ignore")
|
| 11 |
+
|
| 12 |
+
def pad_if_smaller(img, size, fill=0):
|
| 13 |
+
min_size = min(img.size)
|
| 14 |
+
if min_size < size:
|
| 15 |
+
ow, oh = img.size
|
| 16 |
+
padh = size - oh if oh < size else 0
|
| 17 |
+
padw = size - ow if ow < size else 0
|
| 18 |
+
img = F.pad(img, (0, 0, padw, padh), fill=fill)
|
| 19 |
+
return img
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Compose(object):
|
| 23 |
+
def __init__(self, transforms):
|
| 24 |
+
self.transforms = transforms
|
| 25 |
+
|
| 26 |
+
def __call__(self, image, target):
|
| 27 |
+
for t in self.transforms:
|
| 28 |
+
image, target = t(image, target)
|
| 29 |
+
return image, target
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Resize(object):
|
| 33 |
+
def __init__(self, h, w):
|
| 34 |
+
self.h = h
|
| 35 |
+
self.w = w
|
| 36 |
+
|
| 37 |
+
def __call__(self, image, target):
|
| 38 |
+
image = F.resize(image, (self.h, self.w))
|
| 39 |
+
# If size is a sequence like (h, w), the output size will be matched to this.
|
| 40 |
+
# If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio
|
| 41 |
+
target = F.resize(target, (self.h, self.w))
|
| 42 |
+
return image, target
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class RandomResize(object):
|
| 46 |
+
def __init__(self, min_size, max_size=None):
|
| 47 |
+
self.min_size = min_size
|
| 48 |
+
if max_size is None:
|
| 49 |
+
max_size = min_size
|
| 50 |
+
self.max_size = max_size
|
| 51 |
+
|
| 52 |
+
def __call__(self, image, target):
|
| 53 |
+
size = random.randint(self.min_size, self.max_size) # Return a random integer N such that a <= N <= b. Alias for randrange(a, b+1)
|
| 54 |
+
image = F.resize(image, size)
|
| 55 |
+
# If size is a sequence like (h, w), the output size will be matched to this.
|
| 56 |
+
# If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio
|
| 57 |
+
target = F.resize(target, size)
|
| 58 |
+
return image, target
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class RandomHorizontalFlip(object):
|
| 62 |
+
def __init__(self, flip_prob):
|
| 63 |
+
self.flip_prob = flip_prob
|
| 64 |
+
|
| 65 |
+
def __call__(self, image, target):
|
| 66 |
+
if random.random() < self.flip_prob:
|
| 67 |
+
image = F.hflip(image)
|
| 68 |
+
target = F.hflip(target)
|
| 69 |
+
return image, target
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class RandomCrop(object):
|
| 73 |
+
def __init__(self, size):
|
| 74 |
+
self.size = size
|
| 75 |
+
|
| 76 |
+
def __call__(self, image, target):
|
| 77 |
+
image = pad_if_smaller(image, self.size)
|
| 78 |
+
target = pad_if_smaller(target, self.size, fill=255)
|
| 79 |
+
crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
|
| 80 |
+
image = F.crop(image, *crop_params)
|
| 81 |
+
target = F.crop(target, *crop_params)
|
| 82 |
+
return image, target
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class CenterCrop(object):
|
| 86 |
+
def __init__(self, size):
|
| 87 |
+
self.size = size
|
| 88 |
+
|
| 89 |
+
def __call__(self, image, target):
|
| 90 |
+
image = F.center_crop(image, self.size)
|
| 91 |
+
target = F.center_crop(target, self.size)
|
| 92 |
+
return image, target
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class ToTensor(object):
|
| 96 |
+
def __call__(self, image, target):
|
| 97 |
+
image = F.to_tensor(image)
|
| 98 |
+
target = torch.as_tensor(np.asarray(target).copy(), dtype=torch.int64)
|
| 99 |
+
return image, target
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class RandomAffine(object):
|
| 103 |
+
def __init__(self, angle, translate, scale, shear, resample=0, fillcolor=None):
|
| 104 |
+
self.angle = angle
|
| 105 |
+
self.translate = translate
|
| 106 |
+
self.scale = scale
|
| 107 |
+
self.shear = shear
|
| 108 |
+
self.resample = resample
|
| 109 |
+
self.fillcolor = fillcolor
|
| 110 |
+
|
| 111 |
+
def __call__(self, image, target):
|
| 112 |
+
affine_params = T.RandomAffine.get_params(self.angle, self.translate, self.scale, self.shear, image.size)
|
| 113 |
+
image = F.affine(image, *affine_params)
|
| 114 |
+
target = F.affine(target, *affine_params)
|
| 115 |
+
return image, target
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class Normalize(object):
|
| 119 |
+
def __init__(self, mean, std):
|
| 120 |
+
self.mean = mean
|
| 121 |
+
self.std = std
|
| 122 |
+
|
| 123 |
+
def __call__(self, image, target):
|
| 124 |
+
image = F.normalize(image, mean=self.mean, std=self.std)
|
| 125 |
+
return image, target
|
| 126 |
+
|
refer/utils.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
from collections import defaultdict, deque
|
| 3 |
+
import datetime
|
| 4 |
+
import math
|
| 5 |
+
import time
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
import torch.backends.cudnn as cudnn
|
| 9 |
+
|
| 10 |
+
import errno
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SmoothedValue(object):
|
| 17 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 18 |
+
window or the global series average.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, window_size=20, fmt=None):
|
| 22 |
+
if fmt is None:
|
| 23 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 24 |
+
self.deque = deque(maxlen=window_size)
|
| 25 |
+
self.total = 0.0
|
| 26 |
+
self.count = 0
|
| 27 |
+
self.fmt = fmt
|
| 28 |
+
|
| 29 |
+
def update(self, value, n=1):
|
| 30 |
+
self.deque.append(value)
|
| 31 |
+
self.count += n
|
| 32 |
+
self.total += value * n
|
| 33 |
+
|
| 34 |
+
def synchronize_between_processes(self):
|
| 35 |
+
"""
|
| 36 |
+
Warning: does not synchronize the deque!
|
| 37 |
+
"""
|
| 38 |
+
if not is_dist_avail_and_initialized():
|
| 39 |
+
return
|
| 40 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 41 |
+
dist.barrier()
|
| 42 |
+
dist.all_reduce(t)
|
| 43 |
+
t = t.tolist()
|
| 44 |
+
self.count = int(t[0])
|
| 45 |
+
self.total = t[1]
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def median(self):
|
| 49 |
+
d = torch.tensor(list(self.deque))
|
| 50 |
+
return d.median().item()
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def avg(self):
|
| 54 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 55 |
+
return d.mean().item()
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def global_avg(self):
|
| 59 |
+
return self.total / self.count
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def max(self):
|
| 63 |
+
return max(self.deque)
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def value(self):
|
| 67 |
+
return self.deque[-1]
|
| 68 |
+
|
| 69 |
+
def __str__(self):
|
| 70 |
+
return self.fmt.format(
|
| 71 |
+
median=self.median,
|
| 72 |
+
avg=self.avg,
|
| 73 |
+
global_avg=self.global_avg,
|
| 74 |
+
max=self.max,
|
| 75 |
+
value=self.value)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class MetricLogger(object):
|
| 79 |
+
def __init__(self, delimiter="\t"):
|
| 80 |
+
self.meters = defaultdict(SmoothedValue)
|
| 81 |
+
self.delimiter = delimiter
|
| 82 |
+
|
| 83 |
+
def update(self, **kwargs):
|
| 84 |
+
for k, v in kwargs.items():
|
| 85 |
+
if isinstance(v, torch.Tensor):
|
| 86 |
+
v = v.item()
|
| 87 |
+
assert isinstance(v, (float, int))
|
| 88 |
+
self.meters[k].update(v)
|
| 89 |
+
|
| 90 |
+
def __getattr__(self, attr):
|
| 91 |
+
if attr in self.meters:
|
| 92 |
+
return self.meters[attr]
|
| 93 |
+
if attr in self.__dict__:
|
| 94 |
+
return self.__dict__[attr]
|
| 95 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 96 |
+
type(self).__name__, attr))
|
| 97 |
+
|
| 98 |
+
def __str__(self):
|
| 99 |
+
loss_str = []
|
| 100 |
+
for name, meter in self.meters.items():
|
| 101 |
+
loss_str.append(
|
| 102 |
+
"{}: {}".format(name, str(meter))
|
| 103 |
+
)
|
| 104 |
+
return self.delimiter.join(loss_str)
|
| 105 |
+
|
| 106 |
+
def synchronize_between_processes(self):
|
| 107 |
+
for meter in self.meters.values():
|
| 108 |
+
meter.synchronize_between_processes()
|
| 109 |
+
|
| 110 |
+
def add_meter(self, name, meter):
|
| 111 |
+
self.meters[name] = meter
|
| 112 |
+
|
| 113 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 114 |
+
i = 0
|
| 115 |
+
if not header:
|
| 116 |
+
header = ''
|
| 117 |
+
start_time = time.time()
|
| 118 |
+
end = time.time()
|
| 119 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 120 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 121 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 122 |
+
log_msg = self.delimiter.join([
|
| 123 |
+
header,
|
| 124 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 125 |
+
'eta: {eta}',
|
| 126 |
+
'{meters}',
|
| 127 |
+
'time: {time}',
|
| 128 |
+
'data: {data}',
|
| 129 |
+
'max mem: {memory:.0f}'
|
| 130 |
+
])
|
| 131 |
+
MB = 1024.0 * 1024.0
|
| 132 |
+
for obj in iterable:
|
| 133 |
+
data_time.update(time.time() - end)
|
| 134 |
+
yield obj
|
| 135 |
+
iter_time.update(time.time() - end)
|
| 136 |
+
if i % print_freq == 0:
|
| 137 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 138 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 139 |
+
print(log_msg.format(
|
| 140 |
+
i, len(iterable), eta=eta_string,
|
| 141 |
+
meters=str(self),
|
| 142 |
+
time=str(iter_time), data=str(data_time),
|
| 143 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
| 144 |
+
sys.stdout.flush()
|
| 145 |
+
|
| 146 |
+
i += 1
|
| 147 |
+
end = time.time()
|
| 148 |
+
total_time = time.time() - start_time
|
| 149 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 150 |
+
print('{} Total time: {}'.format(header, total_time_str))
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def mkdir(path):
|
| 154 |
+
try:
|
| 155 |
+
os.makedirs(path)
|
| 156 |
+
except OSError as e:
|
| 157 |
+
if e.errno != errno.EEXIST:
|
| 158 |
+
raise
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def setup_for_distributed(is_master):
|
| 162 |
+
"""
|
| 163 |
+
This function disables printing when not in master process
|
| 164 |
+
"""
|
| 165 |
+
import builtins as __builtin__
|
| 166 |
+
builtin_print = __builtin__.print
|
| 167 |
+
|
| 168 |
+
def print(*args, **kwargs):
|
| 169 |
+
force = kwargs.pop('force', False)
|
| 170 |
+
if is_master or force:
|
| 171 |
+
builtin_print(*args, **kwargs)
|
| 172 |
+
|
| 173 |
+
__builtin__.print = print
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def is_dist_avail_and_initialized():
|
| 177 |
+
if not dist.is_available():
|
| 178 |
+
return False
|
| 179 |
+
if not dist.is_initialized():
|
| 180 |
+
return False
|
| 181 |
+
return True
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def get_world_size():
|
| 185 |
+
if not is_dist_avail_and_initialized():
|
| 186 |
+
return 1
|
| 187 |
+
return dist.get_world_size()
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def get_rank():
|
| 191 |
+
if not is_dist_avail_and_initialized():
|
| 192 |
+
return 0
|
| 193 |
+
return dist.get_rank()
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def is_main_process():
|
| 197 |
+
return get_rank() == 0
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def save_on_master(*args, **kwargs):
|
| 201 |
+
if is_main_process():
|
| 202 |
+
torch.save(*args, **kwargs)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def init_distributed_mode(args):
|
| 206 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 207 |
+
rank = int(os.environ["RANK"])
|
| 208 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 209 |
+
print(f"RANK and WORLD_SIZE in environment: {rank}/{world_size}")
|
| 210 |
+
else:
|
| 211 |
+
rank = -1
|
| 212 |
+
world_size = -1
|
| 213 |
+
|
| 214 |
+
torch.cuda.set_device(args.local_rank)
|
| 215 |
+
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
|
| 216 |
+
torch.distributed.barrier()
|
| 217 |
+
setup_for_distributed(is_main_process())
|
| 218 |
+
|
| 219 |
+
if args.output_dir:
|
| 220 |
+
mkdir(args.output_dir)
|
| 221 |
+
if args.model_id:
|
| 222 |
+
mkdir(os.path.join('./models/', args.model_id))
|
refer/v1-inference.yaml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-04
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.00085
|
| 6 |
+
linear_end: 0.0120
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: "jpg"
|
| 11 |
+
cond_stage_key: "txt"
|
| 12 |
+
image_size: 64
|
| 13 |
+
channels: 4
|
| 14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
| 15 |
+
conditioning_key: crossattn
|
| 16 |
+
monitor: val/loss_simple_ema
|
| 17 |
+
scale_factor: 0.18215
|
| 18 |
+
use_ema: False
|
| 19 |
+
|
| 20 |
+
scheduler_config: # 10000 warmup steps
|
| 21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
| 22 |
+
params:
|
| 23 |
+
warm_up_steps: [ 10000 ]
|
| 24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
| 25 |
+
f_start: [ 1.e-6 ]
|
| 26 |
+
f_max: [ 1. ]
|
| 27 |
+
f_min: [ 1. ]
|
| 28 |
+
|
| 29 |
+
unet_config:
|
| 30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 31 |
+
params:
|
| 32 |
+
image_size: 32 # unused
|
| 33 |
+
in_channels: 4
|
| 34 |
+
out_channels: 4
|
| 35 |
+
model_channels: 320
|
| 36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
| 37 |
+
num_res_blocks: 2
|
| 38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
| 39 |
+
num_heads: 8
|
| 40 |
+
use_spatial_transformer: True
|
| 41 |
+
transformer_depth: 1
|
| 42 |
+
context_dim: 768
|
| 43 |
+
use_checkpoint: True
|
| 44 |
+
legacy: False
|
| 45 |
+
|
| 46 |
+
first_stage_config:
|
| 47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
| 48 |
+
params:
|
| 49 |
+
embed_dim: 4
|
| 50 |
+
monitor: val/rec_loss
|
| 51 |
+
ddconfig:
|
| 52 |
+
double_z: true
|
| 53 |
+
z_channels: 4
|
| 54 |
+
resolution: 256
|
| 55 |
+
in_channels: 3
|
| 56 |
+
out_ch: 3
|
| 57 |
+
ch: 128
|
| 58 |
+
ch_mult:
|
| 59 |
+
- 1
|
| 60 |
+
- 2
|
| 61 |
+
- 4
|
| 62 |
+
- 4
|
| 63 |
+
num_res_blocks: 2
|
| 64 |
+
attn_resolutions: []
|
| 65 |
+
dropout: 0.0
|
| 66 |
+
lossconfig:
|
| 67 |
+
target: torch.nn.Identity
|
| 68 |
+
|
| 69 |
+
cond_stage_config:
|
| 70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|