diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..7e5dccd9d003b824c73728c7569263ce7813ddc1 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..97a202f1fdc037f2b1ebefc453601d74b75af850 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/cars.mp4 filter=lfs diff=lfs merge=lfs -text +assets/cell.mp4 filter=lfs diff=lfs merge=lfs -text +assets/demo_3x2.gif filter=lfs diff=lfs merge=lfs -text +assets/top.gif filter=lfs diff=lfs merge=lfs -text diff --git a/__pycache__/SegTracker.cpython-310.pyc b/__pycache__/SegTracker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..722e9c159b0e5c18108cc09594ba68798f711fae Binary files /dev/null and b/__pycache__/SegTracker.cpython-310.pyc differ diff --git a/__pycache__/aot_tracker.cpython-310.pyc b/__pycache__/aot_tracker.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7109a7edfd9f58b072a88090ab42dbe4de4fb175 Binary files /dev/null and b/__pycache__/aot_tracker.cpython-310.pyc differ diff --git a/__pycache__/model_args.cpython-310.pyc b/__pycache__/model_args.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3a205ccc85f308d4da903f863796440276e49c0 Binary files /dev/null and b/__pycache__/model_args.cpython-310.pyc differ diff --git a/__pycache__/seg_track_anything.cpython-310.pyc b/__pycache__/seg_track_anything.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3cac4cca7dd75aad4c430b6fe101677055ed49c Binary files /dev/null and b/__pycache__/seg_track_anything.cpython-310.pyc differ diff --git a/aot/.DS_Store b/aot/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..94c8f55ca1cc85089b2dd02e7257edfa07f8cd4d Binary files /dev/null and b/aot/.DS_Store differ diff --git a/aot/LICENSE b/aot/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..5bcf93f499a3f4dab26cd1047196475da07c9336 --- /dev/null +++ b/aot/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2020, z-x-yang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/aot/MODEL_ZOO.md b/aot/MODEL_ZOO.md new file mode 100644 index 0000000000000000000000000000000000000000..3e1dbb059413820142dba68e246bc0bd53110a37 --- /dev/null +++ b/aot/MODEL_ZOO.md @@ -0,0 +1,115 @@ +## Model Zoo and Results + +### Environment and Settings +- 4/1 NVIDIA V100 GPUs for training/evaluation. +- Auto-mixed precision was enabled in training but disabled in evaluation. +- Test-time augmentations were not used. +- The inference resolution of DAVIS/YouTube-VOS was 480p/1.3x480p as [CFBI](https://github.com/z-x-yang/CFBI). +- Fully online inference. We passed all the modules frame by frame. +- Multi-object FPS was recorded instead of single-object one. + +### Pre-trained Models +Stages: + +- `PRE`: the pre-training stage with static images. + +- `PRE_YTB_DAV`: the main-training stage with YouTube-VOS and DAVIS. All the kinds of evaluation share an **identical** model and the **same** parameters. + + +| Model | Param (M) | PRE | PRE_YTB_DAV | +|:---------- |:---------:|:--------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:| +| AOTT | 5.7 | [gdrive](https://drive.google.com/file/d/1_513h8Hok9ySQPMs_dHgX5sPexUhyCmy/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1owPmwV4owd_ll6GuilzklqTyAd0ZvbCu/view?usp=sharing) | +| AOTS | 7.0 | [gdrive](https://drive.google.com/file/d/1QUP0-VED-lOF1oX_ppYWnXyBjvUzJJB7/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1beU5E6Mdnr_pPrgjWvdWurKAIwJSz1xf/view?usp=sharing) | +| AOTB | 8.3 | [gdrive](https://drive.google.com/file/d/11Bx8n_INAha1IdpHjueGpf7BrKmCJDvK/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1hH-GOn4GAxHkV8ARcQzsUy8Ax6ndot-A/view?usp=sharing) | +| AOTL | 8.3 | [gdrive](https://drive.google.com/file/d/1WL6QCsYeT7Bt-Gain9ZIrNNXpR2Hgh29/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1L1N2hkSPqrwGgnW9GyFHuG59_EYYfTG4/view?usp=sharing) | +| R50-AOTL | 14.9 | [gdrive](https://drive.google.com/file/d/1hS4JIvOXeqvbs-CokwV6PwZV-EvzE6x8/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1qJDYn3Ibpquu4ffYoQmVjg1YCbr2JQep/view?usp=sharing) | +| SwinB-AOTL | 65.4 | [gdrive](https://drive.google.com/file/d/1LlhKQiXD8JyZGGs3hZiNzcaCLqyvL9tj/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/192jCGQZdnuTsvX-CVra-KVZl2q1ZR0vW/view?usp=sharing) | + +| Model | Param (M) | PRE | PRE_YTB_DAV | +|:---------- |:---------:|:--------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:| +| DeAOTT | 7.2 | [gdrive](https://drive.google.com/file/d/11C1ZBoFpL3ztKtINS8qqwPSldfYXexFK/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1ThWIZQS03cYWx1EKNN8MIMnJS5eRowzr/view?usp=sharing) | +| DeAOTS | 10.2 | [gdrive](https://drive.google.com/file/d/1uUidrWVoaP9A5B5-EzQLbielUnRLRF3j/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1YwIAV5tBtn5spSFxKLBQBEQGwPHyQlHi/view?usp=sharing) | +| DeAOTB | 13.2 | [gdrive](https://drive.google.com/file/d/1bEQr6vIgQMVITrSOtxWTMgycKpS0cor9/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1BHxsonnvJXylqHlZ1zJHHc-ymKyq-CFf/view?usp=sharing) | +| DeAOTL | 13.2 | [gdrive](https://drive.google.com/file/d/1_vBL4KJlmBy0oBE4YFDOvsYL1ZtpEL32/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/18elNz_wi9JyVBcIUYKhRdL08MA-FqHD5/view?usp=sharing) | +| R50-DeAOTL | 19.8 | [gdrive](https://drive.google.com/file/d/1sTRQ1g0WCpqVCdavv7uJiZNkXunBt3-R/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1QoChMkTVxdYZ_eBlZhK2acq9KMQZccPJ/view?usp=sharing) | +| SwinB-DeAOTL | 70.3 | [gdrive](https://drive.google.com/file/d/16BZEE53no8CxT-pPLDC2q1d6Xlg8mWPU/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1g4E-F0RPOx9Nd6J7tU9AE1TjsouL4oZq/view?usp=sharing) | + +To use our pre-trained model to infer, a simple way is to set `--model` and `--ckpt_path` to your downloaded checkpoint's model type and file path when running `eval.py`. + +### YouTube-VOS 2018 val +`ALL-F`: all frames. The default evaluation setting of YouTube-VOS is 6fps, but 30fps sequences (all the frames) are also supplied by the dataset organizers. We noticed that many VOS methods prefer to evaluate with 30fps videos. Thus, we also supply our results here. Denser video sequences can significantly improve VOS performance when using the memory reading strategy (like AOTL, R50-AOTL, and SwinB-AOTL), but the efficiency will be influenced since more memorized frames are stored for object matching. +| Model | Stage | FPS | All-F | Mean | J Seen | F Seen | J Unseen | F Unseen | Predictions | +|:------------ |:-----------:|:--------:|:-----:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------------------------------------------------------------------------------------------:| +| AOTT | PRE_YTB_DAV | 41.0 | | 80.2 | 80.4 | 85.0 | 73.6 | 81.7 | [gdrive](https://drive.google.com/file/d/1u8mvPRT08ENZHsw9Xf_4C6Sv9BoCzENR/view?usp=sharing) | +| AOTT | PRE_YTB_DAV | 41.0 | √ | 80.9 | 80.0 | 84.7 | 75.2 | 83.5 | [gdrive](https://drive.google.com/file/d/1RGMI5-29Z0odq73rt26eCxOUYUd-fvVv/view?usp=sharing) | +| DeAOTT | PRE_YTB_DAV | **53.4** | | **82.0** | **81.6** | **86.3** | **75.8** | **84.2** | - | +| AOTS | PRE_YTB_DAV | 27.1 | | 82.9 | 82.3 | 87.0 | 77.1 | 85.1 | [gdrive](https://drive.google.com/file/d/1a4-rNnxjMuPBq21IKo31WDYZXMPgS7r2/view?usp=sharing) | +| AOTS | PRE_YTB_DAV | 27.1 | √ | 83.0 | 82.2 | 87.0 | 77.3 | 85.7 | [gdrive](https://drive.google.com/file/d/1Z0cndyoCw5Na6u-VFRE8CyiIG2RbMIUO/view?usp=sharing) | +| DeAOTS | PRE_YTB_DAV | **38.7** | | **84.0** | **83.3** | **88.3** | **77.9** | **86.6** | - | +| AOTB | PRE_YTB_DAV | 20.5 | | 84.0 | 83.2 | 88.1 | 78.0 | 86.5 | [gdrive](https://drive.google.com/file/d/1J5nhuQbbjVLYNXViBIgo21ddQy-MiOLG/view?usp=sharing) | +| AOTB | PRE_YTB_DAV | 20.5 | √ | 84.1 | 83.6 | 88.5 | 78.0 | 86.5 | [gdrive](https://drive.google.com/file/d/1gFaweB_GTJjHzSD61v_ZsY9K7UEND30O/view?usp=sharing) | +| DeAOTB | PRE_YTB_DAV | **30.4** | | **84.6** | **83.9** | **88.9** | **78.5** | **87.0** | - | +| AOTL | PRE_YTB_DAV | 16.0 | | 84.1 | 83.2 | 88.2 | 78.2 | 86.8 | [gdrive](https://drive.google.com/file/d/1kS8KWQ2L3wzxt44ROLTxwZOT7ZpT8Igc/view?usp=sharing) | +| AOTL | PRE_YTB_DAV | 6.5 | √ | 84.5 | 83.7 | 88.8 | 78.4 | **87.1** | [gdrive](https://drive.google.com/file/d/1Rpm3e215kJOUvb562lJ2kYg2I3hkrxiM/view?usp=sharing) | +| DeAOTL | PRE_YTB_DAV | **24.7** | | **84.8** | **84.2** | **89.4** | **78.6** | 87.0 | - | +| R50-AOTL | PRE_YTB_DAV | 14.9 | | 84.6 | 83.7 | 88.5 | 78.8 | 87.3 | [gdrive](https://drive.google.com/file/d/1nbJZ1bbmEgyK-bg6HQ8LwCz5gVJ6wzIZ/view?usp=sharing) | +| R50-AOTL | PRE_YTB_DAV | 6.4 | √ | 85.5 | 84.5 | 89.5 | 79.6 | 88.2 | [gdrive](https://drive.google.com/file/d/1NbB54ZhYvfJh38KFOgovYYPjWopd-2TE/view?usp=sharing) | +| R50-DeAOTL | PRE_YTB_DAV | **22.4** | | **86.0** | **84.9** | **89.9** | **80.4** | **88.7** | - | +| SwinB-AOTL | PRE_YTB_DAV | 9.3 | | 84.7 | 84.5 | 89.5 | 78.1 | 86.7 | [gdrive](https://drive.google.com/file/d/1QFowulSY0LHfpsjUV8ZE9rYc55L9DOC7/view?usp=sharing) | +| SwinB-AOTL | PRE_YTB_DAV | 5.2 | √ | 85.1 | 85.1 | 90.1 | 78.4 | 86.9 | [gdrive](https://drive.google.com/file/d/1TulhVOhh01rkssNYbOQASeWKu7CQ5Azx/view?usp=sharing) | +| SwinB-DeAOTL | PRE_YTB_DAV | **11.9** | | **86.2** | **85.6** | **90.6** | **80.0** | **88.4** | - | + +### YouTube-VOS 2019 val +| Model | Stage | FPS | All-F | Mean | J Seen | F Seen | J Unseen | F Unseen | Predictions | +|:------------ |:-----------:|:--------:|:-----:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------------------------------------------------------------------------------------------:| +| AOTT | PRE_YTB_DAV | 41.0 | | 80.0 | 79.8 | 84.2 | 74.1 | 82.1 | [gdrive](https://drive.google.com/file/d/1zzyhN1XYtajte5nbZ7opOdfXeDJgCxC5/view?usp=sharing) | +| AOTT | PRE_YTB_DAV | 41.0 | √ | 80.9 | 79.9 | 84.4 | 75.6 | 83.8 | [gdrive](https://drive.google.com/file/d/1V_5vi9dAXOis_WrDieacSESm7OX20Bv-/view?usp=sharing) | +| DeAOTT | PRE_YTB_DAV | **53.4** | | **82.0** | **81.2** | **85.6** | **76.4** | **84.7** | - | +| AOTS | PRE_YTB_DAV | 27.1 | | 82.7 | 81.9 | 86.5 | 77.3 | 85.2 | [gdrive](https://drive.google.com/file/d/11YdkUeyjkTv8Uw7xMgPCBzJs6v5SDt6n/view?usp=sharing) | +| AOTS | PRE_YTB_DAV | 27.1 | √ | 82.8 | 81.9 | 86.5 | 77.3 | 85.6 | [gdrive](https://drive.google.com/file/d/1UhyurGTJeAw412czU3_ebzNwF8xQ4QG_/view?usp=sharing) | +| DeAOTS | PRE_YTB_DAV | **38.7** | | **83.8** | **82.8** | **87.5** | **78.1** | **86.8** | - | +| AOTB | PRE_YTB_DAV | 20.5 | | 84.0 | 83.1 | 87.7 | 78.5 | 86.8 | [gdrive](https://drive.google.com/file/d/1NeI8cT4kVqTqVWAwtwiga1rkrvksNWaO/view?usp=sharing) | +| AOTB | PRE_YTB_DAV | 20.5 | √ | 84.1 | 83.3 | 88.0 | 78.2 | 86.7 | [gdrive](https://drive.google.com/file/d/1kpYV2XFR0sOfLWD-wMhd-nUO6CFiLjlL/view?usp=sharing) | +| DeAOTB | PRE_YTB_DAV | **30.4** | | **84.6** | **83.5** | **88.3** | **79.1** | **87.5** | - | +| AOTL | PRE_YTB_DAV | 16.0 | | 84.0 | 82.8 | 87.6 | 78.6 | 87.1 | [gdrive](https://drive.google.com/file/d/1qKLlNXxmT31bW0weEHI_zAf4QwU8Lhou/view?usp=sharing) | +| AOTL | PRE_YTB_DAV | 6.5 | √ | 84.2 | 83.0 | 87.8 | 78.7 | 87.3 | [gdrive](https://drive.google.com/file/d/1o3fwZ0cH71bqHSA3bYNjhP4GGv9Vyuwa/view?usp=sharing) | +| DeAOTL | PRE_YTB_DAV | **24.7** | | **84.7** | **83.8** | **88.8** | **79.0** | **87.2** | - | +| R50-AOTL | PRE_YTB_DAV | 14.9 | | 84.4 | 83.4 | 88.1 | 78.7 | 87.2 | [gdrive](https://drive.google.com/file/d/1I7ooSp8EYfU6fvkP6QcCMaxeencA68AH/view?usp=sharing) | +| R50-AOTL | PRE_YTB_DAV | 6.4 | √ | 85.3 | 83.9 | 88.8 | 79.9 | 88.5 | [gdrive](https://drive.google.com/file/d/1OGqlkEu0uXa8QVWIVz_M5pmXXiYR2sh3/view?usp=sharing) | +| R50-DeAOTL | PRE_YTB_DAV | **22.4** | | **85.9** | **84.6** | **89.4** | **80.8** | **88.9** | - | +| SwinB-AOTL | PRE_YTB_DAV | 9.3 | | 84.7 | 84.0 | 88.8 | 78.7 | 87.1 | [gdrive](https://drive.google.com/file/d/1fPzCxi5GM7N2sLKkhoTC2yoY_oTQCHp1/view?usp=sharing) | +| SwinB-AOTL | PRE_YTB_DAV | 5.2 | √ | 85.3 | 84.6 | 89.5 | 79.3 | 87.7 | [gdrive](https://drive.google.com/file/d/1e3D22s_rJ7Y2X2MHo7x5lcNtwmHFlwYB/view?usp=sharing) | +| SwinB-DeAOTL | PRE_YTB_DAV | **11.9** | | **86.1** | **85.3** | **90.2** | **80.4** | **88.6** | - | + +### DAVIS-2017 test + +| Model | Stage | FPS | Mean | J Score | F Score | Predictions | +| ---------- |:-----------:|:----:|:--------:|:--------:|:--------:|:----:| +| AOTT | PRE_YTB_DAV | **51.4** | 73.7 | 70.0 | 77.3 | [gdrive](https://drive.google.com/file/d/14Pu-6Uz4rfmJ_WyL2yl57KTx_pSSUNAf/view?usp=sharing) | +| AOTS | PRE_YTB_DAV | 40.0 | 75.2 | 71.4 | 78.9 | [gdrive](https://drive.google.com/file/d/1zzAPZCRLgnBWuAXqejPPEYLqBxu67Rj1/view?usp=sharing) | +| AOTB | PRE_YTB_DAV | 29.6 | 77.4 | 73.7 | 81.1 | [gdrive](https://drive.google.com/file/d/1WpQ-_Jrs7Ssfw0oekrejM2OVWEx_tBN1/view?usp=sharing) | +| AOTL | PRE_YTB_DAV | 18.7 | 79.3 | 75.5 | 83.2 | [gdrive](https://drive.google.com/file/d/1rP1Zdgc0N1d8RR2EaXMz3F-o5zqcNVe8/view?usp=sharing) | +| R50-AOTL | PRE_YTB_DAV | 18.0 | 79.5 | 76.0 | 83.0 | [gdrive](https://drive.google.com/file/d/1iQ5iNlvlS-In586ZNc4LIZMSdNIWDvle/view?usp=sharing) | +| SwinB-AOTL | PRE_YTB_DAV | 12.1 | **82.1** | **78.2** | **85.9** | [gdrive](https://drive.google.com/file/d/1oVt4FPcZdfVHiOxjYYKef0q7Ovy4f5Q_/view?usp=sharing) | + +### DAVIS-2017 val + +| Model | Stage | FPS | Mean | J Score | F Score | Predictions | +| ---------- |:-----------:|:----:|:--------:|:--------:|:---------:|:----:| +| AOTT | PRE_YTB_DAV | **51.4** | 79.2 | 76.5 | 81.9 | [gdrive](https://drive.google.com/file/d/10OUFhK2Sz-hOJrTDoTI0mA45KO1qodZt/view?usp=sharing) | +| AOTS | PRE_YTB_DAV | 40.0 | 82.1 | 79.3 | 84.8 | [gdrive](https://drive.google.com/file/d/1T-JTYyksWlq45jxcLjnRaBvvYUhWgHFH/view?usp=sharing) | +| AOTB | PRE_YTB_DAV | 29.6 | 83.3 | 80.6 | 85.9 | [gdrive](https://drive.google.com/file/d/1EVUnxQm9TLBTuwK82QyiSKk9R9V8NwRL/view?usp=sharing) | +| AOTL | PRE_YTB_DAV | 18.7 | 83.6 | 80.8 | 86.3 | [gdrive](https://drive.google.com/file/d/1CFauSni2BxAe_fcl8W_6bFByuwJRbDYm/view?usp=sharing) | +| R50-AOTL | PRE_YTB_DAV | 18.0 | 85.2 | 82.5 | 87.9 | [gdrive](https://drive.google.com/file/d/1vjloxnP8R4PZdsH2DDizfU2CrkdRHHyo/view?usp=sharing) | +| SwinB-AOTL | PRE_YTB_DAV | 12.1 | **85.9** | **82.9** | **88.9** | [gdrive](https://drive.google.com/file/d/1tYCbKOas0i7Et2iyUAyDwaXnaD9YWxLr/view?usp=sharing) | + +### DAVIS-2016 val + +| Model | Stage | FPS | Mean | J Score | F Score | Predictions | +| ---------- |:-----------:|:----:|:--------:|:--------:|:--------:|:----:| +| AOTT | PRE_YTB_DAV | **51.4** | 87.5 | 86.5 | 88.4 | [gdrive](https://drive.google.com/file/d/1LeW8WQhnylZ3umT7E379KdII92uUsGA9/view?usp=sharing) | +| AOTS | PRE_YTB_DAV | 40.0 | 89.6 | 88.6 | 90.5 | [gdrive](https://drive.google.com/file/d/1vqGei5tLu1FPVrTi5bwRAsaGy3Upf7B1/view?usp=sharing) | +| AOTB | PRE_YTB_DAV | 29.6 | 90.9 | 89.6 | 92.1 | [gdrive](https://drive.google.com/file/d/1qAppo2uOVu0FbE9t1FBUpymC3yWgw1LM/view?usp=sharing) | +| AOTL | PRE_YTB_DAV | 18.7 | 91.1 | 89.5 | 92.7 | [gdrive](https://drive.google.com/file/d/1g6cjYhgBWjMaY3RGAm31qm3SPEF3QcKV/view?usp=sharing) | +| R50-AOTL | PRE_YTB_DAV | 18.0 | 91.7 | 90.4 | 93.0 | [gdrive](https://drive.google.com/file/d/1QzxojqWKsvRf53K2AgKsK523ZVuYU4O-/view?usp=sharing) | +| SwinB-AOTL | PRE_YTB_DAV | 12.1 | **92.2** | **90.6** | **93.8** | [gdrive](https://drive.google.com/file/d/1RIqUtAyVnopeogfT520d7a0yiULg1obp/view?usp=sharing) | diff --git a/aot/Pytorch-Correlation-extension/.gitignore b/aot/Pytorch-Correlation-extension/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..64efac494f6d5109117c8446a7940a0a446be1db --- /dev/null +++ b/aot/Pytorch-Correlation-extension/.gitignore @@ -0,0 +1 @@ +*.egg* diff --git a/aot/Pytorch-Correlation-extension/Correlation_Module/correlation.cpp b/aot/Pytorch-Correlation-extension/Correlation_Module/correlation.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d0e76f6c6c6793f2b0177c0c666c69fa621806d2 --- /dev/null +++ b/aot/Pytorch-Correlation-extension/Correlation_Module/correlation.cpp @@ -0,0 +1,178 @@ +#include +using namespace torch; + +#include + +#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W) + +template +static void correlate_patch( + TensorAccessor input1, + TensorAccessor input2, + scalar_t *dst, + int kH, int kW, + int dilationH, int dilationW, + int u, int v, + int shiftU, int shiftV){ + const int C = input1.size(0); + const int iH = input1.size(1); + const int iW = input1.size(2); + for (int c=0; c +static void correlate_patch_grad( + TensorAccessor input1, + TensorAccessor gradInput1, + TensorAccessor input2, + TensorAccessor gradInput2, + scalar_t gradOutput, + int kH, int kW, + int dilationH, int dilationW, + int u, int v, + int shiftU, int shiftV){ + + const int C = input1.size(0); + const int iH = input1.size(1); + const int iW = input1.size(2); + + for (int c=0; c(); + auto input2_acc = input2.accessor(); + auto output_acc = output.accessor(); + for (h = 0; h < oH; ++h) { + for (w = 0; w < oW; ++w) { + correlate_patch(input1_acc[n], + input2_acc[n], + &output_acc[n][ph][pw][h][w], + kH, kW, + dilationH, dilationW, + -padH + h * dH, + -padW + w * dW, + (ph - patchRadH) * dilation_patchH, + (pw - patchRadW) * dilation_patchW); + } + } + })); + } + } + } + return output; +} + +std::vector correlation_cpp_backward( + torch::Tensor input1, + torch::Tensor input2, + torch::Tensor gradOutput, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW) { + + const int batch_size = input1.size(0); + const int patchRadH = (patchH - 1) / 2; + const int patchRadW = (patchW - 1) / 2; + const int oH = gradOutput.size(3); + const int oW = gradOutput.size(4); + + auto gradInput1 = torch::zeros_like(input1); + + auto gradInput2 = torch::zeros_like(input2); + + int n, ph, pw, h, w; + #pragma omp parallel for private(n, ph, pw, h, w) + for (n = 0; n < batch_size; ++n) { + AT_DISPATCH_FLOATING_TYPES(input1.scalar_type(), "correlation_backward_cpp", ([&] { + auto input1_acc = input1.accessor(); + auto gradInput1_acc = gradInput1.accessor(); + auto input2_acc = input2.accessor(); + auto gradInput2_acc = gradInput2.accessor(); + auto gradOutput_acc = gradOutput.accessor(); + + for(ph = 0; ph < patchH; ++ph){ + for(pw = 0; pw < patchW; ++pw){ + for (h = 0; h < oH; ++h) { + for (w = 0; w < oW; ++w) { + correlate_patch_grad(input1_acc[n], gradInput1_acc[n], + input2_acc[n], gradInput2_acc[n], + gradOutput_acc[n][ph][pw][h][w], + kH, kW, + dilationH, dilationW, + -padH + h * dH, + -padW + w * dW, + (ph - patchRadH) * dilation_patchH, + (pw - patchRadW) * dilation_patchW); + } + } + } + } + })); + } + + return {gradInput1, gradInput2}; +} diff --git a/aot/Pytorch-Correlation-extension/Correlation_Module/correlation_cuda_kernel.cu b/aot/Pytorch-Correlation-extension/Correlation_Module/correlation_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..84b395494685f8f9b35e8723ff884869f7930bb6 --- /dev/null +++ b/aot/Pytorch-Correlation-extension/Correlation_Module/correlation_cuda_kernel.cu @@ -0,0 +1,327 @@ +#include +using namespace torch; + +#include +#include + +#include +#include + +// Cuda tensor accessor definitions +// restrict pointer traits piroritize speed over memory consumption +#define TensorAcc4R PackedTensorAccessor32 +#define TensorAcc5R PackedTensorAccessor32 +#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W) + +#define THREADS_FORWARD 32 +#define THREADS_BACKWARD 5 + + +namespace corr { +template +__global__ void correlation_cuda_forward_kernel( + const TensorAcc4R rInput1, + const TensorAcc4R rInput2, + TensorAcc5R output, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW) { + + const int iH = rInput1.size(1); + const int iW = rInput1.size(2); + const int C = rInput1.size(3); + + const int n = blockIdx.x; + const int h = blockIdx.y; + const int w = blockIdx.z; + const int thread = threadIdx.x; + + const int start_i = -padH + h * dH; + const int start_j = -padW + w * dW; + + const int patchRadH = dilation_patchH * (patchH - 1) / 2; + const int patchRadW = dilation_patchW * (patchW - 1) / 2; + + __shared__ scalar_t prod_sum[THREADS_FORWARD]; + + for(int ph = 0; ph < patchH; ++ph){ + int ph_dilated = ph * dilation_patchH - patchRadH; + for(int pw = 0; pw < patchW; ++pw){ + int pw_dilated = pw * dilation_patchW - patchRadW; + prod_sum[thread] = 0; + for (int i=0; i +__global__ void correlation_cuda_backward_kernel_input1( + const TensorAcc5R gradOutput, + const TensorAcc4R input2, + TensorAcc4R gradInput1, + const int kH, const int kW, + const int patchH, const int patchW, + const int padH, const int padW, + const int dilationH, const int dilationW, + const int dilation_patchH, const int dilation_patchW, + const int dH, const int dW, + const int batch) { + const int iH = input2.size(2); + const int iW = input2.size(3); + + const int H = gradOutput.size(3); + const int W = gradOutput.size(4); + + const int patchRadH = (patchH - 1) / 2; + const int patchRadW = (patchW - 1) / 2; + + const int n = batch; + const int c = blockIdx.x; + const int h = blockIdx.y; + const int w = blockIdx.z; + const int ph_off = threadIdx.x; + const int pw_off = threadIdx.y; + + const int h_2 = h + padH; + const int w_2 = w + padW; + const int min_h = h_2 - kH * dilationH; + const int min_w = w_2 - kW * dilationW; + + __shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD]; + prod_sum[ph_off][pw_off] = 0; + + for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) { + int i1 = h + dilation_patchH * (ph - patchRadH); + for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) { + int j1 = w + dilation_patchW * (pw - patchRadW); + if (WITHIN_BOUNDS(i1, j1, iH, iW)){ + scalar_t val = input2[n][c][i1][j1]; + for(int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) { + int i2 = (h_3)/dH; + if (i2 * dH != h_3) + continue; + for(int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) { + int j2 = (w_3) / dW; + if(j2 * dW != w_3) + continue; + if WITHIN_BOUNDS(i2, j2, H, W) { + prod_sum[ph_off][pw_off] += gradOutput[n][ph][pw][i2][j2] * val; + } + } + } + } + } + } + + __syncthreads(); + + if (ph_off == 0 && pw_off == 0){ + scalar_t reduce_sum =0; + for (int ph = 0; ph < THREADS_BACKWARD; ++ph){ + for (int pw = 0; pw < THREADS_BACKWARD; ++pw){ + reduce_sum += prod_sum[ph][pw]; + } + } + gradInput1[n][c][h][w] = reduce_sum; + } +} + + +template +__global__ void correlation_cuda_backward_kernel_input2( + const TensorAcc5R gradOutput, + const TensorAcc4R input1, + TensorAcc4R gradInput2, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW, + int batch) { + const int iH = input1.size(2); + const int iW = input1.size(3); + + const int patchRadH = (patchH - 1) / 2; + const int patchRadW = (patchW - 1) / 2; + + const int H = gradOutput.size(3); + const int W = gradOutput.size(4); + + const int dilatedKH = kH * dilationH; + const int dilatedKW = kW * dilationW; + + const int n = batch; + const int c = blockIdx.x; + const int h = blockIdx.y; + const int w = blockIdx.z; + const int ph_off = threadIdx.x; + const int pw_off = threadIdx.y; + + __shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD]; + prod_sum[ph_off][pw_off] = 0; + + for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) { + int i1 = h - dilation_patchH * (ph - patchRadH); + for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) { + int j1 = w - dilation_patchW * (pw - patchRadW); + if WITHIN_BOUNDS(i1, j1, iH, iW) { + scalar_t val = input1[n][c][i1][j1]; + + const int h_2 = i1 + padH; + const int w_2 = j1 + padW; + const int min_h = h_2 - dilatedKH; + const int min_w = w_2 - dilatedKW; + + for(int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) { + int i2 = (h_3)/dH; + if (i2 * dH != h_3) + continue; + for(int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) { + int j2 = (w_3) / dW; + if(j2 * dW != w_3) + continue; + if WITHIN_BOUNDS(i2, j2, H, W) { + prod_sum[ph_off][pw_off] += gradOutput[n][ph][pw][i2][j2] * val; + } + } + } + } + } + } + + __syncthreads(); + + if (ph_off == 0 && pw_off == 0){ + scalar_t reduce_sum =0; + for (int ph = 0; ph < THREADS_BACKWARD; ++ph){ + for (int pw = 0; pw < THREADS_BACKWARD; ++pw){ + reduce_sum += prod_sum[ph][pw]; + } + } + gradInput2[n][c][h][w] = reduce_sum; + } +} +} // namsepace corr + +torch::Tensor correlation_cuda_forward( + torch::Tensor input1, + torch::Tensor input2, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW) { + + const int batch_size = input1.size(0); + const int iH = input1.size(2); + const int iW = input1.size(3); + const int dilatedKH = (kH - 1) * dilationH + 1; + const int dilatedKW = (kW - 1) * dilationW + 1; + + const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1; + const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1; + auto output = torch::zeros({batch_size, patchH, patchW, oH, oW}, input1.options()); + + auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous(); + auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous(); + + const int threads = THREADS_FORWARD; + const dim3 blocks(batch_size, oH, oW); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(), "correlation_forward_cuda", ([&] { + TensorAcc4R trInput1_acc = trInput1.packed_accessor32(); + TensorAcc4R trInput2_acc = trInput2.packed_accessor32(); + TensorAcc5R output_acc = output.packed_accessor32(); + corr::correlation_cuda_forward_kernel<<>>( + trInput1_acc, trInput2_acc, output_acc, + kH, kW, patchH, patchW, padH, padW, dilationH, dilationW, + dilation_patchH, dilation_patchW, dH, dW); + })); + + return output; +} + +std::vector correlation_cuda_backward( + torch::Tensor input1, + torch::Tensor input2, + torch::Tensor gradOutput, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW) { + + auto gradInput1 = torch::zeros_like(input1); + auto gradInput2 = torch::zeros_like(input2); + + const int batch_size = input1.size(0); + const int iH = input1.size(2); + const int iW = input1.size(3); + const int C = input1.size(1); + + const dim3 blocks(C, iH, iW); + const dim3 threads(THREADS_BACKWARD, THREADS_BACKWARD); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(), "correlation_backward_cuda", ([&] { + TensorAcc4R input1_acc = input1.packed_accessor32(); + TensorAcc4R input2_acc = input2.packed_accessor32(); + TensorAcc4R gradInput1_acc = gradInput1.packed_accessor32(); + TensorAcc4R gradInput2_acc = gradInput2.packed_accessor32(); + TensorAcc5R gradOutput_acc = gradOutput.packed_accessor32(); + + + for (int n = 0; n < batch_size; ++n){ + corr::correlation_cuda_backward_kernel_input1<<>>( + gradOutput_acc, input2_acc, gradInput1_acc, + kH, kW, patchH, patchW, padH, padW, + dilationH, dilationW, + dilation_patchH, dilation_patchW, + dH, dW, + n); + } + + for (int n = 0; n < batch_size; ++n){ + corr::correlation_cuda_backward_kernel_input2<<>>( + gradOutput_acc, input1_acc, gradInput2_acc, + kH, kW, patchH, patchW, padH, padW, + dilationH, dilationW, + dilation_patchH, dilation_patchW, + dH, dW, + n); + } + })); + + return {gradInput1, gradInput2}; +} diff --git a/aot/Pytorch-Correlation-extension/Correlation_Module/correlation_sampler.cpp b/aot/Pytorch-Correlation-extension/Correlation_Module/correlation_sampler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a21c41589796298ff83616d39f66c0a6c1b0af32 --- /dev/null +++ b/aot/Pytorch-Correlation-extension/Correlation_Module/correlation_sampler.cpp @@ -0,0 +1,138 @@ +#include +#include +#include +#include + +// declarations + +torch::Tensor correlation_cpp_forward( + torch::Tensor input1, + torch::Tensor input2, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW); + +std::vector correlation_cpp_backward( + torch::Tensor grad_output, + torch::Tensor input1, + torch::Tensor input2, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW); + +#ifdef USE_CUDA + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_SAME_DEVICE(x, y) TORCH_CHECK(x.device() == y.device(), #x " is not on same device as " #y) + +torch::Tensor correlation_cuda_forward( + torch::Tensor input1, + torch::Tensor input2, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW); + +std::vector correlation_cuda_backward( + torch::Tensor grad_output, + torch::Tensor input1, + torch::Tensor input2, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW); + +// C++ interface + +torch::Tensor correlation_sample_forward( + torch::Tensor input1, + torch::Tensor input2, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW) { + if (input1.device().is_cuda()){ + CHECK_INPUT(input1); + CHECK_INPUT(input2); + + // set device of input1 as default CUDA device + // https://pytorch.org/cppdocs/api/structc10_1_1cuda_1_1_optional_c_u_d_a_guard.html + const at::cuda::OptionalCUDAGuard guard_input1(device_of(input1)); + CHECK_SAME_DEVICE(input1, input2); + + return correlation_cuda_forward(input1, input2, kH, kW, patchH, patchW, + padH, padW, dilationH, dilationW, + dilation_patchH, dilation_patchW, + dH, dW); + }else{ + return correlation_cpp_forward(input1, input2, kH, kW, patchH, patchW, + padH, padW, dilationH, dilationW, + dilation_patchH, dilation_patchW, + dH, dW); + } +} + +std::vector correlation_sample_backward( + torch::Tensor input1, + torch::Tensor input2, + torch::Tensor grad_output, + int kH, int kW, + int patchH, int patchW, + int padH, int padW, + int dilationH, int dilationW, + int dilation_patchH, int dilation_patchW, + int dH, int dW) { + + if(grad_output.device().is_cuda()){ + CHECK_INPUT(input1); + CHECK_INPUT(input2); + + // set device of input1 as default CUDA device + const at::cuda::OptionalCUDAGuard guard_input1(device_of(input1)); + CHECK_SAME_DEVICE(input1, input2); + CHECK_SAME_DEVICE(input1, grad_output); + + return correlation_cuda_backward(input1, input2, grad_output, + kH, kW, patchH, patchW, + padH, padW, + dilationH, dilationW, + dilation_patchH, dilation_patchW, + dH, dW); + }else{ + return correlation_cpp_backward( + input1, input2, grad_output, + kH, kW, patchH, patchW, + padH, padW, + dilationH, dilationW, + dilation_patchH, dilation_patchW, + dH, dW); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &correlation_sample_forward, "Spatial Correlation Sampler Forward"); + m.def("backward", &correlation_sample_backward, "Spatial Correlation Sampler backward"); +} + +#else + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &correlation_cpp_forward, "Spatial Correlation Sampler Forward"); + m.def("backward", &correlation_cpp_backward, "Spatial Correlation Sampler backward"); +} + +#endif diff --git a/aot/Pytorch-Correlation-extension/Correlation_Module/spatial_correlation_sampler/__init__.py b/aot/Pytorch-Correlation-extension/Correlation_Module/spatial_correlation_sampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b99f47f5fd1008ca5e8eb4783dfc6d68fa90a2ac --- /dev/null +++ b/aot/Pytorch-Correlation-extension/Correlation_Module/spatial_correlation_sampler/__init__.py @@ -0,0 +1 @@ +from .spatial_correlation_sampler import SpatialCorrelationSampler, spatial_correlation_sample \ No newline at end of file diff --git a/aot/Pytorch-Correlation-extension/Correlation_Module/spatial_correlation_sampler/spatial_correlation_sampler.py b/aot/Pytorch-Correlation-extension/Correlation_Module/spatial_correlation_sampler/spatial_correlation_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..f8db65cbf5723ed3e73cf9fa94bc19803732feb9 --- /dev/null +++ b/aot/Pytorch-Correlation-extension/Correlation_Module/spatial_correlation_sampler/spatial_correlation_sampler.py @@ -0,0 +1,107 @@ +from torch import nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +import spatial_correlation_sampler_backend as correlation + + +def spatial_correlation_sample(input1, + input2, + kernel_size=1, + patch_size=1, + stride=1, + padding=0, + dilation=1, + dilation_patch=1): + """Apply spatial correlation sampling on from input1 to input2, + + Every parameter except input1 and input2 can be either single int + or a pair of int. For more information about Spatial Correlation + Sampling, see this page. + https://lmb.informatik.uni-freiburg.de/Publications/2015/DFIB15/ + + Args: + input1 : The first parameter. + input2 : The second parameter. + kernel_size : total size of your correlation kernel, in pixels + patch_size : total size of your patch, determining how many + different shifts will be applied + stride : stride of the spatial sampler, will modify output + height and width + padding : padding applied to input1 and input2 before applying + the correlation sampling, will modify output height and width + dilation_patch : step for every shift in patch + + Returns: + Tensor: Result of correlation sampling + + """ + return SpatialCorrelationSamplerFunction.apply(input1, input2, + kernel_size, patch_size, + stride, padding, dilation, dilation_patch) + + +class SpatialCorrelationSamplerFunction(Function): + + @staticmethod + def forward(ctx, + input1, + input2, + kernel_size=1, + patch_size=1, + stride=1, + padding=0, + dilation=1, + dilation_patch=1): + + ctx.save_for_backward(input1, input2) + kH, kW = ctx.kernel_size = _pair(kernel_size) + patchH, patchW = ctx.patch_size = _pair(patch_size) + padH, padW = ctx.padding = _pair(padding) + dilationH, dilationW = ctx.dilation = _pair(dilation) + dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair(dilation_patch) + dH, dW = ctx.stride = _pair(stride) + + output = correlation.forward(input1, input2, + kH, kW, patchH, patchW, + padH, padW, dilationH, dilationW, + dilation_patchH, dilation_patchW, + dH, dW) + + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input1, input2 = ctx.saved_variables + + kH, kW = ctx.kernel_size + patchH, patchW = ctx.patch_size + padH, padW = ctx.padding + dilationH, dilationW = ctx.dilation + dilation_patchH, dilation_patchW = ctx.dilation_patch + dH, dW = ctx.stride + + grad_input1, grad_input2 = correlation.backward(input1, input2, grad_output, + kH, kW, patchH, patchW, + padH, padW, dilationH, dilationW, + dilation_patchH, dilation_patchW, + dH, dW) + return grad_input1, grad_input2, None, None, None, None, None, None + + +class SpatialCorrelationSampler(nn.Module): + def __init__(self, kernel_size=1, patch_size=1, stride=1, padding=0, dilation=1, dilation_patch=1): + super(SpatialCorrelationSampler, self).__init__() + self.kernel_size = kernel_size + self.patch_size = patch_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.dilation_patch = dilation_patch + + def forward(self, input1, input2): + return SpatialCorrelationSamplerFunction.apply(input1, input2, self.kernel_size, + self.patch_size, self.stride, + self.padding, self.dilation, self.dilation_patch) diff --git a/aot/Pytorch-Correlation-extension/LICENSE b/aot/Pytorch-Correlation-extension/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..63b4b681cb65bcf92db3d26bc3664a1298cbeea8 --- /dev/null +++ b/aot/Pytorch-Correlation-extension/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) [year] [fullname] + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/aot/Pytorch-Correlation-extension/README.md b/aot/Pytorch-Correlation-extension/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a9b0c8c006924db94ec181fef781950a9bf11a2e --- /dev/null +++ b/aot/Pytorch-Correlation-extension/README.md @@ -0,0 +1,155 @@ + +[![PyPI](https://img.shields.io/pypi/v/spatial-correlation-sampler.svg)](https://pypi.org/project/spatial-correlation-sampler/) + + +# Pytorch Correlation module + +this is a custom C++/Cuda implementation of Correlation module, used e.g. in [FlowNetC](https://arxiv.org/abs/1504.06852) + +This [tutorial](http://pytorch.org/tutorials/advanced/cpp_extension.html) was used as a basis for implementation, as well as +[NVIDIA's cuda code](https://github.com/NVIDIA/flownet2-pytorch/tree/master/networks/correlation_package) + +- Build and Install C++ and CUDA extensions by executing `python setup.py install`, +- Benchmark C++ vs. CUDA by running `python benchmark.py {cpu, cuda}`, +- Run gradient checks on the code by running `python grad_check.py --backend {cpu, cuda}`. + +# Requirements + +This module is expected to compile for Pytorch `2.1.0`. + +Before installation please check compatibility of your GPU and CUDA (_Compute Capability_) [nvidia docs](https://developer.nvidia.com/cuda-gpus). +e.g RTX 6000 is using CC=8.9 so we are setting the environment variable to + +`export TORCH_CUDA_ARCH_LIST="8.9+PTX"` + +# Installation + +be reminded this module requires `python3-dev` to compile C++ code, e.g. on Ubuntu run: + +`apt install python3-dev` + +this module is available on pip + +`pip install spatial-correlation-sampler` + +For a cpu-only version, you can install from source with + +`python setup_cpu.py install` + +# Known Problems + +This module needs compatible gcc version and CUDA to be compiled. +Namely, CUDA 9.1 and below will need gcc5, while CUDA 9.2 and 10.0 will need gcc7 +See [this issue](https://github.com/ClementPinard/Pytorch-Correlation-extension/issues/1) for more information + +# Usage + +API has a few difference with NVIDIA's module + * output is now a 5D tensor, which reflects the shifts horizontal and vertical. + ``` +input (B x C x H x W) -> output (B x PatchH x PatchW x oH x oW) + ``` + * Output sizes `oH` and `oW` are no longer dependant of patch size, but only of kernel size and padding + * Patch size `patch_size` is now the whole patch, and not only the radii. + * `stride1` is now `stride` and`stride2` is `dilation_patch`, which behave like dilated convolutions + * equivalent `max_displacement` is then `dilation_patch * (patch_size - 1) / 2`. + * `dilation` is a new parameter, it acts the same way as dilated convolution regarding the correlation kernel + * to get the right parameters for FlowNetC, you would have + ``` +kernel_size=1 +patch_size=21, +stride=1, +padding=0, +dilation=1 +dilation_patch=2 + ``` + + +## Example +```python +import torch +from spatial_correlation_sampler import SpatialCorrelationSampler, + +device = "cuda" +batch_size = 1 +channel = 1 +H = 10 +W = 10 +dtype = torch.float32 + +input1 = torch.randint(1, 4, (batch_size, channel, H, W), dtype=dtype, device=device, requires_grad=True) +input2 = torch.randint_like(input1, 1, 4).requires_grad_(True) + +#You can either use the function or the module. Note that the module doesn't contain any parameter tensor. + +#function + +out = spatial_correlation_sample(input1, + input2, + kernel_size=3, + patch_size=1, + stride=2, + padding=0, + dilation=2, + dilation_patch=1) + +#module + +correlation_sampler = SpatialCorrelationSampler( + kernel_size=3, + patch_size=1, + stride=2, + padding=0, + dilation=2, + dilation_patch=1) +out = correlation_sampler(input1, input2) + +``` + +# Benchmark + + * default parameters are from `benchmark.py`, FlowNetC parameters are same as use in `FlowNetC` with a batch size of 4, described in [this paper](https://arxiv.org/abs/1504.06852), implemented [here](https://github.com/lmb-freiburg/flownet2) and [here](https://github.com/NVIDIA/flownet2-pytorch/blob/master/networks/FlowNetC.py). + * Feel free to file an issue to add entries to this with your hardware ! + +## CUDA Benchmark + + * See [here](https://gist.github.com/ClementPinard/270e910147119831014932f67fb1b5ea) for a benchmark script working with [NVIDIA](https://github.com/NVIDIA/flownet2-pytorch/tree/master/networks/correlation_package)'s code, and Pytorch. + * Benchmark are launched with environment variable `CUDA_LAUNCH_BLOCKING` set to `1`. + * Only `float32` is benchmarked. + * FlowNetC correlation parameters where launched with the following command: + + ```bash + CUDA_LAUNCH_BLOCKING=1 python benchmark.py --scale ms -k1 --patch 21 -s1 -p0 --patch_dilation 2 -b4 --height 48 --width 64 -c256 cuda -d float + + CUDA_LAUNCH_BLOCKING=1 python NV_correlation_benchmark.py --scale ms -k1 --patch 21 -s1 -p0 --patch_dilation 2 -b4 --height 48 --width 64 -c256 + ``` + + | implementation | Correlation parameters | device | pass | min time | avg time | + | -------------- | ---------------------- | ------- | -------- | ------------: | ------------: | + | ours | default | 980 GTX | forward | **5.745 ms** | **5.851 ms** | + | ours | default | 980 GTX | backward | 77.694 ms | 77.957 ms | + | NVIDIA | default | 980 GTX | forward | 13.779 ms | 13.853 ms | + | NVIDIA | default | 980 GTX | backward | **73.383 ms** | **73.708 ms** | + | | | | | | | + | ours | FlowNetC | 980 GTX | forward | **26.102 ms** | **26.179 ms** | + | ours | FlowNetC | 980 GTX | backward | **208.091 ms** | **208.510 ms** | + | NVIDIA | FlowNetC | 980 GTX | forward | 35.363 ms | 35.550 ms | + | NVIDIA | FlowNetC | 980 GTX | backward | 283.748 ms | 284.346 ms | + +### Notes + * The overhead of our implementation regarding `kernel_size` > 1 during backward needs some investigation, feel free to + dive in the code to improve it ! + * The backward pass of NVIDIA is not entirely correct when stride1 > 1 and kernel_size > 1, because not everything + is computed, see [here](https://github.com/NVIDIA/flownet2-pytorch/blob/master/networks/correlation_package/src/correlation_cuda_kernel.cu#L120). + +## CPU Benchmark + + * No other implementation is avalaible on CPU. + * It is obviously not recommended to run it on CPU if you have a GPU. + + | Correlation parameters | device | pass | min time | avg time | + | ---------------------- | -------------------- | -------- | ----------: | ----------: | + | default | E5-2630 v3 @ 2.40GHz | forward | 159.616 ms | 188.727 ms | + | default | E5-2630 v3 @ 2.40GHz | backward | 282.641 ms | 294.194 ms | + | FlowNetC | E5-2630 v3 @ 2.40GHz | forward | 2.138 s | 2.144 s | + | FlowNetC | E5-2630 v3 @ 2.40GHz | backward | 7.006 s | 7.075 s | diff --git a/aot/Pytorch-Correlation-extension/benchmark.py b/aot/Pytorch-Correlation-extension/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..83f3bf476cc364fde5f2b374ed90a1f0d568ccdc --- /dev/null +++ b/aot/Pytorch-Correlation-extension/benchmark.py @@ -0,0 +1,90 @@ +from __future__ import division +from __future__ import print_function + +import argparse +import time + +import torch +from spatial_correlation_sampler import SpatialCorrelationSampler +from tqdm import trange + +TIME_SCALES = {'s': 1, 'ms': 1000, 'us': 1000000} + +parser = argparse.ArgumentParser() +parser.add_argument('backend', choices=['cpu', 'cuda'], default='cuda') +parser.add_argument('-b', '--batch-size', type=int, default=16) +parser.add_argument('-k', '--kernel-size', type=int, default=3) +parser.add_argument('--patch', type=int, default=3) +parser.add_argument('--patch_dilation', type=int, default=2) +parser.add_argument('-c', '--channel', type=int, default=64) +parser.add_argument('--height', type=int, default=100) +parser.add_argument('-w', '--width', type=int, default=100) +parser.add_argument('-s', '--stride', type=int, default=2) +parser.add_argument('-p', '--pad', type=int, default=1) +parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us') +parser.add_argument('-r', '--runs', type=int, default=100) +parser.add_argument('--dilation', type=int, default=2) +parser.add_argument('-d', '--dtype', choices=['half', 'float', 'double']) + +args = parser.parse_args() + +device = torch.device(args.backend) + +if args.dtype == 'half': + dtype = torch.float16 +elif args.dtype == 'float': + dtype = torch.float32 +else: + dtype = torch.float64 + + +input1 = torch.randn(args.batch_size, + args.channel, + args.height, + args.width, + dtype=dtype, + device=device, + requires_grad=True) +input2 = torch.randn_like(input1) + +correlation_sampler = SpatialCorrelationSampler( + args.kernel_size, + args.patch, + args.stride, + args.pad, + args.dilation, + args.patch_dilation) + +# Force CUDA initialization +output = correlation_sampler(input1, input2) +print(output.size()) +output.mean().backward() +forward_min = float('inf') +forward_time = 0 +backward_min = float('inf') +backward_time = 0 +for _ in trange(args.runs): + correlation_sampler.zero_grad() + + start = time.time() + output = correlation_sampler(input1, input2) + elapsed = time.time() - start + forward_min = min(forward_min, elapsed) + forward_time += elapsed + output = output.mean() + + start = time.time() + (output.mean()).backward() + elapsed = time.time() - start + backward_min = min(backward_min, elapsed) + backward_time += elapsed + +scale = TIME_SCALES[args.scale] +forward_min *= scale +backward_min *= scale +forward_average = forward_time / args.runs * scale +backward_average = backward_time / args.runs * scale + +print('Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}'.format( + forward_min, forward_average, backward_min, backward_average, + args.scale)) diff --git a/aot/Pytorch-Correlation-extension/check.py b/aot/Pytorch-Correlation-extension/check.py new file mode 100644 index 0000000000000000000000000000000000000000..0033f978f13f9de80c1e8cd2ea80b2eea5588124 --- /dev/null +++ b/aot/Pytorch-Correlation-extension/check.py @@ -0,0 +1,119 @@ +from __future__ import division +from __future__ import print_function + +import argparse +import numpy as np +import torch + +from spatial_correlation_sampler import SpatialCorrelationSampler + + +def check_equal(first, second, verbose): + if verbose: + print() + for i, (x, y) in enumerate(zip(first, second)): + x = x.cpu().detach().numpy() + y = y.cpu().detach().numpy() + if verbose: + print("x = {}".format(x.flatten())) + print("y = {}".format(y.flatten())) + print('-' * 80) + np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i)) + + +def zero_grad(variables): + for variable in variables: + if variable.grad is not None: variable.grad.zero_() + + +def get_grads(variables): + return [var.grad.clone() for var in variables] + + +def check_forward(input1, input2, correlation_sampler, verbose, gpu_index=0): + device = torch.device(f"cuda:{gpu_index}") + + cpu_values = correlation_sampler(input1, input2) + cuda_values = correlation_sampler(input1.to(device), input2.to(device)) + + print(f"Forward: CPU vs. CUDA device:{gpu_index} ... ", end='') + check_equal(cpu_values, cuda_values, verbose) + print('Ok') + + +def check_backward(input1, input2, correlation_sampler, verbose, gpu_index=0): + device = torch.device(f"cuda:{gpu_index}") + + zero_grad([input1, input2]) + + cpu_values = correlation_sampler(input1, input2) + cpu_values.sum().backward() + grad_cpu = get_grads([input1, input2]) + + zero_grad([input1, input2]) + + cuda_values = correlation_sampler(input1.to(device), input2.to(device)) + cuda_values.sum().backward() + grad_cuda = get_grads([input1, input2]) + + print(f"Backward: CPU vs. CUDA device:{gpu_index} ... ", end='') + check_equal(grad_cpu, grad_cuda, verbose) + print('Ok') + + +def check_multi_gpu_forward(correlation_sampler, verbose): + print("Multi-GPU forward") + total_gpus = torch.cuda.device_count() + for gpu in range(total_gpus): + check_forward(input1, input2, correlation_sampler, verbose, gpu_index=gpu) + +def check_multi_gpu_backward(correlation_sampler, verbose): + print("Multi-GPU backward") + total_gpus = torch.cuda.device_count() + for gpu in range(total_gpus): + check_backward(input1, input2, correlation_sampler, verbose, gpu_index=gpu) + + +parser = argparse.ArgumentParser() +parser.add_argument('direction', choices=['forward', 'backward'], nargs='+') +parser.add_argument('-b', '--batch-size', type=int, default=1) +parser.add_argument('-k', '--kernel-size', type=int, default=3) +parser.add_argument('--patch', type=int, default=3) +parser.add_argument('--patch_dilation', type=int, default=2) +parser.add_argument('-c', '--channel', type=int, default=10) +parser.add_argument('--height', type=int, default=10) +parser.add_argument('-w', '--width', type=int, default=10) +parser.add_argument('-s', '--stride', type=int, default=2) +parser.add_argument('-p', '--pad', type=int, default=5) +parser.add_argument('-v', '--verbose', action='store_true', default=False) +parser.add_argument('-d', '--dilation', type=int, default=2) +args = parser.parse_args() +print(args) + +assert(torch.cuda.is_available()), "no comparison to make" +input1 = torch.randn(args.batch_size, + args.channel, + args.height, + args.width).double() +input2 = torch.randn(args.batch_size, + args.channel, + args.height, + args.width).double() +input1.requires_grad = True +input2.requires_grad = True + +correlation_sampler = SpatialCorrelationSampler( + args.kernel_size, + args.patch, + args.stride, + args.pad, + args.dilation, + args.patch_dilation) + +if 'forward' in args.direction: + check_forward(input1, input2, correlation_sampler, args.verbose) + if torch.cuda.device_count() > 1: check_multi_gpu_forward(correlation_sampler, args.verbose) + +if 'backward' in args.direction: + check_backward(input1, input2, correlation_sampler, args.verbose) + if torch.cuda.device_count() > 1: check_multi_gpu_backward(correlation_sampler, args.verbose) diff --git a/aot/Pytorch-Correlation-extension/grad_check.py b/aot/Pytorch-Correlation-extension/grad_check.py new file mode 100644 index 0000000000000000000000000000000000000000..bed39ea5c7c8540af2d0a5def2d0d89da1b664d8 --- /dev/null +++ b/aot/Pytorch-Correlation-extension/grad_check.py @@ -0,0 +1,47 @@ +import argparse +import torch +# torch.set_printoptions(precision=1, threshold=10000) +from torch.autograd import gradcheck +from spatial_correlation_sampler import SpatialCorrelationSampler + +parser = argparse.ArgumentParser() +parser.add_argument('backend', choices=['cpu', 'cuda'], default='cuda') +parser.add_argument('-b', '--batch-size', type=int, default=2) +parser.add_argument('-k', '--kernel-size', type=int, default=3) +parser.add_argument('--patch', type=int, default=3) +parser.add_argument('--patch_dilation', type=int, default=2) +parser.add_argument('-c', '--channel', type=int, default=2) +parser.add_argument('--height', type=int, default=10) +parser.add_argument('-w', '--width', type=int, default=10) +parser.add_argument('-s', '--stride', type=int, default=2) +parser.add_argument('-p', '--pad', type=int, default=1) +parser.add_argument('-d', '--dilation', type=int, default=2) + +args = parser.parse_args() + +input1 = torch.randn(args.batch_size, + args.channel, + args.height, + args.width, + dtype=torch.float64, + device=torch.device(args.backend)) +input2 = torch.randn(args.batch_size, + args.channel, + args.height, + args.width, + dtype=torch.float64, + device=torch.device(args.backend)) + +input1.requires_grad = True +input2.requires_grad = True + +correlation_sampler = SpatialCorrelationSampler(args.kernel_size, + args.patch, + args.stride, + args.pad, + args.dilation, + args.patch_dilation) + + +if gradcheck(correlation_sampler, [input1, input2]): + print('Ok') diff --git a/aot/Pytorch-Correlation-extension/requirements.txt b/aot/Pytorch-Correlation-extension/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3d922e41730707fbb9c91993e7a83e9dd9222049 --- /dev/null +++ b/aot/Pytorch-Correlation-extension/requirements.txt @@ -0,0 +1,2 @@ +torch>=1.0.1 +numpy diff --git a/aot/Pytorch-Correlation-extension/setup.py b/aot/Pytorch-Correlation-extension/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..3d1ec4801aed65d39a82aa4b45ffd6006f6f460e --- /dev/null +++ b/aot/Pytorch-Correlation-extension/setup.py @@ -0,0 +1,69 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension +from os.path import join + +CPU_ONLY = False +project_root = 'Correlation_Module' + +source_files = ['correlation.cpp', 'correlation_sampler.cpp'] + +cxx_args = ['-std=c++17', '-fopenmp'] + +def generate_nvcc_args(gpu_archs): + nvcc_args = [] + for arch in gpu_archs: + nvcc_args.extend(['-gencode', f'arch=compute_{arch},code=sm_{arch}']) + return nvcc_args + +gpu_arch = os.environ.get('GPU_ARCH', '').split() +nvcc_args = generate_nvcc_args(gpu_arch) + +with open("README.md", "r") as fh: + long_description = fh.read() + + +def launch_setup(): + if CPU_ONLY: + Extension = CppExtension + macro = [] + else: + Extension = CUDAExtension + source_files.append('correlation_cuda_kernel.cu') + macro = [("USE_CUDA", None)] + + sources = [join(project_root, file) for file in source_files] + + setup( + name='spatial_correlation_sampler', + version="0.4.0", + author="Clément Pinard", + author_email="clement.pinard@ensta-paristech.fr", + description="Correlation module for pytorch", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/ClementPinard/Pytorch-Correlation-extension", + install_requires=['torch>=1.1', 'numpy'], + ext_modules=[ + Extension('spatial_correlation_sampler_backend', + sources, + define_macros=macro, + extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}, + extra_link_args=['-lgomp']) + ], + package_dir={'': project_root}, + packages=['spatial_correlation_sampler'], + cmdclass={ + 'build_ext': BuildExtension + }, + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: POSIX :: Linux", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Artificial Intelligence" + ]) + + +if __name__ == '__main__': + launch_setup() diff --git a/aot/Pytorch-Correlation-extension/setup_cpu.py b/aot/Pytorch-Correlation-extension/setup_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..d4620c22d9d61b1cd6a621288d54ae85dd17d4d7 --- /dev/null +++ b/aot/Pytorch-Correlation-extension/setup_cpu.py @@ -0,0 +1,4 @@ +import setup + +setup.CPU_ONLY = True +setup.launch_setup() diff --git a/aot/README.md b/aot/README.md new file mode 100644 index 0000000000000000000000000000000000000000..451c956c3b05d620b2e1ec1561a80e98dc3c21b6 --- /dev/null +++ b/aot/README.md @@ -0,0 +1,152 @@ +# AOT Series Frameworks in PyTorch + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/decoupling-features-in-hierarchical/semi-supervised-video-object-segmentation-on-15)](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-15?p=decoupling-features-in-hierarchical) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/video-object-segmentation-on-youtube-vos)](https://paperswithcode.com/sota/video-object-segmentation-on-youtube-vos?p=associating-objects-with-scalable) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/semi-supervised-video-object-segmentation-on-18)](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-18?p=associating-objects-with-scalable) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/semi-supervised-video-object-segmentation-on-1)](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-1?p=associating-objects-with-scalable) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/visual-object-tracking-on-davis-2017)](https://paperswithcode.com/sota/visual-object-tracking-on-davis-2017?p=associating-objects-with-scalable) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/visual-object-tracking-on-davis-2016)](https://paperswithcode.com/sota/visual-object-tracking-on-davis-2016?p=associating-objects-with-scalable) + +A modular reference PyTorch implementation of AOT series frameworks: +- **DeAOT**: Decoupling Features in Hierachical Propagation for Video Object Segmentation (NeurIPS 2022, Spotlight) [[OpenReview](https://openreview.net/forum?id=DgM7-7eMkq0)][[PDF](https://arxiv.org/pdf/2210.09782.pdf)] + + +- **AOT**: Associating Objects with Transformers for Video Object Segmentation (NeurIPS 2021, Score 8/8/7/8) [[OpenReview](https://openreview.net/forum?id=hl3v8io3ZYt)][[PDF](https://arxiv.org/abs/2106.02638)] + + +An extension of AOT, [AOST](https://arxiv.org/abs/2203.11442) (under review), is available now. AOST is a more robust and flexible framework, supporting run-time speed-accuracy trade-offs. + +## Examples +Benchmark examples: + + + +General examples (Messi and Kobe): + + + +## Highlights +- **High performance:** up to **85.5%** ([R50-AOTL](MODEL_ZOO.md#youtube-vos-2018-val)) on YouTube-VOS 2018 and **82.1%** ([SwinB-AOTL]((MODEL_ZOO.md#youtube-vos-2018-val))) on DAVIS-2017 Test-dev under standard settings (without any test-time augmentation and post processing). +- **High efficiency:** up to **51fps** ([AOTT](MODEL_ZOO.md#davis-2017-test)) on DAVIS-2017 (480p) even with **10** objects and **41fps** on YouTube-VOS (1.3x480p). AOT can process multiple objects (less than a pre-defined number, 10 is the default) as efficiently as processing a single object. This project also supports inferring any number of objects together within a video by automatic separation and aggregation. +- **Multi-GPU training and inference** +- **Mixed precision training and inference** +- **Test-time augmentation:** multi-scale and flipping augmentations are supported. + +## Requirements + * Python3 + * pytorch >= 1.7.0 and torchvision + * opencv-python + * Pillow + * Pytorch Correlation (Recommend to install from [source](https://github.com/ClementPinard/Pytorch-Correlation-extension) instead of using `pip`. **The project can also work without this module but will lose some efficiency of the short-term attention**.) + +Optional: + * scikit-image (if you want to run our **Demo**, please install) + +## Model Zoo and Results +Pre-trained models, benckmark scores, and pre-computed results reproduced by this project can be found in [MODEL_ZOO.md](MODEL_ZOO.md). + +## Demo - Panoptic Propagation +We provide a simple demo to demonstrate AOT's effectiveness. The demo will propagate more than **40** objects, including semantic regions (like sky) and instances (like person), together within a single complex scenario and predict its video panoptic segmentation. + +To run the demo, download the [checkpoint](https://drive.google.com/file/d/1qJDYn3Ibpquu4ffYoQmVjg1YCbr2JQep/view?usp=sharing) of R50-AOTL into [pretrain_models](pretrain_models), and then run: +```bash +python tools/demo.py +``` +which will predict the given scenarios in the resolution of 1.3x480p. You can also run this demo with other AOTs ([MODEL_ZOO.md](MODEL_ZOO.md)) by setting `--model` (model type) and `--ckpt_path` (checkpoint path). + +Two scenarios from [VSPW](https://www.vspwdataset.com/home) are supplied in [datasets/Demo](datasets/Demo): + +- 1001_3iEIq5HBY1s: 44 objects. 1080P. +- 1007_YCTBBdbKSSg: 43 objects. 1080P. + +Results: + + + + +## Getting Started +0. Prepare a valid environment follow the [requirements](#requirements). + +1. Prepare datasets: + + Please follow the below instruction to prepare datasets in each corresponding folder. + * **Static** + + [datasets/Static](datasets/Static): pre-training dataset with static images. Guidance can be found in [AFB-URR](https://github.com/xmlyqing00/AFB-URR), which we referred to in the implementation of the pre-training. + * **YouTube-VOS** + + A commonly-used large-scale VOS dataset. + + [datasets/YTB/2019](datasets/YTB/2019): version 2019, download [link](https://drive.google.com/drive/folders/1BWzrCWyPEmBEKm0lOHe5KLuBuQxUSwqz?usp=sharing). `train` is required for training. `valid` (6fps) and `valid_all_frames` (30fps, optional) are used for evaluation. + + [datasets/YTB/2018](datasets/YTB/2018): version 2018, download [link](https://drive.google.com/drive/folders/1bI5J1H3mxsIGo7Kp-pPZU8i6rnykOw7f?usp=sharing). Only `valid` (6fps) and `valid_all_frames` (30fps, optional) are required for this project and used for evaluation. + + * **DAVIS** + + A commonly-used small-scale VOS dataset. + + [datasets/DAVIS](datasets/DAVIS): [TrainVal](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip) (480p) contains both the training and validation split. [Test-Dev](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-test-dev-480p.zip) (480p) contains the Test-dev split. The [full-resolution version](https://davischallenge.org/davis2017/code.html) is also supported for training and evaluation but not required. + + +2. Prepare ImageNet pre-trained encoders + + Select and download below checkpoints into [pretrain_models](pretrain_models): + + - [MobileNet-V2](https://download.pytorch.org/models/mobilenet_v2-b0353104.pth) (default encoder) + - [MobileNet-V3](https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth) + - [ResNet-50](https://download.pytorch.org/models/resnet50-0676ba61.pth) + - [ResNet-101](https://download.pytorch.org/models/resnet101-63fe2227.pth) + - [ResNeSt-50](https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest50-528c19ca.pth) + - [ResNeSt-101](https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth) + - [Swin-Base](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth) + + The current default training configs are not optimized for encoders larger than ResNet-50. If you want to use larger encoders, we recommend early stopping the main-training stage at 80,000 iterations (100,000 in default) to avoid over-fitting on the seen classes of YouTube-VOS. + + + +3. Training and Evaluation + + The [example script](train_eval.sh) will train AOTT with 2 stages using 4 GPUs and auto-mixed precision (`--amp`). The first stage is a pre-training stage using `Static` dataset, and the second stage is a main-training stage, which uses both `YouTube-VOS 2019 train` and `DAVIS-2017 train` for training, resulting in a model that can generalize to different domains (YouTube-VOS and DAVIS) and different frame rates (6fps, 24fps, and 30fps). + + Notably, you can use only the `YouTube-VOS 2019 train` split in the second stage by changing `pre_ytb_dav` to `pre_ytb`, which leads to better YouTube-VOS performance on unseen classes. Besides, if you don't want to do the first stage, you can start the training from stage `ytb`, but the performance will drop about 1~2% absolutely. + + After the training is finished (about 0.6 days for each stage with 4 Tesla V100 GPUs), the [example script](train_eval.sh) will evaluate the model on YouTube-VOS and DAVIS, and the results will be packed into Zip files. For calculating scores, please use official YouTube-VOS servers ([2018 server](https://competitions.codalab.org/competitions/19544) and [2019 server](https://competitions.codalab.org/competitions/20127)), official [DAVIS toolkit](https://github.com/davisvideochallenge/davis-2017) (for Val), and official [DAVIS server](https://competitions.codalab.org/competitions/20516#learn_the_details) (for Test-dev). + +## Adding your own dataset +Coming + +## Troubleshooting +Waiting + +## TODO +- [ ] Code documentation +- [ ] Adding your own dataset +- [ ] Results with test-time augmentations in Model Zoo +- [ ] Support gradient accumulation +- [x] Demo tool + +## Citations +Please consider citing the related paper(s) in your publications if it helps your research. +``` +@inproceedings{yang2022deaot, + title={Decoupling Features in Hierarchical Propagation for Video Object Segmentation}, + author={Yang, Zongxin and Yang, Yi}, + booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, + year={2022} +} +@article{yang2021aost, + title={Scalable Multi-object Identification for Video Object Segmentation}, + author={Yang, Zongxin and Miao, Jiaxu and Wang, Xiaohan and Wei, Yunchao and Yang, Yi}, + journal={arXiv preprint arXiv:2203.11442}, + year={2022} +} +@inproceedings{yang2021aot, + title={Associating Objects with Transformers for Video Object Segmentation}, + author={Yang, Zongxin and Wei, Yunchao and Yang, Yi}, + booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, + year={2021} +} +``` + +## License +This project is released under the BSD-3-Clause license. See [LICENSE](LICENSE) for additional details. diff --git a/aot/__init__.py b/aot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/aot/__pycache__/__init__.cpython-310.pyc b/aot/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0118cd29a9d063b77f279ccddf473792abce5c0 Binary files /dev/null and b/aot/__pycache__/__init__.cpython-310.pyc differ diff --git a/aot/configs/__pycache__/default.cpython-310.pyc b/aot/configs/__pycache__/default.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7194f49aba4070ba02f07771285abd73634f711a Binary files /dev/null and b/aot/configs/__pycache__/default.cpython-310.pyc differ diff --git a/aot/configs/__pycache__/pre_ytb_dav.cpython-310.pyc b/aot/configs/__pycache__/pre_ytb_dav.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11fe6f91d5ccceb475c2182971edacd5f7926489 Binary files /dev/null and b/aot/configs/__pycache__/pre_ytb_dav.cpython-310.pyc differ diff --git a/aot/configs/default.py b/aot/configs/default.py new file mode 100644 index 0000000000000000000000000000000000000000..fc96c45cb196c50a70d447ba628b26bd5a293e0e --- /dev/null +++ b/aot/configs/default.py @@ -0,0 +1,138 @@ +import os +import importlib + + +class DefaultEngineConfig(): + def __init__(self, exp_name='default', model='aott'): + model_cfg = importlib.import_module('configs.models.' + + model).ModelConfig() + self.__dict__.update(model_cfg.__dict__) # add model config + + self.EXP_NAME = exp_name + '_' + self.MODEL_NAME + + self.STAGE_NAME = 'YTB' + + self.DATASETS = ['youtubevos'] + self.DATA_WORKERS = 8 + self.DATA_RANDOMCROP = (465, + 465) if self.MODEL_ALIGN_CORNERS else (464, + 464) + self.DATA_RANDOMFLIP = 0.5 + self.DATA_MAX_CROP_STEPS = 10 + self.DATA_SHORT_EDGE_LEN = 480 + self.DATA_MIN_SCALE_FACTOR = 0.7 + self.DATA_MAX_SCALE_FACTOR = 1.3 + self.DATA_RANDOM_REVERSE_SEQ = True + self.DATA_SEQ_LEN = 5 + self.DATA_DAVIS_REPEAT = 5 + self.DATA_RANDOM_GAP_DAVIS = 12 # max frame interval between two sampled frames for DAVIS (24fps) + self.DATA_RANDOM_GAP_YTB = 3 # max frame interval between two sampled frames for YouTube-VOS (6fps) + self.DATA_DYNAMIC_MERGE_PROB = 0.3 + + self.PRETRAIN = True + self.PRETRAIN_FULL = False # if False, load encoder only + self.PRETRAIN_MODEL = './data_wd/pretrain_model/mobilenet_v2.pth' + # self.PRETRAIN_MODEL = './pretrain_models/mobilenet_v2-b0353104.pth' + + self.TRAIN_TOTAL_STEPS = 100000 + self.TRAIN_START_STEP = 0 + self.TRAIN_WEIGHT_DECAY = 0.07 + self.TRAIN_WEIGHT_DECAY_EXCLUSIVE = { + # 'encoder.': 0.01 + } + self.TRAIN_WEIGHT_DECAY_EXEMPTION = [ + 'absolute_pos_embed', 'relative_position_bias_table', + 'relative_emb_v', 'conv_out' + ] + self.TRAIN_LR = 2e-4 + self.TRAIN_LR_MIN = 2e-5 if 'mobilenetv2' in self.MODEL_ENCODER else 1e-5 + self.TRAIN_LR_POWER = 0.9 + self.TRAIN_LR_ENCODER_RATIO = 0.1 + self.TRAIN_LR_WARM_UP_RATIO = 0.05 + self.TRAIN_LR_COSINE_DECAY = False + self.TRAIN_LR_RESTART = 1 + self.TRAIN_LR_UPDATE_STEP = 1 + self.TRAIN_AUX_LOSS_WEIGHT = 1.0 + self.TRAIN_AUX_LOSS_RATIO = 1.0 + self.TRAIN_OPT = 'adamw' + self.TRAIN_SGD_MOMENTUM = 0.9 + self.TRAIN_GPUS = 4 + self.TRAIN_BATCH_SIZE = 16 + self.TRAIN_TBLOG = False + self.TRAIN_TBLOG_STEP = 50 + self.TRAIN_LOG_STEP = 20 + self.TRAIN_IMG_LOG = True + self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15 + self.TRAIN_SEQ_TRAINING_FREEZE_PARAMS = ['patch_wise_id_bank'] + self.TRAIN_SEQ_TRAINING_START_RATIO = 0.5 + self.TRAIN_HARD_MINING_RATIO = 0.5 + self.TRAIN_EMA_RATIO = 0.1 + self.TRAIN_CLIP_GRAD_NORM = 5. + self.TRAIN_SAVE_STEP = 5000 + self.TRAIN_MAX_KEEP_CKPT = 8 + self.TRAIN_RESUME = False + self.TRAIN_RESUME_CKPT = None + self.TRAIN_RESUME_STEP = 0 + self.TRAIN_AUTO_RESUME = True + self.TRAIN_DATASET_FULL_RESOLUTION = False + self.TRAIN_ENABLE_PREV_FRAME = False + self.TRAIN_ENCODER_FREEZE_AT = 2 + self.TRAIN_LSTT_EMB_DROPOUT = 0. + self.TRAIN_LSTT_ID_DROPOUT = 0. + self.TRAIN_LSTT_DROPPATH = 0.1 + self.TRAIN_LSTT_DROPPATH_SCALING = False + self.TRAIN_LSTT_DROPPATH_LST = False + self.TRAIN_LSTT_LT_DROPOUT = 0. + self.TRAIN_LSTT_ST_DROPOUT = 0. + + self.TEST_GPU_ID = 0 + self.TEST_GPU_NUM = 1 + self.TEST_FRAME_LOG = False + self.TEST_DATASET = 'youtubevos' + self.TEST_DATASET_FULL_RESOLUTION = False + self.TEST_DATASET_SPLIT = 'val' + self.TEST_CKPT_PATH = None + # if "None", evaluate the latest checkpoint. + self.TEST_CKPT_STEP = None + self.TEST_FLIP = False + self.TEST_MULTISCALE = [1] + self.TEST_MAX_SHORT_EDGE = None + self.TEST_MAX_LONG_EDGE = 800 * 1.3 + self.TEST_WORKERS = 4 + + # GPU distribution + self.DIST_ENABLE = True + self.DIST_BACKEND = "nccl" # "gloo" + self.DIST_URL = "tcp://127.0.0.1:13241" + self.DIST_START_GPU = 0 + + def init_dir(self): + self.DIR_DATA = '../VOS02/datasets'#'./datasets' + self.DIR_DAVIS = os.path.join(self.DIR_DATA, 'DAVIS') + self.DIR_YTB = os.path.join(self.DIR_DATA, 'YTB') + self.DIR_STATIC = os.path.join(self.DIR_DATA, 'Static') + + self.DIR_ROOT = './'#'./data_wd/youtube_vos_jobs' + + self.DIR_RESULT = os.path.join(self.DIR_ROOT, 'result', self.EXP_NAME, + self.STAGE_NAME) + self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt') + self.DIR_EMA_CKPT = os.path.join(self.DIR_RESULT, 'ema_ckpt') + self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log') + self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard') + # self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img') + # self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval') + self.DIR_IMG_LOG = './img_logs' + self.DIR_EVALUATION = './results' + + for path in [ + self.DIR_RESULT, self.DIR_CKPT, self.DIR_EMA_CKPT, + self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG, + self.DIR_TB_LOG + ]: + if not os.path.isdir(path): + try: + os.makedirs(path) + except Exception as inst: + print(inst) + print('Failed to make dir: {}.'.format(path)) diff --git a/aot/configs/models/__pycache__/default.cpython-310.pyc b/aot/configs/models/__pycache__/default.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e75de09e98c5b1c6d500bd2d28453683d9c6e32d Binary files /dev/null and b/aot/configs/models/__pycache__/default.cpython-310.pyc differ diff --git a/aot/configs/models/__pycache__/r50_aotl.cpython-310.pyc b/aot/configs/models/__pycache__/r50_aotl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17820a1002052be51b2522f5d37821a07a2b9be0 Binary files /dev/null and b/aot/configs/models/__pycache__/r50_aotl.cpython-310.pyc differ diff --git a/aot/configs/models/aotb.py b/aot/configs/models/aotb.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0bc396f04e399e5d06f4cb40dd86e8dedd2019 --- /dev/null +++ b/aot/configs/models/aotb.py @@ -0,0 +1,9 @@ +import os +from .default import DefaultModelConfig + +class ModelConfig(DefaultModelConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'AOTB' + + self.MODEL_LSTT_NUM = 3 diff --git a/aot/configs/models/aotl.py b/aot/configs/models/aotl.py new file mode 100644 index 0000000000000000000000000000000000000000..4fcedefa3b39bc3e49e70eae357273757c332778 --- /dev/null +++ b/aot/configs/models/aotl.py @@ -0,0 +1,13 @@ +import os +from .default import DefaultModelConfig + +class ModelConfig(DefaultModelConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'AOTL' + + self.MODEL_LSTT_NUM = 3 + + self.TRAIN_LONG_TERM_MEM_GAP = 2 + + self.TEST_LONG_TERM_MEM_GAP = 5 \ No newline at end of file diff --git a/aot/configs/models/aots.py b/aot/configs/models/aots.py new file mode 100644 index 0000000000000000000000000000000000000000..cb5e8458e5747116fc3da7101fc70413452aebd4 --- /dev/null +++ b/aot/configs/models/aots.py @@ -0,0 +1,9 @@ +import os +from .default import DefaultModelConfig + +class ModelConfig(DefaultModelConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'AOTS' + + self.MODEL_LSTT_NUM = 2 diff --git a/aot/configs/models/aott.py b/aot/configs/models/aott.py new file mode 100644 index 0000000000000000000000000000000000000000..587fce66d43c23ddc2eed105e1033650f3ef5080 --- /dev/null +++ b/aot/configs/models/aott.py @@ -0,0 +1,7 @@ +import os +from .default import DefaultModelConfig + +class ModelConfig(DefaultModelConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'AOTT' diff --git a/aot/configs/models/deaotb.py b/aot/configs/models/deaotb.py new file mode 100644 index 0000000000000000000000000000000000000000..6fcf2c1251f5a7e032375d425b1e979ff68bdaee --- /dev/null +++ b/aot/configs/models/deaotb.py @@ -0,0 +1,9 @@ +from .default_deaot import DefaultModelConfig + + +class ModelConfig(DefaultModelConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'DeAOTB' + + self.MODEL_LSTT_NUM = 3 diff --git a/aot/configs/models/deaotl.py b/aot/configs/models/deaotl.py new file mode 100644 index 0000000000000000000000000000000000000000..b61601e36a01dc536c29f45efab27dd0f8e857ba --- /dev/null +++ b/aot/configs/models/deaotl.py @@ -0,0 +1,13 @@ +from .default_deaot import DefaultModelConfig + + +class ModelConfig(DefaultModelConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'DeAOTL' + + self.MODEL_LSTT_NUM = 3 + + self.TRAIN_LONG_TERM_MEM_GAP = 2 + + self.TEST_LONG_TERM_MEM_GAP = 5 diff --git a/aot/configs/models/deaots.py b/aot/configs/models/deaots.py new file mode 100644 index 0000000000000000000000000000000000000000..632916c59e9c92cf26c6d12c9a0d2aadd2cd07cf --- /dev/null +++ b/aot/configs/models/deaots.py @@ -0,0 +1,9 @@ +from .default_deaot import DefaultModelConfig + + +class ModelConfig(DefaultModelConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'DeAOTS' + + self.MODEL_LSTT_NUM = 2 diff --git a/aot/configs/models/deaott.py b/aot/configs/models/deaott.py new file mode 100644 index 0000000000000000000000000000000000000000..78a414b74e53572ac34e30d74c0dd91a61cae4a1 --- /dev/null +++ b/aot/configs/models/deaott.py @@ -0,0 +1,7 @@ +from .default_deaot import DefaultModelConfig + + +class ModelConfig(DefaultModelConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'DeAOTT' diff --git a/aot/configs/models/default.py b/aot/configs/models/default.py new file mode 100644 index 0000000000000000000000000000000000000000..2ec250c637882027824483babaef3618044a347d --- /dev/null +++ b/aot/configs/models/default.py @@ -0,0 +1,27 @@ +class DefaultModelConfig(): + def __init__(self): + self.MODEL_NAME = 'AOTDefault' + + self.MODEL_VOS = 'aot' + self.MODEL_ENGINE = 'aotengine' + self.MODEL_ALIGN_CORNERS = True + self.MODEL_ENCODER = 'mobilenetv2' + self.MODEL_ENCODER_PRETRAIN = './pretrain_models/mobilenet_v2-b0353104.pth' + self.MODEL_ENCODER_DIM = [24, 32, 96, 1280] # 4x, 8x, 16x, 16x + self.MODEL_ENCODER_EMBEDDING_DIM = 256 + self.MODEL_DECODER_INTERMEDIATE_LSTT = True + self.MODEL_FREEZE_BN = True + self.MODEL_FREEZE_BACKBONE = False + self.MODEL_MAX_OBJ_NUM = 10 + self.MODEL_SELF_HEADS = 8 + self.MODEL_ATT_HEADS = 8 + self.MODEL_LSTT_NUM = 1 + self.MODEL_EPSILON = 1e-5 + self.MODEL_USE_PREV_PROB = False + + self.TRAIN_LONG_TERM_MEM_GAP = 9999 + self.TRAIN_AUG_TYPE = 'v1' + + self.TEST_LONG_TERM_MEM_GAP = 9999 + + self.TEST_SHORT_TERM_MEM_SKIP = 1 diff --git a/aot/configs/models/default_deaot.py b/aot/configs/models/default_deaot.py new file mode 100644 index 0000000000000000000000000000000000000000..f28a52e99ab79c37346848ea9f6329521da91e36 --- /dev/null +++ b/aot/configs/models/default_deaot.py @@ -0,0 +1,17 @@ +from .default import DefaultModelConfig as BaseConfig + + +class DefaultModelConfig(BaseConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'DeAOTDefault' + + self.MODEL_VOS = 'deaot' + self.MODEL_ENGINE = 'deaotengine' + + self.MODEL_DECODER_INTERMEDIATE_LSTT = False + + self.MODEL_SELF_HEADS = 1 + self.MODEL_ATT_HEADS = 1 + + self.TRAIN_AUG_TYPE = 'v2' diff --git a/aot/configs/models/r101_aotl.py b/aot/configs/models/r101_aotl.py new file mode 100644 index 0000000000000000000000000000000000000000..1687165de3f066648aefc985298b0f783a3f4a48 --- /dev/null +++ b/aot/configs/models/r101_aotl.py @@ -0,0 +1,16 @@ +from .default import DefaultModelConfig + + +class ModelConfig(DefaultModelConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'R101_AOTL' + + self.MODEL_ENCODER = 'resnet101' + self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet101-63fe2227.pth' # https://download.pytorch.org/models/resnet101-63fe2227.pth + self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x + self.MODEL_LSTT_NUM = 3 + + self.TRAIN_LONG_TERM_MEM_GAP = 2 + + self.TEST_LONG_TERM_MEM_GAP = 5 \ No newline at end of file diff --git a/aot/configs/models/r50_aotl.py b/aot/configs/models/r50_aotl.py new file mode 100644 index 0000000000000000000000000000000000000000..941b9228f06e7b7fe7ef8fda6596c19120a254c0 --- /dev/null +++ b/aot/configs/models/r50_aotl.py @@ -0,0 +1,16 @@ +from .default import DefaultModelConfig + + +class ModelConfig(DefaultModelConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'R50_AOTL' + + self.MODEL_ENCODER = 'resnet50' + self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet50-0676ba61.pth' # https://download.pytorch.org/models/resnet50-0676ba61.pth + self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x + self.MODEL_LSTT_NUM = 3 + + self.TRAIN_LONG_TERM_MEM_GAP = 2 + + self.TEST_LONG_TERM_MEM_GAP = 5 \ No newline at end of file diff --git a/aot/configs/models/r50_deaotl.py b/aot/configs/models/r50_deaotl.py new file mode 100644 index 0000000000000000000000000000000000000000..216abdb07c20b3fad131f868fa9d5b96cb17e8f9 --- /dev/null +++ b/aot/configs/models/r50_deaotl.py @@ -0,0 +1,16 @@ +from .default_deaot import DefaultModelConfig + + +class ModelConfig(DefaultModelConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'R50_DeAOTL' + + self.MODEL_ENCODER = 'resnet50' + self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x + + self.MODEL_LSTT_NUM = 3 + + self.TRAIN_LONG_TERM_MEM_GAP = 2 + + self.TEST_LONG_TERM_MEM_GAP = 5 diff --git a/aot/configs/models/rs101_aotl.py b/aot/configs/models/rs101_aotl.py new file mode 100644 index 0000000000000000000000000000000000000000..b1636ec2d13db08758d1765a71c5acf717ff143a --- /dev/null +++ b/aot/configs/models/rs101_aotl.py @@ -0,0 +1,16 @@ +from .default import DefaultModelConfig + + +class ModelConfig(DefaultModelConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'R101_AOTL' + + self.MODEL_ENCODER = 'resnest101' + self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnest101-22405ba7.pth' # https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth + self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x + self.MODEL_LSTT_NUM = 3 + + self.TRAIN_LONG_TERM_MEM_GAP = 2 + + self.TEST_LONG_TERM_MEM_GAP = 5 \ No newline at end of file diff --git a/aot/configs/models/swinb_aotl.py b/aot/configs/models/swinb_aotl.py new file mode 100644 index 0000000000000000000000000000000000000000..360a16d33184ca6e265bbe5a7315f72ce755b53a --- /dev/null +++ b/aot/configs/models/swinb_aotl.py @@ -0,0 +1,17 @@ +from .default import DefaultModelConfig + + +class ModelConfig(DefaultModelConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'SwinB_AOTL' + + self.MODEL_ENCODER = 'swin_base' + self.MODEL_ENCODER_PRETRAIN = './pretrain_models/swin_base_patch4_window7_224_22k.pth' # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth + self.MODEL_ALIGN_CORNERS = False + self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x + self.MODEL_LSTT_NUM = 3 + + self.TRAIN_LONG_TERM_MEM_GAP = 2 + + self.TEST_LONG_TERM_MEM_GAP = 5 \ No newline at end of file diff --git a/aot/configs/models/swinb_deaotl.py b/aot/configs/models/swinb_deaotl.py new file mode 100644 index 0000000000000000000000000000000000000000..463a3fa61b45740a3f821b7bc4bcbb432950f62b --- /dev/null +++ b/aot/configs/models/swinb_deaotl.py @@ -0,0 +1,17 @@ +from .default_deaot import DefaultModelConfig + + +class ModelConfig(DefaultModelConfig): + def __init__(self): + super().__init__() + self.MODEL_NAME = 'SwinB_DeAOTL' + + self.MODEL_ENCODER = 'swin_base' + self.MODEL_ALIGN_CORNERS = False + self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x + + self.MODEL_LSTT_NUM = 3 + + self.TRAIN_LONG_TERM_MEM_GAP = 2 + + self.TEST_LONG_TERM_MEM_GAP = 5 \ No newline at end of file diff --git a/aot/configs/pre.py b/aot/configs/pre.py new file mode 100644 index 0000000000000000000000000000000000000000..53b8b0283a59eb3c048e64ce200836a33c5be7ab --- /dev/null +++ b/aot/configs/pre.py @@ -0,0 +1,19 @@ +from .default import DefaultEngineConfig + + +class EngineConfig(DefaultEngineConfig): + def __init__(self, exp_name='default', model='AOTT'): + super().__init__(exp_name, model) + self.STAGE_NAME = 'PRE' + + self.init_dir() + + self.DATASETS = ['static'] + + self.DATA_DYNAMIC_MERGE_PROB = 1.0 + + self.TRAIN_LR = 4e-4 + self.TRAIN_LR_MIN = 2e-5 + self.TRAIN_WEIGHT_DECAY = 0.03 + self.TRAIN_SEQ_TRAINING_START_RATIO = 1.0 + self.TRAIN_AUX_LOSS_RATIO = 0.1 diff --git a/aot/configs/pre_dav.py b/aot/configs/pre_dav.py new file mode 100644 index 0000000000000000000000000000000000000000..2abf75f557815ba2c0499d6c7f68539079b25293 --- /dev/null +++ b/aot/configs/pre_dav.py @@ -0,0 +1,21 @@ +import os +from .default import DefaultEngineConfig + + +class EngineConfig(DefaultEngineConfig): + def __init__(self, exp_name='default', model='AOTT'): + super().__init__(exp_name, model) + self.STAGE_NAME = 'PRE_DAV' + + self.init_dir() + + self.DATASETS = ['davis2017'] + + self.TRAIN_TOTAL_STEPS = 50000 + + pretrain_stage = 'PRE' + pretrain_ckpt = 'save_step_100000.pth' + self.PRETRAIN_FULL = True # if False, load encoder only + self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', + self.EXP_NAME, pretrain_stage, + 'ema_ckpt', pretrain_ckpt) diff --git a/aot/configs/pre_ytb.py b/aot/configs/pre_ytb.py new file mode 100644 index 0000000000000000000000000000000000000000..a1edbb1103b2d9fcc5606cf41ee03bece7cb2d93 --- /dev/null +++ b/aot/configs/pre_ytb.py @@ -0,0 +1,17 @@ +import os +from .default import DefaultEngineConfig + + +class EngineConfig(DefaultEngineConfig): + def __init__(self, exp_name='default', model='AOTT'): + super().__init__(exp_name, model) + self.STAGE_NAME = 'PRE_YTB' + + self.init_dir() + + pretrain_stage = 'PRE' + pretrain_ckpt = 'save_step_100000.pth' + self.PRETRAIN_FULL = True # if False, load encoder only + self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', + self.EXP_NAME, pretrain_stage, + 'ema_ckpt', pretrain_ckpt) diff --git a/aot/configs/pre_ytb_dav.py b/aot/configs/pre_ytb_dav.py new file mode 100644 index 0000000000000000000000000000000000000000..0d58a5dc20af434394d157750a3f0f2f19095027 --- /dev/null +++ b/aot/configs/pre_ytb_dav.py @@ -0,0 +1,19 @@ +import os +from .default import DefaultEngineConfig + + +class EngineConfig(DefaultEngineConfig): + def __init__(self, exp_name='default', model='AOTT'): + super().__init__(exp_name, model) + self.STAGE_NAME = 'PRE_YTB_DAV' + + self.init_dir() + + self.DATASETS = ['youtubevos', 'davis2017'] + + pretrain_stage = 'PRE' + pretrain_ckpt = 'save_step_100000.pth' + self.PRETRAIN_FULL = True # if False, load encoder only + self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', + self.EXP_NAME, pretrain_stage, + 'ema_ckpt', pretrain_ckpt) diff --git a/aot/configs/ytb.py b/aot/configs/ytb.py new file mode 100644 index 0000000000000000000000000000000000000000..f476ee106290fe390cccf2b9e8f116ee1c8fbd61 --- /dev/null +++ b/aot/configs/ytb.py @@ -0,0 +1,10 @@ +import os +from .default import DefaultEngineConfig + + +class EngineConfig(DefaultEngineConfig): + def __init__(self, exp_name='default', model='AOTT'): + super().__init__(exp_name, model) + self.STAGE_NAME = 'YTB' + + self.init_dir() diff --git a/aot/dataloaders/__init__.py b/aot/dataloaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/aot/dataloaders/__pycache__/__init__.cpython-310.pyc b/aot/dataloaders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e3a0f1bf85e24f4bb594bfa47b7b7f5529fd98d Binary files /dev/null and b/aot/dataloaders/__pycache__/__init__.cpython-310.pyc differ diff --git a/aot/dataloaders/__pycache__/eval_datasets.cpython-310.pyc b/aot/dataloaders/__pycache__/eval_datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2793005b63872ce5bf1d1434fb098605c055fa9c Binary files /dev/null and b/aot/dataloaders/__pycache__/eval_datasets.cpython-310.pyc differ diff --git a/aot/dataloaders/__pycache__/image_transforms.cpython-310.pyc b/aot/dataloaders/__pycache__/image_transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..957f6eb258bdb9a546d339dcf0c42373b2743d14 Binary files /dev/null and b/aot/dataloaders/__pycache__/image_transforms.cpython-310.pyc differ diff --git a/aot/dataloaders/__pycache__/video_transforms.cpython-310.pyc b/aot/dataloaders/__pycache__/video_transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..311b9cbe829d5295ab6eede420b1a1ae6ab98ae4 Binary files /dev/null and b/aot/dataloaders/__pycache__/video_transforms.cpython-310.pyc differ diff --git a/aot/dataloaders/eval_datasets.py b/aot/dataloaders/eval_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c5bc22f2c9b14d90d201d0e128fdacb9efc58f --- /dev/null +++ b/aot/dataloaders/eval_datasets.py @@ -0,0 +1,411 @@ +from __future__ import division +import os +import shutil +import json +import cv2 +from PIL import Image + +import numpy as np +from torch.utils.data import Dataset + +from utils.image import _palette + + +class VOSTest(Dataset): + def __init__(self, + image_root, + label_root, + seq_name, + images, + labels, + rgb=True, + transform=None, + single_obj=False, + resolution=None): + self.image_root = image_root + self.label_root = label_root + self.seq_name = seq_name + self.images = images + self.labels = labels + self.obj_num = 1 + self.num_frame = len(self.images) + self.transform = transform + self.rgb = rgb + self.single_obj = single_obj + self.resolution = resolution + + self.obj_nums = [] + self.obj_indices = [] + + curr_objs = [0] + for img_name in self.images: + self.obj_nums.append(len(curr_objs) - 1) + current_label_name = img_name.split('.')[0] + '.png' + if current_label_name in self.labels: + current_label = self.read_label(current_label_name) + curr_obj = list(np.unique(current_label)) + for obj_idx in curr_obj: + if obj_idx not in curr_objs: + curr_objs.append(obj_idx) + self.obj_indices.append(curr_objs.copy()) + + self.obj_nums[0] = self.obj_nums[1] + + def __len__(self): + return len(self.images) + + def read_image(self, idx): + img_name = self.images[idx] + img_path = os.path.join(self.image_root, self.seq_name, img_name) + img = cv2.imread(img_path) + img = np.array(img, dtype=np.float32) + if self.rgb: + img = img[:, :, [2, 1, 0]] + return img + + def read_label(self, label_name, squeeze_idx=None): + label_path = os.path.join(self.label_root, self.seq_name, label_name) + label = Image.open(label_path) + label = np.array(label, dtype=np.uint8) + if self.single_obj: + label = (label > 0).astype(np.uint8) + elif squeeze_idx is not None: + squeezed_label = label * 0 + for idx in range(len(squeeze_idx)): + obj_id = squeeze_idx[idx] + if obj_id == 0: + continue + mask = label == obj_id + squeezed_label += (mask * idx).astype(np.uint8) + label = squeezed_label + return label + + def __getitem__(self, idx): + img_name = self.images[idx] + current_img = self.read_image(idx) + height, width, channels = current_img.shape + if self.resolution is not None: + width = int(np.ceil( + float(width) * self.resolution / float(height))) + height = int(self.resolution) + + current_label_name = img_name.split('.')[0] + '.png' + obj_num = self.obj_nums[idx] + obj_idx = self.obj_indices[idx] + + if current_label_name in self.labels: + current_label = self.read_label(current_label_name, obj_idx) + sample = { + 'current_img': current_img, + 'current_label': current_label + } + else: + sample = {'current_img': current_img} + + sample['meta'] = { + 'seq_name': self.seq_name, + 'frame_num': self.num_frame, + 'obj_num': obj_num, + 'current_name': img_name, + 'height': height, + 'width': width, + 'flip': False, + 'obj_idx': obj_idx + } + + if self.transform is not None: + sample = self.transform(sample) + return sample + + +class YOUTUBEVOS_Test(object): + def __init__(self, + root='./datasets/YTB', + year=2018, + split='val', + transform=None, + rgb=True, + result_root=None): + if split == 'val': + split = 'valid' + root = os.path.join(root, str(year), split) + self.db_root_dir = root + self.result_root = result_root + self.rgb = rgb + self.transform = transform + self.seq_list_file = os.path.join(self.db_root_dir, 'meta.json') + self._check_preprocess() + self.seqs = list(self.ann_f.keys()) + self.image_root = os.path.join(root, 'JPEGImages') + self.label_root = os.path.join(root, 'Annotations') + + def __len__(self): + return len(self.seqs) + + def __getitem__(self, idx): + seq_name = self.seqs[idx] + data = self.ann_f[seq_name]['objects'] + obj_names = list(data.keys()) + images = [] + labels = [] + for obj_n in obj_names: + images += map(lambda x: x + '.jpg', list(data[obj_n]["frames"])) + labels.append(data[obj_n]["frames"][0] + '.png') + images = np.sort(np.unique(images)) + labels = np.sort(np.unique(labels)) + + try: + if not os.path.isfile( + os.path.join(self.result_root, seq_name, labels[0])): + if not os.path.exists(os.path.join(self.result_root, + seq_name)): + os.makedirs(os.path.join(self.result_root, seq_name)) + shutil.copy( + os.path.join(self.label_root, seq_name, labels[0]), + os.path.join(self.result_root, seq_name, labels[0])) + except Exception as inst: + print(inst) + print('Failed to create a result folder for sequence {}.'.format( + seq_name)) + + seq_dataset = VOSTest(self.image_root, + self.label_root, + seq_name, + images, + labels, + transform=self.transform, + rgb=self.rgb) + return seq_dataset + + def _check_preprocess(self): + _seq_list_file = self.seq_list_file + if not os.path.isfile(_seq_list_file): + print(_seq_list_file) + return False + else: + self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos'] + return True + + +class YOUTUBEVOS_DenseTest(object): + def __init__(self, + root='./datasets/YTB', + year=2018, + split='val', + transform=None, + rgb=True, + result_root=None): + if split == 'val': + split = 'valid' + root_sparse = os.path.join(root, str(year), split) + root_dense = root_sparse + '_all_frames' + self.db_root_dir = root_dense + self.result_root = result_root + self.rgb = rgb + self.transform = transform + self.seq_list_file = os.path.join(root_sparse, 'meta.json') + self._check_preprocess() + self.seqs = list(self.ann_f.keys()) + self.image_root = os.path.join(root_dense, 'JPEGImages') + self.label_root = os.path.join(root_sparse, 'Annotations') + + def __len__(self): + return len(self.seqs) + + def __getitem__(self, idx): + seq_name = self.seqs[idx] + + data = self.ann_f[seq_name]['objects'] + obj_names = list(data.keys()) + images_sparse = [] + for obj_n in obj_names: + images_sparse += map(lambda x: x + '.jpg', + list(data[obj_n]["frames"])) + images_sparse = np.sort(np.unique(images_sparse)) + + images = np.sort( + list(os.listdir(os.path.join(self.image_root, seq_name)))) + start_img = images_sparse[0] + end_img = images_sparse[-1] + for start_idx in range(len(images)): + if start_img in images[start_idx]: + break + for end_idx in range(len(images))[::-1]: + if end_img in images[end_idx]: + break + images = images[start_idx:(end_idx + 1)] + labels = np.sort( + list(os.listdir(os.path.join(self.label_root, seq_name)))) + + try: + if not os.path.isfile( + os.path.join(self.result_root, seq_name, labels[0])): + if not os.path.exists(os.path.join(self.result_root, + seq_name)): + os.makedirs(os.path.join(self.result_root, seq_name)) + shutil.copy( + os.path.join(self.label_root, seq_name, labels[0]), + os.path.join(self.result_root, seq_name, labels[0])) + except Exception as inst: + print(inst) + print('Failed to create a result folder for sequence {}.'.format( + seq_name)) + + seq_dataset = VOSTest(self.image_root, + self.label_root, + seq_name, + images, + labels, + transform=self.transform, + rgb=self.rgb) + seq_dataset.images_sparse = images_sparse + + return seq_dataset + + def _check_preprocess(self): + _seq_list_file = self.seq_list_file + if not os.path.isfile(_seq_list_file): + print(_seq_list_file) + return False + else: + self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos'] + return True + + +class DAVIS_Test(object): + def __init__(self, + split=['val'], + root='./DAVIS', + year=2017, + transform=None, + rgb=True, + full_resolution=False, + result_root=None): + self.transform = transform + self.rgb = rgb + self.result_root = result_root + if year == 2016: + self.single_obj = True + else: + self.single_obj = False + if full_resolution: + resolution = 'Full-Resolution' + else: + resolution = '480p' + self.image_root = os.path.join(root, 'JPEGImages', resolution) + self.label_root = os.path.join(root, 'Annotations', resolution) + seq_names = [] + for spt in split: + if spt == 'test': + spt = 'test-dev' + with open(os.path.join(root, 'ImageSets', str(year), + spt + '.txt')) as f: + seqs_tmp = f.readlines() + seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp)) + seq_names.extend(seqs_tmp) + self.seqs = list(np.unique(seq_names)) + + def __len__(self): + return len(self.seqs) + + def __getitem__(self, idx): + seq_name = self.seqs[idx] + images = list( + np.sort(os.listdir(os.path.join(self.image_root, seq_name)))) + labels = [images[0].replace('jpg', 'png')] + + if not os.path.isfile( + os.path.join(self.result_root, seq_name, labels[0])): + seq_result_folder = os.path.join(self.result_root, seq_name) + try: + if not os.path.exists(seq_result_folder): + os.makedirs(seq_result_folder) + except Exception as inst: + print(inst) + print( + 'Failed to create a result folder for sequence {}.'.format( + seq_name)) + source_label_path = os.path.join(self.label_root, seq_name, + labels[0]) + result_label_path = os.path.join(self.result_root, seq_name, + labels[0]) + if self.single_obj: + label = Image.open(source_label_path) + label = np.array(label, dtype=np.uint8) + label = (label > 0).astype(np.uint8) + label = Image.fromarray(label).convert('P') + label.putpalette(_palette) + label.save(result_label_path) + else: + shutil.copy(source_label_path, result_label_path) + + seq_dataset = VOSTest(self.image_root, + self.label_root, + seq_name, + images, + labels, + transform=self.transform, + rgb=self.rgb, + single_obj=self.single_obj, + resolution=480) + return seq_dataset + + +class _EVAL_TEST(Dataset): + def __init__(self, transform, seq_name): + self.seq_name = seq_name + self.num_frame = 10 + self.transform = transform + + def __len__(self): + return self.num_frame + + def __getitem__(self, idx): + current_frame_obj_num = 2 + height = 400 + width = 400 + img_name = 'test{}.jpg'.format(idx) + current_img = np.zeros((height, width, 3)).astype(np.float32) + if idx == 0: + current_label = (current_frame_obj_num * np.ones( + (height, width))).astype(np.uint8) + sample = { + 'current_img': current_img, + 'current_label': current_label + } + else: + sample = {'current_img': current_img} + + sample['meta'] = { + 'seq_name': self.seq_name, + 'frame_num': self.num_frame, + 'obj_num': current_frame_obj_num, + 'current_name': img_name, + 'height': height, + 'width': width, + 'flip': False + } + + if self.transform is not None: + sample = self.transform(sample) + return sample + + +class EVAL_TEST(object): + def __init__(self, transform=None, result_root=None): + self.transform = transform + self.result_root = result_root + + self.seqs = ['test1', 'test2', 'test3'] + + def __len__(self): + return len(self.seqs) + + def __getitem__(self, idx): + seq_name = self.seqs[idx] + + if not os.path.exists(os.path.join(self.result_root, seq_name)): + os.makedirs(os.path.join(self.result_root, seq_name)) + + seq_dataset = _EVAL_TEST(self.transform, seq_name) + return seq_dataset diff --git a/aot/dataloaders/image_transforms.py b/aot/dataloaders/image_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..0c90be41e911f8770820277474b3b782d9be3024 --- /dev/null +++ b/aot/dataloaders/image_transforms.py @@ -0,0 +1,530 @@ +import math +import warnings +import random +import numbers +import numpy as np +from PIL import Image, ImageFilter +from collections.abc import Sequence + +import torch +import torchvision.transforms.functional as TF + +_pil_interpolation_to_str = { + Image.NEAREST: 'PIL.Image.NEAREST', + Image.BILINEAR: 'PIL.Image.BILINEAR', + Image.BICUBIC: 'PIL.Image.BICUBIC', + Image.LANCZOS: 'PIL.Image.LANCZOS', + Image.HAMMING: 'PIL.Image.HAMMING', + Image.BOX: 'PIL.Image.BOX', +} + + +def _get_image_size(img): + if TF._is_pil_image(img): + return img.size + elif isinstance(img, torch.Tensor) and img.dim() > 2: + return img.shape[-2:][::-1] + else: + raise TypeError("Unexpected type {}".format(type(img))) + + +class RandomHorizontalFlip(object): + """Horizontal flip the given PIL Image randomly with a given probability. + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, mask): + """ + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Randomly flipped image. + """ + if random.random() < self.p: + img = TF.hflip(img) + mask = TF.hflip(mask) + return img, mask + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomVerticalFlip(object): + """Vertical flip the given PIL Image randomly with a given probability. + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, mask): + """ + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Randomly flipped image. + """ + if random.random() < self.p: + img = TF.vflip(img) + mask = TF.vflip(mask) + return img, mask + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class GaussianBlur(object): + """Gaussian blur augmentation from SimCLR: https://arxiv.org/abs/2002.05709""" + def __init__(self, sigma=[.1, 2.]): + self.sigma = sigma + + def __call__(self, x): + sigma = random.uniform(self.sigma[0], self.sigma[1]) + x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) + return x + + +class RandomAffine(object): + """Random affine transformation of the image keeping center invariant + + Args: + degrees (sequence or float or int): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). Set to 0 to deactivate rotations. + translate (tuple, optional): tuple of maximum absolute fraction for horizontal + and vertical translations. For example translate=(a, b), then horizontal shift + is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is + randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. + scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is + randomly sampled from the range a <= scale <= b. Will keep original scale by default. + shear (sequence or float or int, optional): Range of degrees to select from. + If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) + will be apllied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the + range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values, + a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. + Will not apply shear by default + resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. + fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area + outside the transform in the output image.(Pillow>=5.0.0) + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + def __init__(self, + degrees, + translate=None, + scale=None, + shear=None, + resample=False, + fillcolor=0): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError( + "If degrees is a single number, it must be positive.") + self.degrees = (-degrees, degrees) + else: + assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ + "degrees should be a list or tuple and it must be of length 2." + self.degrees = degrees + + if translate is not None: + assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ + "translate should be a list or tuple and it must be of length 2." + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError( + "translation values should be between 0 and 1") + self.translate = translate + + if scale is not None: + assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ + "scale should be a list or tuple and it must be of length 2." + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + if isinstance(shear, numbers.Number): + if shear < 0: + raise ValueError( + "If shear is a single number, it must be positive.") + self.shear = (-shear, shear) + else: + assert isinstance(shear, (tuple, list)) and \ + (len(shear) == 2 or len(shear) == 4), \ + "shear should be a list or tuple and it must be of length 2 or 4." + # X-Axis shear with [min, max] + if len(shear) == 2: + self.shear = [shear[0], shear[1], 0., 0.] + elif len(shear) == 4: + self.shear = [s for s in shear] + else: + self.shear = shear + + self.resample = resample + self.fillcolor = fillcolor + + @staticmethod + def get_params(degrees, translate, scale_ranges, shears, img_size): + """Get parameters for affine transformation + + Returns: + sequence: params to be passed to the affine transformation + """ + angle = random.uniform(degrees[0], degrees[1]) + if translate is not None: + max_dx = translate[0] * img_size[0] + max_dy = translate[1] * img_size[1] + translations = (np.round(random.uniform(-max_dx, max_dx)), + np.round(random.uniform(-max_dy, max_dy))) + else: + translations = (0, 0) + + if scale_ranges is not None: + scale = random.uniform(scale_ranges[0], scale_ranges[1]) + else: + scale = 1.0 + + if shears is not None: + if len(shears) == 2: + shear = [random.uniform(shears[0], shears[1]), 0.] + elif len(shears) == 4: + shear = [ + random.uniform(shears[0], shears[1]), + random.uniform(shears[2], shears[3]) + ] + else: + shear = 0.0 + + return angle, translations, scale, shear + + def __call__(self, img, mask): + """ + img (PIL Image): Image to be transformed. + + Returns: + PIL Image: Affine transformed image. + """ + ret = self.get_params(self.degrees, self.translate, self.scale, + self.shear, img.size) + img = TF.affine(img, + *ret, + resample=self.resample, + fillcolor=self.fillcolor) + mask = TF.affine(mask, *ret, resample=Image.NEAREST, fillcolor=0) + return img, mask + + def __repr__(self): + s = '{name}(degrees={degrees}' + if self.translate is not None: + s += ', translate={translate}' + if self.scale is not None: + s += ', scale={scale}' + if self.shear is not None: + s += ', shear={shear}' + if self.resample > 0: + s += ', resample={resample}' + if self.fillcolor != 0: + s += ', fillcolor={fillcolor}' + s += ')' + d = dict(self.__dict__) + d['resample'] = _pil_interpolation_to_str[d['resample']] + return s.format(name=self.__class__.__name__, **d) + + +class RandomCrop(object): + """Crop the given PIL Image at a random location. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + padding (int or sequence, optional): Optional padding on each border + of the image. Default is None, i.e no padding. If a sequence of length + 4 is provided, it is used to pad left, top, right, bottom borders + respectively. If a sequence of length 2 is provided, it is used to + pad left/right, top/bottom borders, respectively. + pad_if_needed (boolean): It will pad the image if smaller than the + desired size to avoid raising an exception. Since cropping is done + after padding, the padding seems to be done at a random offset. + fill: Pixel fill value for constant fill. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant + padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value on the edge of the image + + - reflect: pads with reflection of image (without repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image (repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + + """ + def __init__(self, + size, + padding=None, + pad_if_needed=False, + fill=0, + padding_mode='constant'): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + @staticmethod + def get_params(img, output_size): + """Get parameters for ``crop`` for a random crop. + + Args: + img (PIL Image): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. + """ + w, h = _get_image_size(img) + th, tw = output_size + if w == tw and h == th: + return 0, 0, h, w + + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + return i, j, th, tw + + def __call__(self, img, mask): + """ + Args: + img (PIL Image): Image to be cropped. + + Returns: + PIL Image: Cropped image. + """ + # if self.padding is not None: + # img = TF.pad(img, self.padding, self.fill, self.padding_mode) + # + # # pad the width if needed + # if self.pad_if_needed and img.size[0] < self.size[1]: + # img = TF.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) + # # pad the height if needed + # if self.pad_if_needed and img.size[1] < self.size[0]: + # img = TF.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) + + i, j, h, w = self.get_params(img, self.size) + img = TF.crop(img, i, j, h, w) + mask = TF.crop(mask, i, j, h, w) + + return img, mask + + def __repr__(self): + return self.__class__.__name__ + '(size={0}, padding={1})'.format( + self.size, self.padding) + + +class RandomResizedCrop(object): + """Crop the given PIL Image to random size and aspect ratio. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + def __init__(self, + size, + scale=(0.08, 1.0), + ratio=(3. / 4., 4. / 3.), + interpolation=Image.BILINEAR): + if isinstance(size, (tuple, list)): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("range should be of kind (min, max)") + + self.interpolation = interpolation + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + width, height = _get_image_size(img) + area = height * width + + for _ in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if (in_ratio < min(ratio)): + w = width + h = int(round(w / min(ratio))) + elif (in_ratio > max(ratio)): + h = height + w = int(round(h * max(ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def __call__(self, img, mask): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + # print(i, j, h, w) + img = TF.resized_crop(img, i, j, h, w, self.size, self.interpolation) + mask = TF.resized_crop(mask, i, j, h, w, self.size, Image.NEAREST) + return img, mask + + def __repr__(self): + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format( + tuple(round(s, 4) for s in self.scale)) + format_string += ', ratio={0}'.format( + tuple(round(r, 4) for r in self.ratio)) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + +class ToOnehot(object): + """To oneshot tensor + + Args: + max_obj_n (float): Maximum number of the objects + """ + def __init__(self, max_obj_n, shuffle): + self.max_obj_n = max_obj_n + self.shuffle = shuffle + + def __call__(self, mask, obj_list=None): + """ + Args: + mask (Mask in Numpy): Mask to be converted. + + Returns: + Tensor: Converted mask in onehot format. + """ + + new_mask = np.zeros((self.max_obj_n + 1, *mask.shape), np.uint8) + + if not obj_list: + obj_list = list() + obj_max = mask.max() + 1 + for i in range(1, obj_max): + tmp = (mask == i).astype(np.uint8) + if tmp.max() > 0: + obj_list.append(i) + + if self.shuffle: + random.shuffle(obj_list) + obj_list = obj_list[:self.max_obj_n] + + for i in range(len(obj_list)): + new_mask[i + 1] = (mask == obj_list[i]).astype(np.uint8) + new_mask[0] = 1 - np.sum(new_mask, axis=0) + + return torch.from_numpy(new_mask), obj_list + + def __repr__(self): + return self.__class__.__name__ + '(max_obj_n={})'.format( + self.max_obj_n) + + +class Resize(torch.nn.Module): + """Resize the input image to the given size. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions + + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size). + In torchscript mode padding as single int is not supported, use a tuple or + list of length 1: ``[size, ]``. + interpolation (int, optional): Desired interpolation enum defined by `filters`_. + Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` + and ``PIL.Image.BICUBIC`` are supported. + """ + def __init__(self, size, interpolation=Image.BILINEAR): + super().__init__() + if not isinstance(size, (int, Sequence)): + raise TypeError("Size should be int or sequence. Got {}".format( + type(size))) + if isinstance(size, Sequence) and len(size) not in (1, 2): + raise ValueError( + "If size is a sequence, it should have 1 or 2 values") + self.size = size + self.interpolation = interpolation + + def forward(self, img, mask): + """ + Args: + img (PIL Image or Tensor): Image to be scaled. + + Returns: + PIL Image or Tensor: Rescaled image. + """ + img = TF.resize(img, self.size, self.interpolation) + mask = TF.resize(mask, self.size, Image.NEAREST) + return img, mask + + def __repr__(self): + interpolate_str = _pil_interpolation_to_str[self.interpolation] + return self.__class__.__name__ + '(size={0}, interpolation={1})'.format( + self.size, interpolate_str) diff --git a/aot/dataloaders/train_datasets.py b/aot/dataloaders/train_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..eadc573fa23d6f16233a764bc14ec1dd24671e52 --- /dev/null +++ b/aot/dataloaders/train_datasets.py @@ -0,0 +1,682 @@ +from __future__ import division +import os +from glob import glob +import json +import random +import cv2 +from PIL import Image + +import numpy as np +import torch +from torch.utils.data import Dataset +import torchvision.transforms as TF + +import dataloaders.image_transforms as IT + +cv2.setNumThreads(0) + + +def _get_images(sample): + return [sample['ref_img'], sample['prev_img']] + sample['curr_img'] + + +def _get_labels(sample): + return [sample['ref_label'], sample['prev_label']] + sample['curr_label'] + + +def _merge_sample(sample1, sample2, min_obj_pixels=100, max_obj_n=10): + + sample1_images = _get_images(sample1) + sample2_images = _get_images(sample2) + + sample1_labels = _get_labels(sample1) + sample2_labels = _get_labels(sample2) + + obj_idx = torch.arange(0, max_obj_n * 2 + 1).view(max_obj_n * 2 + 1, 1, 1) + selected_idx = None + selected_obj = None + + all_img = [] + all_mask = [] + for idx, (s1_img, s2_img, s1_label, s2_label) in enumerate( + zip(sample1_images, sample2_images, sample1_labels, + sample2_labels)): + s2_fg = (s2_label > 0).float() + s2_bg = 1 - s2_fg + merged_img = s1_img * s2_bg + s2_img * s2_fg + merged_mask = s1_label * s2_bg.long() + ( + (s2_label + max_obj_n) * s2_fg.long()) + merged_mask = (merged_mask == obj_idx).float() + if idx == 0: + after_merge_pixels = merged_mask.sum(dim=(1, 2), keepdim=True) + selected_idx = after_merge_pixels > min_obj_pixels + selected_idx[0] = True + obj_num = selected_idx.sum().int().item() - 1 + selected_idx = selected_idx.expand(-1, + s1_label.size()[1], + s1_label.size()[2]) + if obj_num > max_obj_n: + selected_obj = list(range(1, obj_num + 1)) + random.shuffle(selected_obj) + selected_obj = [0] + selected_obj[:max_obj_n] + + merged_mask = merged_mask[selected_idx].view(obj_num + 1, + s1_label.size()[1], + s1_label.size()[2]) + if obj_num > max_obj_n: + merged_mask = merged_mask[selected_obj] + merged_mask[0] += 0.1 + merged_mask = torch.argmax(merged_mask, dim=0, keepdim=True).long() + + all_img.append(merged_img) + all_mask.append(merged_mask) + + sample = { + 'ref_img': all_img[0], + 'prev_img': all_img[1], + 'curr_img': all_img[2:], + 'ref_label': all_mask[0], + 'prev_label': all_mask[1], + 'curr_label': all_mask[2:] + } + sample['meta'] = sample1['meta'] + sample['meta']['obj_num'] = min(obj_num, max_obj_n) + return sample + + +class StaticTrain(Dataset): + def __init__(self, + root, + output_size, + seq_len=5, + max_obj_n=10, + dynamic_merge=True, + merge_prob=1.0, + aug_type='v1'): + self.root = root + self.clip_n = seq_len + self.output_size = output_size + self.max_obj_n = max_obj_n + + self.dynamic_merge = dynamic_merge + self.merge_prob = merge_prob + + self.img_list = list() + self.mask_list = list() + + dataset_list = list() + lines = ['COCO', 'ECSSD', 'MSRA10K', 'PASCAL-S', 'PASCALVOC2012'] + for line in lines: + dataset_name = line.strip() + + img_dir = os.path.join(root, 'JPEGImages', dataset_name) + mask_dir = os.path.join(root, 'Annotations', dataset_name) + + img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) + \ + sorted(glob(os.path.join(img_dir, '*.png'))) + mask_list = sorted(glob(os.path.join(mask_dir, '*.png'))) + + if len(img_list) > 0: + if len(img_list) == len(mask_list): + dataset_list.append(dataset_name) + self.img_list += img_list + self.mask_list += mask_list + print(f'\t{dataset_name}: {len(img_list)} imgs.') + else: + print( + f'\tPreTrain dataset {dataset_name} has {len(img_list)} imgs and {len(mask_list)} annots. Not match! Skip.' + ) + else: + print( + f'\tPreTrain dataset {dataset_name} doesn\'t exist. Skip.') + + print( + f'{len(self.img_list)} imgs are used for PreTrain. They are from {dataset_list}.' + ) + + self.aug_type = aug_type + + self.pre_random_horizontal_flip = IT.RandomHorizontalFlip(0.5) + + self.random_horizontal_flip = IT.RandomHorizontalFlip(0.3) + + if self.aug_type == 'v1': + self.color_jitter = TF.ColorJitter(0.1, 0.1, 0.1, 0.03) + elif self.aug_type == 'v2': + self.color_jitter = TF.RandomApply( + [TF.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8) + self.gray_scale = TF.RandomGrayscale(p=0.2) + self.blur = TF.RandomApply([IT.GaussianBlur([.1, 2.])], p=0.3) + else: + assert NotImplementedError + + self.random_affine = IT.RandomAffine(degrees=20, + translate=(0.1, 0.1), + scale=(0.9, 1.1), + shear=10, + resample=Image.BICUBIC, + fillcolor=(124, 116, 104)) + base_ratio = float(output_size[1]) / output_size[0] + self.random_resize_crop = IT.RandomResizedCrop( + output_size, (0.8, 1), + ratio=(base_ratio * 3. / 4., base_ratio * 4. / 3.), + interpolation=Image.BICUBIC) + self.to_tensor = TF.ToTensor() + self.to_onehot = IT.ToOnehot(max_obj_n, shuffle=True) + self.normalize = TF.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + + def __len__(self): + return len(self.img_list) + + def load_image_in_PIL(self, path, mode='RGB'): + img = Image.open(path) + img.load() # Very important for loading large image + return img.convert(mode) + + def sample_sequence(self, idx): + img_pil = self.load_image_in_PIL(self.img_list[idx], 'RGB') + mask_pil = self.load_image_in_PIL(self.mask_list[idx], 'P') + + frames = [] + masks = [] + + img_pil, mask_pil = self.pre_random_horizontal_flip(img_pil, mask_pil) + # img_pil, mask_pil = self.pre_random_vertical_flip(img_pil, mask_pil) + + for i in range(self.clip_n): + img, mask = img_pil, mask_pil + + if i > 0: + img, mask = self.random_horizontal_flip(img, mask) + img, mask = self.random_affine(img, mask) + + img = self.color_jitter(img) + + img, mask = self.random_resize_crop(img, mask) + + if self.aug_type == 'v2': + img = self.gray_scale(img) + img = self.blur(img) + + mask = np.array(mask, np.uint8) + + if i == 0: + mask, obj_list = self.to_onehot(mask) + obj_num = len(obj_list) + else: + mask, _ = self.to_onehot(mask, obj_list) + + mask = torch.argmax(mask, dim=0, keepdim=True) + + frames.append(self.normalize(self.to_tensor(img))) + masks.append(mask) + + sample = { + 'ref_img': frames[0], + 'prev_img': frames[1], + 'curr_img': frames[2:], + 'ref_label': masks[0], + 'prev_label': masks[1], + 'curr_label': masks[2:] + } + sample['meta'] = { + 'seq_name': self.img_list[idx], + 'frame_num': 1, + 'obj_num': obj_num + } + + return sample + + def __getitem__(self, idx): + sample1 = self.sample_sequence(idx) + + if self.dynamic_merge and (sample1['meta']['obj_num'] == 0 + or random.random() < self.merge_prob): + rand_idx = np.random.randint(len(self.img_list)) + while (rand_idx == idx): + rand_idx = np.random.randint(len(self.img_list)) + + sample2 = self.sample_sequence(rand_idx) + + sample = self.merge_sample(sample1, sample2) + else: + sample = sample1 + + return sample + + def merge_sample(self, sample1, sample2, min_obj_pixels=100): + return _merge_sample(sample1, sample2, min_obj_pixels, self.max_obj_n) + + +class VOSTrain(Dataset): + def __init__(self, + image_root, + label_root, + imglistdic, + transform=None, + rgb=True, + repeat_time=1, + rand_gap=3, + seq_len=5, + rand_reverse=True, + dynamic_merge=True, + enable_prev_frame=False, + merge_prob=0.3, + max_obj_n=10): + self.image_root = image_root + self.label_root = label_root + self.rand_gap = rand_gap + self.seq_len = seq_len + self.rand_reverse = rand_reverse + self.repeat_time = repeat_time + self.transform = transform + self.dynamic_merge = dynamic_merge + self.merge_prob = merge_prob + self.enable_prev_frame = enable_prev_frame + self.max_obj_n = max_obj_n + self.rgb = rgb + self.imglistdic = imglistdic + self.seqs = list(self.imglistdic.keys()) + print('Video Num: {} X {}'.format(len(self.seqs), self.repeat_time)) + + def __len__(self): + return int(len(self.seqs) * self.repeat_time) + + def reverse_seq(self, imagelist, lablist): + if np.random.randint(2) == 1: + imagelist = imagelist[::-1] + lablist = lablist[::-1] + return imagelist, lablist + + def get_ref_index(self, + seqname, + lablist, + objs, + min_fg_pixels=200, + max_try=5): + bad_indices = [] + for _ in range(max_try): + ref_index = np.random.randint(len(lablist)) + if ref_index in bad_indices: + continue + ref_label = Image.open( + os.path.join(self.label_root, seqname, lablist[ref_index])) + ref_label = np.array(ref_label, dtype=np.uint8) + ref_objs = list(np.unique(ref_label)) + is_consistent = True + for obj in ref_objs: + if obj == 0: + continue + if obj not in objs: + is_consistent = False + xs, ys = np.nonzero(ref_label) + if len(xs) > min_fg_pixels and is_consistent: + break + bad_indices.append(ref_index) + return ref_index + + def get_ref_index_v2(self, + seqname, + lablist, + min_fg_pixels=200, + max_try=20, + total_gap=0): + search_range = len(lablist) - total_gap + if search_range <= 1: + return 0 + bad_indices = [] + for _ in range(max_try): + ref_index = np.random.randint(search_range) + if ref_index in bad_indices: + continue + ref_label = Image.open( + os.path.join(self.label_root, seqname, lablist[ref_index])) + ref_label = np.array(ref_label, dtype=np.uint8) + xs, ys = np.nonzero(ref_label) + if len(xs) > min_fg_pixels: + break + bad_indices.append(ref_index) + return ref_index + + def get_curr_gaps(self, seq_len, max_gap=999, max_try=10): + for _ in range(max_try): + curr_gaps = [] + total_gap = 0 + for _ in range(seq_len): + gap = int(np.random.randint(self.rand_gap) + 1) + total_gap += gap + curr_gaps.append(gap) + if total_gap <= max_gap: + break + return curr_gaps, total_gap + + def get_prev_index(self, lablist, total_gap): + search_range = len(lablist) - total_gap + if search_range > 1: + prev_index = np.random.randint(search_range) + else: + prev_index = 0 + return prev_index + + def check_index(self, total_len, index, allow_reflect=True): + if total_len <= 1: + return 0 + + if index < 0: + if allow_reflect: + index = -index + index = self.check_index(total_len, index, True) + else: + index = 0 + elif index >= total_len: + if allow_reflect: + index = 2 * (total_len - 1) - index + index = self.check_index(total_len, index, True) + else: + index = total_len - 1 + + return index + + def get_curr_indices(self, lablist, prev_index, gaps): + total_len = len(lablist) + curr_indices = [] + now_index = prev_index + for gap in gaps: + now_index += gap + curr_indices.append(self.check_index(total_len, now_index)) + return curr_indices + + def get_image_label(self, seqname, imagelist, lablist, index): + image = cv2.imread( + os.path.join(self.image_root, seqname, imagelist[index])) + image = np.array(image, dtype=np.float32) + + if self.rgb: + image = image[:, :, [2, 1, 0]] + + label = Image.open( + os.path.join(self.label_root, seqname, lablist[index])) + label = np.array(label, dtype=np.uint8) + + return image, label + + def sample_sequence(self, idx): + idx = idx % len(self.seqs) + seqname = self.seqs[idx] + imagelist, lablist = self.imglistdic[seqname] + frame_num = len(imagelist) + if self.rand_reverse: + imagelist, lablist = self.reverse_seq(imagelist, lablist) + + is_consistent = False + max_try = 5 + try_step = 0 + while (is_consistent is False and try_step < max_try): + try_step += 1 + + # generate random gaps + curr_gaps, total_gap = self.get_curr_gaps(self.seq_len - 1) + + if self.enable_prev_frame: # prev frame is randomly sampled + # get prev frame + prev_index = self.get_prev_index(lablist, total_gap) + prev_image, prev_label = self.get_image_label( + seqname, imagelist, lablist, prev_index) + prev_objs = list(np.unique(prev_label)) + + # get curr frames + curr_indices = self.get_curr_indices(lablist, prev_index, + curr_gaps) + curr_images, curr_labels, curr_objs = [], [], [] + for curr_index in curr_indices: + curr_image, curr_label = self.get_image_label( + seqname, imagelist, lablist, curr_index) + c_objs = list(np.unique(curr_label)) + curr_images.append(curr_image) + curr_labels.append(curr_label) + curr_objs.extend(c_objs) + + objs = list(np.unique(prev_objs + curr_objs)) + + start_index = prev_index + end_index = max(curr_indices) + # get ref frame + _try_step = 0 + ref_index = self.get_ref_index_v2(seqname, lablist) + while (ref_index > start_index and ref_index <= end_index + and _try_step < max_try): + _try_step += 1 + ref_index = self.get_ref_index_v2(seqname, lablist) + ref_image, ref_label = self.get_image_label( + seqname, imagelist, lablist, ref_index) + ref_objs = list(np.unique(ref_label)) + else: # prev frame is next to ref frame + # get ref frame + ref_index = self.get_ref_index_v2(seqname, lablist) + + ref_image, ref_label = self.get_image_label( + seqname, imagelist, lablist, ref_index) + ref_objs = list(np.unique(ref_label)) + + # get curr frames + curr_indices = self.get_curr_indices(lablist, ref_index, + curr_gaps) + curr_images, curr_labels, curr_objs = [], [], [] + for curr_index in curr_indices: + curr_image, curr_label = self.get_image_label( + seqname, imagelist, lablist, curr_index) + c_objs = list(np.unique(curr_label)) + curr_images.append(curr_image) + curr_labels.append(curr_label) + curr_objs.extend(c_objs) + + objs = list(np.unique(curr_objs)) + prev_image, prev_label = curr_images[0], curr_labels[0] + curr_images, curr_labels = curr_images[1:], curr_labels[1:] + + is_consistent = True + for obj in objs: + if obj == 0: + continue + if obj not in ref_objs: + is_consistent = False + break + + # get meta info + obj_num = list(np.sort(ref_objs))[-1] + + sample = { + 'ref_img': ref_image, + 'prev_img': prev_image, + 'curr_img': curr_images, + 'ref_label': ref_label, + 'prev_label': prev_label, + 'curr_label': curr_labels + } + sample['meta'] = { + 'seq_name': seqname, + 'frame_num': frame_num, + 'obj_num': obj_num + } + + if self.transform is not None: + sample = self.transform(sample) + + return sample + + def __getitem__(self, idx): + sample1 = self.sample_sequence(idx) + + if self.dynamic_merge and (sample1['meta']['obj_num'] == 0 + or random.random() < self.merge_prob): + rand_idx = np.random.randint(len(self.seqs)) + while (rand_idx == (idx % len(self.seqs))): + rand_idx = np.random.randint(len(self.seqs)) + + sample2 = self.sample_sequence(rand_idx) + + sample = self.merge_sample(sample1, sample2) + else: + sample = sample1 + + return sample + + def merge_sample(self, sample1, sample2, min_obj_pixels=100): + return _merge_sample(sample1, sample2, min_obj_pixels, self.max_obj_n) + + +class DAVIS2017_Train(VOSTrain): + def __init__(self, + split=['train'], + root='./DAVIS', + transform=None, + rgb=True, + repeat_time=1, + full_resolution=True, + year=2017, + rand_gap=3, + seq_len=5, + rand_reverse=True, + dynamic_merge=True, + enable_prev_frame=False, + max_obj_n=10, + merge_prob=0.3): + if full_resolution: + resolution = 'Full-Resolution' + if not os.path.exists(os.path.join(root, 'JPEGImages', + resolution)): + print('No Full-Resolution, use 480p instead.') + resolution = '480p' + else: + resolution = '480p' + image_root = os.path.join(root, 'JPEGImages', resolution) + label_root = os.path.join(root, 'Annotations', resolution) + seq_names = [] + for spt in split: + with open(os.path.join(root, 'ImageSets', str(year), + spt + '.txt')) as f: + seqs_tmp = f.readlines() + seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp)) + seq_names.extend(seqs_tmp) + imglistdic = {} + for seq_name in seq_names: + images = list( + np.sort(os.listdir(os.path.join(image_root, seq_name)))) + labels = list( + np.sort(os.listdir(os.path.join(label_root, seq_name)))) + imglistdic[seq_name] = (images, labels) + + super(DAVIS2017_Train, self).__init__(image_root, + label_root, + imglistdic, + transform, + rgb, + repeat_time, + rand_gap, + seq_len, + rand_reverse, + dynamic_merge, + enable_prev_frame, + merge_prob=merge_prob, + max_obj_n=max_obj_n) + + +class YOUTUBEVOS_Train(VOSTrain): + def __init__(self, + root='./datasets/YTB', + year=2019, + transform=None, + rgb=True, + rand_gap=3, + seq_len=3, + rand_reverse=True, + dynamic_merge=True, + enable_prev_frame=False, + max_obj_n=10, + merge_prob=0.3): + root = os.path.join(root, str(year), 'train') + image_root = os.path.join(root, 'JPEGImages') + label_root = os.path.join(root, 'Annotations') + self.seq_list_file = os.path.join(root, 'meta.json') + self._check_preprocess() + seq_names = list(self.ann_f.keys()) + + imglistdic = {} + for seq_name in seq_names: + data = self.ann_f[seq_name]['objects'] + obj_names = list(data.keys()) + images = [] + labels = [] + for obj_n in obj_names: + if len(data[obj_n]["frames"]) < 2: + print("Short object: " + seq_name + '-' + obj_n) + continue + images += list( + map(lambda x: x + '.jpg', list(data[obj_n]["frames"]))) + labels += list( + map(lambda x: x + '.png', list(data[obj_n]["frames"]))) + images = np.sort(np.unique(images)) + labels = np.sort(np.unique(labels)) + if len(images) < 2: + print("Short video: " + seq_name) + continue + imglistdic[seq_name] = (images, labels) + + super(YOUTUBEVOS_Train, self).__init__(image_root, + label_root, + imglistdic, + transform, + rgb, + 1, + rand_gap, + seq_len, + rand_reverse, + dynamic_merge, + enable_prev_frame, + merge_prob=merge_prob, + max_obj_n=max_obj_n) + + def _check_preprocess(self): + if not os.path.isfile(self.seq_list_file): + print('No such file: {}.'.format(self.seq_list_file)) + return False + else: + self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos'] + return True + + +class TEST(Dataset): + def __init__( + self, + seq_len=3, + obj_num=3, + transform=None, + ): + self.seq_len = seq_len + self.obj_num = obj_num + self.transform = transform + + def __len__(self): + return 3000 + + def __getitem__(self, idx): + img = np.zeros((800, 800, 3)).astype(np.float32) + label = np.ones((800, 800)).astype(np.uint8) + sample = { + 'ref_img': img, + 'prev_img': img, + 'curr_img': [img] * (self.seq_len - 2), + 'ref_label': label, + 'prev_label': label, + 'curr_label': [label] * (self.seq_len - 2) + } + sample['meta'] = { + 'seq_name': 'test', + 'frame_num': 100, + 'obj_num': self.obj_num + } + + if self.transform is not None: + sample = self.transform(sample) + return sample diff --git a/aot/dataloaders/video_transforms.py b/aot/dataloaders/video_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c3eaafbff595730b747d36170edee3ff5e16b6 --- /dev/null +++ b/aot/dataloaders/video_transforms.py @@ -0,0 +1,715 @@ +import random +import cv2 +import numpy as np +from PIL import Image + +import torch +import torchvision.transforms as TF +import dataloaders.image_transforms as IT + +cv2.setNumThreads(0) + + +class Resize(object): + """Rescale the image in a sample to a given size. + + Args: + output_size (tuple or int): Desired output size. If tuple, output is + matched to output_size. If int, smaller of image edges is matched + to output_size keeping aspect ratio the same. + """ + def __init__(self, output_size, use_padding=False): + assert isinstance(output_size, (int, tuple)) + if isinstance(output_size, int): + self.output_size = (output_size, output_size) + else: + self.output_size = output_size + self.use_padding = use_padding + + def __call__(self, sample): + return self.padding(sample) if self.use_padding else self.rescale( + sample) + + def rescale(self, sample): + prev_img = sample['prev_img'] + h, w = prev_img.shape[:2] + if self.output_size == (h, w): + return sample + else: + new_h, new_w = self.output_size + + for elem in sample.keys(): + if 'meta' in elem: + continue + tmp = sample[elem] + + if elem == 'prev_img' or elem == 'curr_img' or elem == 'ref_img': + flagval = cv2.INTER_CUBIC + else: + flagval = cv2.INTER_NEAREST + + if elem == 'curr_img' or elem == 'curr_label': + new_tmp = [] + all_tmp = tmp + for tmp in all_tmp: + tmp = cv2.resize(tmp, + dsize=(new_w, new_h), + interpolation=flagval) + new_tmp.append(tmp) + tmp = new_tmp + else: + tmp = cv2.resize(tmp, + dsize=(new_w, new_h), + interpolation=flagval) + + sample[elem] = tmp + + return sample + + def padding(self, sample): + prev_img = sample['prev_img'] + h, w = prev_img.shape[:2] + if self.output_size == (h, w): + return sample + else: + new_h, new_w = self.output_size + + def sep_pad(x): + x0 = np.random.randint(0, x + 1) + x1 = x - x0 + return x0, x1 + + top_pad, bottom_pad = sep_pad(new_h - h) + left_pad, right_pad = sep_pad(new_w - w) + + for elem in sample.keys(): + if 'meta' in elem: + continue + tmp = sample[elem] + + if elem == 'prev_img' or elem == 'curr_img' or elem == 'ref_img': + pad_value = (124, 116, 104) + else: + pad_value = (0) + + if elem == 'curr_img' or elem == 'curr_label': + new_tmp = [] + all_tmp = tmp + for tmp in all_tmp: + tmp = cv2.copyMakeBorder(tmp, + top_pad, + bottom_pad, + left_pad, + right_pad, + cv2.BORDER_CONSTANT, + value=pad_value) + new_tmp.append(tmp) + tmp = new_tmp + else: + tmp = cv2.copyMakeBorder(tmp, + top_pad, + bottom_pad, + left_pad, + right_pad, + cv2.BORDER_CONSTANT, + value=pad_value) + + sample[elem] = tmp + + return sample + + +class BalancedRandomCrop(object): + """Crop randomly the image in a sample. + + Args: + output_size (tuple or int): Desired output size. If int, square crop + is made. + """ + def __init__(self, + output_size, + max_step=5, + max_obj_num=5, + min_obj_pixel_num=100): + assert isinstance(output_size, (int, tuple)) + if isinstance(output_size, int): + self.output_size = (output_size, output_size) + else: + assert len(output_size) == 2 + self.output_size = output_size + self.max_step = max_step + self.max_obj_num = max_obj_num + self.min_obj_pixel_num = min_obj_pixel_num + + def __call__(self, sample): + + image = sample['prev_img'] + h, w = image.shape[:2] + new_h, new_w = self.output_size + new_h = h if new_h >= h else new_h + new_w = w if new_w >= w else new_w + ref_label = sample["ref_label"] + prev_label = sample["prev_label"] + curr_label = sample["curr_label"] + + is_contain_obj = False + step = 0 + while (not is_contain_obj) and (step < self.max_step): + step += 1 + top = np.random.randint(0, h - new_h + 1) + left = np.random.randint(0, w - new_w + 1) + after_crop = [] + contains = [] + for elem in ([ref_label, prev_label] + curr_label): + tmp = elem[top:top + new_h, left:left + new_w] + contains.append(np.unique(tmp)) + after_crop.append(tmp) + + all_obj = list(np.sort(contains[0])) + + if all_obj[-1] == 0: + continue + + # remove background + if all_obj[0] == 0: + all_obj = all_obj[1:] + + # remove small obj + new_all_obj = [] + for obj_id in all_obj: + after_crop_pixels = np.sum(after_crop[0] == obj_id) + if after_crop_pixels > self.min_obj_pixel_num: + new_all_obj.append(obj_id) + + if len(new_all_obj) == 0: + is_contain_obj = False + else: + is_contain_obj = True + + if len(new_all_obj) > self.max_obj_num: + random.shuffle(new_all_obj) + new_all_obj = new_all_obj[:self.max_obj_num] + + all_obj = [0] + new_all_obj + + post_process = [] + for elem in after_crop: + new_elem = elem * 0 + for idx in range(len(all_obj)): + obj_id = all_obj[idx] + if obj_id == 0: + continue + mask = elem == obj_id + + new_elem += (mask * idx).astype(np.uint8) + post_process.append(new_elem.astype(np.uint8)) + + sample["ref_label"] = post_process[0] + sample["prev_label"] = post_process[1] + curr_len = len(sample["curr_img"]) + sample["curr_label"] = [] + for idx in range(curr_len): + sample["curr_label"].append(post_process[idx + 2]) + + for elem in sample.keys(): + if 'meta' in elem or 'label' in elem: + continue + if elem == 'curr_img': + new_tmp = [] + for tmp_ in sample[elem]: + tmp_ = tmp_[top:top + new_h, left:left + new_w] + new_tmp.append(tmp_) + sample[elem] = new_tmp + else: + tmp = sample[elem] + tmp = tmp[top:top + new_h, left:left + new_w] + sample[elem] = tmp + + obj_num = len(all_obj) - 1 + + sample['meta']['obj_num'] = obj_num + + return sample + + +class RandomScale(object): + """Randomly resize the image and the ground truth to specified scales. + Args: + scales (list): the list of scales + """ + def __init__(self, min_scale=1., max_scale=1.3, short_edge=None): + self.min_scale = min_scale + self.max_scale = max_scale + self.short_edge = short_edge + + def __call__(self, sample): + # Fixed range of scales + sc = np.random.uniform(self.min_scale, self.max_scale) + # Align short edge + if self.short_edge is not None: + image = sample['prev_img'] + h, w = image.shape[:2] + if h > w: + sc *= float(self.short_edge) / w + else: + sc *= float(self.short_edge) / h + + for elem in sample.keys(): + if 'meta' in elem: + continue + tmp = sample[elem] + + if elem == 'prev_img' or elem == 'curr_img' or elem == 'ref_img': + flagval = cv2.INTER_CUBIC + else: + flagval = cv2.INTER_NEAREST + + if elem == 'curr_img' or elem == 'curr_label': + new_tmp = [] + for tmp_ in tmp: + tmp_ = cv2.resize(tmp_, + None, + fx=sc, + fy=sc, + interpolation=flagval) + new_tmp.append(tmp_) + tmp = new_tmp + else: + tmp = cv2.resize(tmp, + None, + fx=sc, + fy=sc, + interpolation=flagval) + + sample[elem] = tmp + + return sample + + +class RandomScaleV2(object): + """Randomly resize the image and the ground truth to specified scales. + Args: + scales (list): the list of scales + """ + def __init__(self, + min_scale=0.36, + max_scale=1.0, + short_edge=None, + ratio=[3. / 4., 4. / 3.]): + self.min_scale = min_scale + self.max_scale = max_scale + self.short_edge = short_edge + self.ratio = ratio + + def __call__(self, sample): + image = sample['prev_img'] + h, w = image.shape[:2] + + new_h, new_w = self.get_params(h, w) + + sc_x = float(new_w) / w + sc_y = float(new_h) / h + + # Align short edge + if not (self.short_edge is None): + if h > w: + sc_x *= float(self.short_edge) / w + sc_y *= float(self.short_edge) / w + else: + sc_x *= float(self.short_edge) / h + sc_y *= float(self.short_edge) / h + + for elem in sample.keys(): + if 'meta' in elem: + continue + tmp = sample[elem] + + if elem == 'prev_img' or elem == 'curr_img' or elem == 'ref_img': + flagval = cv2.INTER_CUBIC + else: + flagval = cv2.INTER_NEAREST + + if elem == 'curr_img' or elem == 'curr_label': + new_tmp = [] + for tmp_ in tmp: + tmp_ = cv2.resize(tmp_, + None, + fx=sc_x, + fy=sc_y, + interpolation=flagval) + new_tmp.append(tmp_) + tmp = new_tmp + else: + tmp = cv2.resize(tmp, + None, + fx=sc_x, + fy=sc_y, + interpolation=flagval) + + sample[elem] = tmp + + return sample + + def get_params(self, height, width): + area = height * width + + log_ratio = [np.log(item) for item in self.ratio] + for _ in range(10): + target_area = area * np.random.uniform(self.min_scale**2, + self.max_scale**2) + aspect_ratio = np.exp(np.random.uniform(log_ratio[0], + log_ratio[1])) + + w = int(round(np.sqrt(target_area * aspect_ratio))) + h = int(round(np.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + return h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(self.ratio): + w = width + h = int(round(w / min(self.ratio))) + elif in_ratio > max(self.ratio): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + + return h, w + +class RestrictSize(object): + """Randomly resize the image and the ground truth to specified scales. + Args: + scales (list): the list of scales + """ + def __init__(self, max_short_edge=None, max_long_edge=800 * 1.3): + self.max_short_edge = max_short_edge + self.max_long_edge = max_long_edge + assert ((max_short_edge is None)) or ((max_long_edge is None)) + + def __call__(self, sample): + + # Fixed range of scales + sc = None + image = sample['ref_img'] + h, w = image.shape[:2] + # Align short edge + if not (self.max_short_edge is None): + if h > w: + short_edge = w + else: + short_edge = h + if short_edge < self.max_short_edge: + sc = float(self.max_short_edge) / short_edge + else: + if h > w: + long_edge = h + else: + long_edge = w + if long_edge > self.max_long_edge: + sc = float(self.max_long_edge) / long_edge + + if sc is None: + new_h = h + new_w = w + else: + new_h = int(sc * h) + new_w = int(sc * w) + new_h = new_h - (new_h - 1) % 4 + new_w = new_w - (new_w - 1) % 4 + if new_h == h and new_w == w: + return sample + + for elem in sample.keys(): + if 'meta' in elem: + continue + tmp = sample[elem] + + if 'label' in elem: + flagval = cv2.INTER_NEAREST + else: + flagval = cv2.INTER_CUBIC + + tmp = cv2.resize(tmp, dsize=(new_w, new_h), interpolation=flagval) + + sample[elem] = tmp + + return sample + + +class RandomHorizontalFlip(object): + """Horizontally flip the given image and ground truth randomly with a probability of 0.5.""" + def __init__(self, prob): + self.p = prob + + def __call__(self, sample): + + if random.random() < self.p: + for elem in sample.keys(): + if 'meta' in elem: + continue + if elem == 'curr_img' or elem == 'curr_label': + new_tmp = [] + for tmp_ in sample[elem]: + tmp_ = cv2.flip(tmp_, flipCode=1) + new_tmp.append(tmp_) + sample[elem] = new_tmp + else: + tmp = sample[elem] + tmp = cv2.flip(tmp, flipCode=1) + sample[elem] = tmp + + return sample + + +class RandomVerticalFlip(object): + """Vertically flip the given image and ground truth randomly with a probability of 0.5.""" + def __init__(self, prob=0.3): + self.p = prob + + def __call__(self, sample): + + if random.random() < self.p: + for elem in sample.keys(): + if 'meta' in elem: + continue + if elem == 'curr_img' or elem == 'curr_label': + new_tmp = [] + for tmp_ in sample[elem]: + tmp_ = cv2.flip(tmp_, flipCode=0) + new_tmp.append(tmp_) + sample[elem] = new_tmp + else: + tmp = sample[elem] + tmp = cv2.flip(tmp, flipCode=0) + sample[elem] = tmp + + return sample + + +class RandomGaussianBlur(object): + def __init__(self, prob=0.3, sigma=[0.1, 2.]): + self.aug = TF.RandomApply([IT.GaussianBlur(sigma)], p=prob) + + def __call__(self, sample): + for elem in sample.keys(): + if 'meta' in elem or 'label' in elem: + continue + + if elem == 'curr_img': + new_tmp = [] + for tmp_ in sample[elem]: + tmp_ = self.apply_augmentation(tmp_) + new_tmp.append(tmp_) + sample[elem] = new_tmp + else: + tmp = sample[elem] + tmp = self.apply_augmentation(tmp) + sample[elem] = tmp + return sample + + def apply_augmentation(self, x): + x = Image.fromarray(np.uint8(x)) + x = self.aug(x) + x = np.array(x, dtype=np.float32) + return x + + +class RandomGrayScale(RandomGaussianBlur): + def __init__(self, prob=0.2): + self.aug = TF.RandomGrayscale(p=prob) + + +class RandomColorJitter(RandomGaussianBlur): + def __init__(self, + prob=0.8, + brightness=0.4, + contrast=0.4, + saturation=0.2, + hue=0.1): + self.aug = TF.RandomApply( + [TF.ColorJitter(brightness, contrast, saturation, hue)], p=prob) + + +class SubtractMeanImage(object): + def __init__(self, mean, change_channels=False): + self.mean = mean + self.change_channels = change_channels + + def __call__(self, sample): + for elem in sample.keys(): + if 'image' in elem: + if self.change_channels: + sample[elem] = sample[elem][:, :, [2, 1, 0]] + sample[elem] = np.subtract( + sample[elem], np.array(self.mean, dtype=np.float32)) + return sample + + def __str__(self): + return 'SubtractMeanImage' + str(self.mean) + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + def __call__(self, sample): + + for elem in sample.keys(): + if 'meta' in elem: + continue + tmp = sample[elem] + + if elem == 'curr_img' or elem == 'curr_label': + new_tmp = [] + for tmp_ in tmp: + if tmp_.ndim == 2: + tmp_ = tmp_[:, :, np.newaxis] + tmp_ = tmp_.transpose((2, 0, 1)) + new_tmp.append(torch.from_numpy(tmp_).int()) + else: + tmp_ = tmp_ / 255. + tmp_ -= (0.485, 0.456, 0.406) + tmp_ /= (0.229, 0.224, 0.225) + tmp_ = tmp_.transpose((2, 0, 1)) + new_tmp.append(torch.from_numpy(tmp_)) + tmp = new_tmp + else: + if tmp.ndim == 2: + tmp = tmp[:, :, np.newaxis] + tmp = tmp.transpose((2, 0, 1)) + tmp = torch.from_numpy(tmp).int() + else: + tmp = tmp / 255. + tmp -= (0.485, 0.456, 0.406) + tmp /= (0.229, 0.224, 0.225) + tmp = tmp.transpose((2, 0, 1)) + tmp = torch.from_numpy(tmp) + sample[elem] = tmp + + return sample + + +class MultiRestrictSize(object): + def __init__(self, + max_short_edge=None, + max_long_edge=800, + flip=False, + multi_scale=[1.3], + align_corners=True, + max_stride=16): + self.max_short_edge = max_short_edge + self.max_long_edge = max_long_edge + self.multi_scale = multi_scale + self.flip = flip + self.align_corners = align_corners + self.max_stride = max_stride + + def __call__(self, sample): + samples = [] + image = sample['current_img'] + h, w = image.shape[:2] + for scale in self.multi_scale: + # restrict short edge + sc = 1. + if self.max_short_edge is not None: + if h > w: + short_edge = w + else: + short_edge = h + if short_edge > self.max_short_edge: + sc *= float(self.max_short_edge) / short_edge + new_h, new_w = sc * h, sc * w + + # restrict long edge + sc = 1. + if self.max_long_edge is not None: + if new_h > new_w: + long_edge = new_h + else: + long_edge = new_w + if long_edge > self.max_long_edge: + sc *= float(self.max_long_edge) / long_edge + + new_h, new_w = sc * new_h, sc * new_w + + new_h = int(new_h * scale) + new_w = int(new_w * scale) + + if self.align_corners: + if (new_h - 1) % self.max_stride != 0: + new_h = int( + np.around((new_h - 1) / self.max_stride) * + self.max_stride + 1) + if (new_w - 1) % self.max_stride != 0: + new_w = int( + np.around((new_w - 1) / self.max_stride) * + self.max_stride + 1) + else: + if new_h % self.max_stride != 0: + new_h = int( + np.around(new_h / self.max_stride) * self.max_stride) + if new_w % self.max_stride != 0: + new_w = int( + np.around(new_w / self.max_stride) * self.max_stride) + + if new_h == h and new_w == w: + samples.append(sample) + else: + new_sample = {} + for elem in sample.keys(): + if 'meta' in elem: + new_sample[elem] = sample[elem] + continue + tmp = sample[elem] + if 'label' in elem: + new_sample[elem] = sample[elem] + continue + else: + flagval = cv2.INTER_CUBIC + tmp = cv2.resize(tmp, + dsize=(new_w, new_h), + interpolation=flagval) + new_sample[elem] = tmp + samples.append(new_sample) + + if self.flip: + now_sample = samples[-1] + new_sample = {} + for elem in now_sample.keys(): + if 'meta' in elem: + new_sample[elem] = now_sample[elem].copy() + new_sample[elem]['flip'] = True + continue + tmp = now_sample[elem] + tmp = tmp[:, ::-1].copy() + new_sample[elem] = tmp + samples.append(new_sample) + + return samples + + +class MultiToTensor(object): + def __call__(self, samples): + for idx in range(len(samples)): + sample = samples[idx] + for elem in sample.keys(): + if 'meta' in elem: + continue + tmp = sample[elem] + if tmp is None: + continue + + if tmp.ndim == 2: + tmp = tmp[:, :, np.newaxis] + tmp = tmp.transpose((2, 0, 1)) + samples[idx][elem] = torch.from_numpy(tmp).int() + else: + tmp = tmp / 255. + tmp -= (0.485, 0.456, 0.406) + tmp /= (0.229, 0.224, 0.225) + tmp = tmp.transpose((2, 0, 1)) + samples[idx][elem] = torch.from_numpy(tmp) + + return samples diff --git a/aot/datasets/.DS_Store b/aot/datasets/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..567774adc688ea1b1460d745f0a1886870056d26 Binary files /dev/null and b/aot/datasets/.DS_Store differ diff --git a/aot/datasets/DAVIS/README.md b/aot/datasets/DAVIS/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d1c18f61dc18ce6fd6d012e983d90253b139762d --- /dev/null +++ b/aot/datasets/DAVIS/README.md @@ -0,0 +1 @@ +Put DAVIS 2017 here. \ No newline at end of file diff --git a/aot/datasets/Static/README.md b/aot/datasets/Static/README.md new file mode 100644 index 0000000000000000000000000000000000000000..084f78a76a0b82ea59555db819e9ca4137e69d0a --- /dev/null +++ b/aot/datasets/Static/README.md @@ -0,0 +1 @@ +Put the static dataset here. Guidance can be found in [AFB-URR](https://github.com/xmlyqing00/AFB-URR), which we referred to in the implementation of the pre-training. diff --git a/aot/datasets/YTB/2018/train/README.md b/aot/datasets/YTB/2018/train/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b9bfd5ec65839444fc0536479ad05a93e5839ae6 --- /dev/null +++ b/aot/datasets/YTB/2018/train/README.md @@ -0,0 +1 @@ +Put the training split of YouTube-VOS 2018 here. \ No newline at end of file diff --git a/aot/datasets/YTB/2018/valid/README.md b/aot/datasets/YTB/2018/valid/README.md new file mode 100644 index 0000000000000000000000000000000000000000..37036488a307621c617aba2825c60ad3895fe4ba --- /dev/null +++ b/aot/datasets/YTB/2018/valid/README.md @@ -0,0 +1 @@ +Put the validation split of YouTube-VOS 2018 here. \ No newline at end of file diff --git a/aot/datasets/YTB/2018/valid_all_frames/README.md b/aot/datasets/YTB/2018/valid_all_frames/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7fc905e2d328ddba5c00915d2cae1a9f4175fe29 --- /dev/null +++ b/aot/datasets/YTB/2018/valid_all_frames/README.md @@ -0,0 +1 @@ +Put the all-frame validation split of YouTube-VOS 2018 here. \ No newline at end of file diff --git a/aot/datasets/YTB/2019/train/README.md b/aot/datasets/YTB/2019/train/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f816ba6895e0c7c3d4b4bb63b83ffb19c21c0ed4 --- /dev/null +++ b/aot/datasets/YTB/2019/train/README.md @@ -0,0 +1 @@ +Put the training split of YouTube-VOS 2019 here. \ No newline at end of file diff --git a/aot/datasets/YTB/2019/valid/README.md b/aot/datasets/YTB/2019/valid/README.md new file mode 100644 index 0000000000000000000000000000000000000000..445bb13cf1938869995902f580478daf7d20c364 --- /dev/null +++ b/aot/datasets/YTB/2019/valid/README.md @@ -0,0 +1 @@ +Put the validation split of YouTube-VOS 2019 here. \ No newline at end of file diff --git a/aot/datasets/YTB/2019/valid_all_frames/README.md b/aot/datasets/YTB/2019/valid_all_frames/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7fc905e2d328ddba5c00915d2cae1a9f4175fe29 --- /dev/null +++ b/aot/datasets/YTB/2019/valid_all_frames/README.md @@ -0,0 +1 @@ +Put the all-frame validation split of YouTube-VOS 2018 here. \ No newline at end of file diff --git a/aot/networks/.DS_Store b/aot/networks/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..996bc3ce8680c9ed43d713cdae21fba9989ff7b1 Binary files /dev/null and b/aot/networks/.DS_Store differ diff --git a/aot/networks/__init__.py b/aot/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/aot/networks/__pycache__/__init__.cpython-310.pyc b/aot/networks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5af2104dbcae7d846d18c525d23ae0cfd3bb4838 Binary files /dev/null and b/aot/networks/__pycache__/__init__.cpython-310.pyc differ diff --git a/aot/networks/decoders/__init__.py b/aot/networks/decoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5cd58d425f47367ef77c97e89a5f005cc2d2001 --- /dev/null +++ b/aot/networks/decoders/__init__.py @@ -0,0 +1,9 @@ +from networks.decoders.fpn import FPNSegmentationHead + + +def build_decoder(name, **kwargs): + + if name == 'fpn': + return FPNSegmentationHead(**kwargs) + else: + raise NotImplementedError diff --git a/aot/networks/decoders/__pycache__/__init__.cpython-310.pyc b/aot/networks/decoders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d714729edfbb5bc4894ef9df9c92cfae709c9c3 Binary files /dev/null and b/aot/networks/decoders/__pycache__/__init__.cpython-310.pyc differ diff --git a/aot/networks/decoders/__pycache__/fpn.cpython-310.pyc b/aot/networks/decoders/__pycache__/fpn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..615a95e08cb91ee917cba4b28874ec58d5fc5e42 Binary files /dev/null and b/aot/networks/decoders/__pycache__/fpn.cpython-310.pyc differ diff --git a/aot/networks/decoders/fpn.py b/aot/networks/decoders/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..4ce4d3f0a380078b646cc0b28ffd8e91b9810adb --- /dev/null +++ b/aot/networks/decoders/fpn.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from networks.layers.basic import ConvGN + + +class FPNSegmentationHead(nn.Module): + def __init__(self, + in_dim, + out_dim, + decode_intermediate_input=True, + hidden_dim=256, + shortcut_dims=[24, 32, 96, 1280], + align_corners=True): + super().__init__() + self.align_corners = align_corners + + self.decode_intermediate_input = decode_intermediate_input + + self.conv_in = ConvGN(in_dim, hidden_dim, 1) + + self.conv_16x = ConvGN(hidden_dim, hidden_dim, 3) + self.conv_8x = ConvGN(hidden_dim, hidden_dim // 2, 3) + self.conv_4x = ConvGN(hidden_dim // 2, hidden_dim // 2, 3) + + self.adapter_16x = nn.Conv2d(shortcut_dims[-2], hidden_dim, 1) + self.adapter_8x = nn.Conv2d(shortcut_dims[-3], hidden_dim, 1) + self.adapter_4x = nn.Conv2d(shortcut_dims[-4], hidden_dim // 2, 1) + + self.conv_out = nn.Conv2d(hidden_dim // 2, out_dim, 1) + + self._init_weight() + + def forward(self, inputs, shortcuts): + + if self.decode_intermediate_input: + x = torch.cat(inputs, dim=1) + else: + x = inputs[-1] + + x = F.relu_(self.conv_in(x)) + x = F.relu_(self.conv_16x(self.adapter_16x(shortcuts[-2]) + x)) + + x = F.interpolate(x, + size=shortcuts[-3].size()[-2:], + mode="bilinear", + align_corners=self.align_corners) + x = F.relu_(self.conv_8x(self.adapter_8x(shortcuts[-3]) + x)) + + x = F.interpolate(x, + size=shortcuts[-4].size()[-2:], + mode="bilinear", + align_corners=self.align_corners) + x = F.relu_(self.conv_4x(self.adapter_4x(shortcuts[-4]) + x)) + + x = self.conv_out(x) + + return x + + def _init_weight(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) diff --git a/aot/networks/encoders/.DS_Store b/aot/networks/encoders/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..a280bcfdd4d5bc576488337f5b69531bcbb7017d Binary files /dev/null and b/aot/networks/encoders/.DS_Store differ diff --git a/aot/networks/encoders/__init__.py b/aot/networks/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf240be6ec6c2ce2f6e1c81adb86af4196223c83 --- /dev/null +++ b/aot/networks/encoders/__init__.py @@ -0,0 +1,35 @@ +from networks.encoders.mobilenetv2 import MobileNetV2 +from networks.encoders.mobilenetv3 import MobileNetV3Large +from networks.encoders.resnet import ResNet101, ResNet50 +from networks.encoders.resnest import resnest +from networks.encoders.swin import build_swin_model +from networks.layers.normalization import FrozenBatchNorm2d +from torch import nn + + +def build_encoder(name, frozen_bn=True, freeze_at=-1): + if frozen_bn: + BatchNorm = FrozenBatchNorm2d + else: + BatchNorm = nn.BatchNorm2d + + if name == 'mobilenetv2': + return MobileNetV2(16, BatchNorm, freeze_at=freeze_at) + elif name == 'mobilenetv3': + return MobileNetV3Large(16, BatchNorm, freeze_at=freeze_at) + elif name == 'resnet50': + return ResNet50(16, BatchNorm, freeze_at=freeze_at) + elif name == 'resnet101': + return ResNet101(16, BatchNorm, freeze_at=freeze_at) + elif name == 'resnest50': + return resnest.resnest50(norm_layer=BatchNorm, + dilation=2, + freeze_at=freeze_at) + elif name == 'resnest101': + return resnest.resnest101(norm_layer=BatchNorm, + dilation=2, + freeze_at=freeze_at) + elif 'swin' in name: + return build_swin_model(name, freeze_at=freeze_at) + else: + raise NotImplementedError diff --git a/aot/networks/encoders/__pycache__/__init__.cpython-310.pyc b/aot/networks/encoders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfb209c7b7242cfc6bd3da15c7ab380d32256ed0 Binary files /dev/null and b/aot/networks/encoders/__pycache__/__init__.cpython-310.pyc differ diff --git a/aot/networks/encoders/__pycache__/mobilenetv2.cpython-310.pyc b/aot/networks/encoders/__pycache__/mobilenetv2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be704dc26b37b0991bb1ed6a95c950ad5a2491f6 Binary files /dev/null and b/aot/networks/encoders/__pycache__/mobilenetv2.cpython-310.pyc differ diff --git a/aot/networks/encoders/__pycache__/mobilenetv3.cpython-310.pyc b/aot/networks/encoders/__pycache__/mobilenetv3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d20927ce710409176e89809bfcb22df476b3c86f Binary files /dev/null and b/aot/networks/encoders/__pycache__/mobilenetv3.cpython-310.pyc differ diff --git a/aot/networks/encoders/__pycache__/resnet.cpython-310.pyc b/aot/networks/encoders/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d09423b9a7a64d02946bc3692ae290fabec6fd27 Binary files /dev/null and b/aot/networks/encoders/__pycache__/resnet.cpython-310.pyc differ diff --git a/aot/networks/encoders/mobilenetv2.py b/aot/networks/encoders/mobilenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff40d99f609c2c443110d02dd16875a9209b3b1 --- /dev/null +++ b/aot/networks/encoders/mobilenetv2.py @@ -0,0 +1,247 @@ +from torch import nn +from torch import Tensor +from typing import Callable, Optional, List +from utils.learning import freeze_params + +__all__ = ['MobileNetV2'] + + +def _make_divisible(v: float, + divisor: int, + min_value: Optional[int] = None) -> int: + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNActivation(nn.Sequential): + def __init__( + self, + in_planes: int, + out_planes: int, + kernel_size: int = 3, + stride: int = 1, + groups: int = 1, + padding: int = -1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation_layer: Optional[Callable[..., nn.Module]] = None, + dilation: int = 1, + ) -> None: + if padding == -1: + padding = (kernel_size - 1) // 2 * dilation + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if activation_layer is None: + activation_layer = nn.ReLU6 + super().__init__( + nn.Conv2d(in_planes, + out_planes, + kernel_size, + stride, + padding, + dilation=dilation, + groups=groups, + bias=False), norm_layer(out_planes), + activation_layer(inplace=True)) + self.out_channels = out_planes + + +# necessary for backwards compatibility +ConvBNReLU = ConvBNActivation + + +class InvertedResidual(nn.Module): + def __init__( + self, + inp: int, + oup: int, + stride: int, + dilation: int, + expand_ratio: int, + norm_layer: Optional[Callable[..., nn.Module]] = None) -> None: + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + self.kernel_size = 3 + self.dilation = dilation + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers: List[nn.Module] = [] + if expand_ratio != 1: + # pw + layers.append( + ConvBNReLU(inp, + hidden_dim, + kernel_size=1, + norm_layer=norm_layer)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, + hidden_dim, + stride=stride, + dilation=dilation, + groups=hidden_dim, + norm_layer=norm_layer), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ]) + self.conv = nn.Sequential(*layers) + self.out_channels = oup + self._is_cn = stride > 1 + + def forward(self, x: Tensor) -> Tensor: + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, + output_stride=8, + norm_layer: Optional[Callable[..., nn.Module]] = None, + width_mult: float = 1.0, + inverted_residual_setting: Optional[List[List[int]]] = None, + round_nearest: int = 8, + block: Optional[Callable[..., nn.Module]] = None, + freeze_at=0) -> None: + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + norm_layer: Module specifying the normalization layer to use + """ + super(MobileNetV2, self).__init__() + + if block is None: + block = InvertedResidual + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + last_channel = 1280 + input_channel = 32 + current_stride = 1 + rate = 1 + + if inverted_residual_setting is None: + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len( + inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format( + inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, + round_nearest) + self.last_channel = _make_divisible( + last_channel * max(1.0, width_mult), round_nearest) + features: List[nn.Module] = [ + ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer) + ] + current_stride *= 2 + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + if current_stride == output_stride: + stride = 1 + dilation = rate + rate *= s + else: + stride = s + dilation = 1 + current_stride *= s + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + if i == 0: + features.append( + block(input_channel, output_channel, stride, dilation, + t, norm_layer)) + else: + features.append( + block(input_channel, output_channel, 1, rate, t, + norm_layer)) + input_channel = output_channel + + # building last several layers + features.append( + ConvBNReLU(input_channel, + self.last_channel, + kernel_size=1, + norm_layer=norm_layer)) + # make it nn.Sequential + self.features = nn.Sequential(*features) + + self._initialize_weights() + + feature_4x = self.features[0:4] + feautre_8x = self.features[4:7] + feature_16x = self.features[7:14] + feature_32x = self.features[14:] + + self.stages = [feature_4x, feautre_8x, feature_16x, feature_32x] + + self.freeze(freeze_at) + + def forward(self, x): + xs = [] + for stage in self.stages: + x = stage(x) + xs.append(x) + return xs + + def _initialize_weights(self): + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def freeze(self, freeze_at): + if freeze_at >= 1: + for m in self.stages[0][0]: + freeze_params(m) + + for idx, stage in enumerate(self.stages, start=2): + if freeze_at >= idx: + freeze_params(stage) diff --git a/aot/networks/encoders/mobilenetv3.py b/aot/networks/encoders/mobilenetv3.py new file mode 100644 index 0000000000000000000000000000000000000000..47bd0db64f62f6c36d17ca7d8f586b86a01d35bb --- /dev/null +++ b/aot/networks/encoders/mobilenetv3.py @@ -0,0 +1,239 @@ +""" +Creates a MobileNetV3 Model as defined in: +Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019). +Searching for MobileNetV3 +arXiv preprint arXiv:1905.02244. +""" + +import torch.nn as nn +import math +from utils.learning import freeze_params + + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class h_sigmoid(nn.Module): + def __init__(self, inplace=True): + super(h_sigmoid, self).__init__() + self.relu = nn.ReLU6(inplace=inplace) + + def forward(self, x): + return self.relu(x + 3) / 6 + + +class h_swish(nn.Module): + def __init__(self, inplace=True): + super(h_swish, self).__init__() + self.sigmoid = h_sigmoid(inplace=inplace) + + def forward(self, x): + return x * self.sigmoid(x) + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=4): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, _make_divisible(channel // reduction, 8)), + nn.ReLU(inplace=True), + nn.Linear(_make_divisible(channel // reduction, 8), channel), + h_sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +def conv_3x3_bn(inp, oup, stride, norm_layer=nn.BatchNorm2d): + return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + norm_layer(oup), h_swish()) + + +def conv_1x1_bn(inp, oup, norm_layer=nn.BatchNorm2d): + return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + norm_layer(oup), h_swish()) + + +class InvertedResidual(nn.Module): + def __init__(self, + inp, + hidden_dim, + oup, + kernel_size, + stride, + use_se, + use_hs, + dilation=1, + norm_layer=nn.BatchNorm2d): + super(InvertedResidual, self).__init__() + assert stride in [1, 2] + + self.identity = stride == 1 and inp == oup + + if inp == hidden_dim: + self.conv = nn.Sequential( + # dw + nn.Conv2d(hidden_dim, + hidden_dim, + kernel_size, + stride, (kernel_size - 1) // 2 * dilation, + dilation=dilation, + groups=hidden_dim, + bias=False), + norm_layer(hidden_dim), + h_swish() if use_hs else nn.ReLU(inplace=True), + # Squeeze-and-Excite + SELayer(hidden_dim) if use_se else nn.Identity(), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ) + else: + self.conv = nn.Sequential( + # pw + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + norm_layer(hidden_dim), + h_swish() if use_hs else nn.ReLU(inplace=True), + # dw + nn.Conv2d(hidden_dim, + hidden_dim, + kernel_size, + stride, (kernel_size - 1) // 2 * dilation, + dilation=dilation, + groups=hidden_dim, + bias=False), + norm_layer(hidden_dim), + # Squeeze-and-Excite + SELayer(hidden_dim) if use_se else nn.Identity(), + h_swish() if use_hs else nn.ReLU(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ) + + def forward(self, x): + if self.identity: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV3Large(nn.Module): + def __init__(self, + output_stride=16, + norm_layer=nn.BatchNorm2d, + width_mult=1., + freeze_at=0): + super(MobileNetV3Large, self).__init__() + """ + Constructs a MobileNetV3-Large model + """ + cfgs = [ + # k, t, c, SE, HS, s + [3, 1, 16, 0, 0, 1], + [3, 4, 24, 0, 0, 2], + [3, 3, 24, 0, 0, 1], + [5, 3, 40, 1, 0, 2], + [5, 3, 40, 1, 0, 1], + [5, 3, 40, 1, 0, 1], + [3, 6, 80, 0, 1, 2], + [3, 2.5, 80, 0, 1, 1], + [3, 2.3, 80, 0, 1, 1], + [3, 2.3, 80, 0, 1, 1], + [3, 6, 112, 1, 1, 1], + [3, 6, 112, 1, 1, 1], + [5, 6, 160, 1, 1, 2], + [5, 6, 160, 1, 1, 1], + [5, 6, 160, 1, 1, 1] + ] + self.cfgs = cfgs + + # building first layer + input_channel = _make_divisible(16 * width_mult, 8) + layers = [conv_3x3_bn(3, input_channel, 2, norm_layer)] + # building inverted residual blocks + block = InvertedResidual + now_stride = 2 + rate = 1 + for k, t, c, use_se, use_hs, s in self.cfgs: + if now_stride == output_stride: + dilation = rate + rate *= s + s = 1 + else: + dilation = 1 + now_stride *= s + output_channel = _make_divisible(c * width_mult, 8) + exp_size = _make_divisible(input_channel * t, 8) + layers.append( + block(input_channel, exp_size, output_channel, k, s, use_se, + use_hs, dilation, norm_layer)) + input_channel = output_channel + + self.features = nn.Sequential(*layers) + self.conv = conv_1x1_bn(input_channel, exp_size, norm_layer) + # building last several layers + + self._initialize_weights() + + feature_4x = self.features[0:4] + feautre_8x = self.features[4:7] + feature_16x = self.features[7:13] + feature_32x = self.features[13:] + + self.stages = [feature_4x, feautre_8x, feature_16x, feature_32x] + + self.freeze(freeze_at) + + def forward(self, x): + xs = [] + for stage in self.stages: + x = stage(x) + xs.append(x) + xs[-1] = self.conv(xs[-1]) + return xs + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + def freeze(self, freeze_at): + if freeze_at >= 1: + for m in self.stages[0][0]: + freeze_params(m) + + for idx, stage in enumerate(self.stages, start=2): + if freeze_at >= idx: + freeze_params(stage) diff --git a/aot/networks/encoders/resnest/__init__.py b/aot/networks/encoders/resnest/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d46216bfb8fed8ea1f951bf23211d6e9d44b4148 --- /dev/null +++ b/aot/networks/encoders/resnest/__init__.py @@ -0,0 +1 @@ +from .resnest import * diff --git a/aot/networks/encoders/resnest/__pycache__/__init__.cpython-310.pyc b/aot/networks/encoders/resnest/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7b9419b37911f9f8c7506a08c1d8df2735b80ca Binary files /dev/null and b/aot/networks/encoders/resnest/__pycache__/__init__.cpython-310.pyc differ diff --git a/aot/networks/encoders/resnest/__pycache__/resnest.cpython-310.pyc b/aot/networks/encoders/resnest/__pycache__/resnest.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5460a7eb3bc32b3fbd957c82348c40fd446250aa Binary files /dev/null and b/aot/networks/encoders/resnest/__pycache__/resnest.cpython-310.pyc differ diff --git a/aot/networks/encoders/resnest/__pycache__/resnet.cpython-310.pyc b/aot/networks/encoders/resnest/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cf13000d2776daba04f72cbf8ff90dcb8bb9e55 Binary files /dev/null and b/aot/networks/encoders/resnest/__pycache__/resnet.cpython-310.pyc differ diff --git a/aot/networks/encoders/resnest/__pycache__/splat.cpython-310.pyc b/aot/networks/encoders/resnest/__pycache__/splat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95053fab15390df4a748539d2781cb73744917fc Binary files /dev/null and b/aot/networks/encoders/resnest/__pycache__/splat.cpython-310.pyc differ diff --git a/aot/networks/encoders/resnest/resnest.py b/aot/networks/encoders/resnest/resnest.py new file mode 100644 index 0000000000000000000000000000000000000000..b22600e8717fedaa8b95077c428eaac3ea300ba1 --- /dev/null +++ b/aot/networks/encoders/resnest/resnest.py @@ -0,0 +1,108 @@ +import torch +from .resnet import ResNet, Bottleneck + +__all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269'] + +_url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' + +_model_sha256 = { + name: checksum + for checksum, name in [ + ('528c19ca', 'resnest50'), + ('22405ba7', 'resnest101'), + ('75117900', 'resnest200'), + ('0cc87c48', 'resnest269'), + ] +} + + +def short_hash(name): + if name not in _model_sha256: + raise ValueError( + 'Pretrained model for {name} is not available.'.format(name=name)) + return _model_sha256[name][:8] + + +resnest_model_urls = { + name: _url_format.format(name, short_hash(name)) + for name in _model_sha256.keys() +} + + +def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): + model = ResNet(Bottleneck, [3, 4, 6, 3], + radix=2, + groups=1, + bottleneck_width=64, + deep_stem=True, + stem_width=32, + avg_down=True, + avd=True, + avd_first=False, + **kwargs) + if pretrained: + model.load_state_dict( + torch.hub.load_state_dict_from_url(resnest_model_urls['resnest50'], + progress=True, + check_hash=True)) + return model + + +def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): + model = ResNet(Bottleneck, [3, 4, 23, 3], + radix=2, + groups=1, + bottleneck_width=64, + deep_stem=True, + stem_width=64, + avg_down=True, + avd=True, + avd_first=False, + **kwargs) + if pretrained: + model.load_state_dict( + torch.hub.load_state_dict_from_url( + resnest_model_urls['resnest101'], + progress=True, + check_hash=True)) + return model + + +def resnest200(pretrained=False, root='~/.encoding/models', **kwargs): + model = ResNet(Bottleneck, [3, 24, 36, 3], + radix=2, + groups=1, + bottleneck_width=64, + deep_stem=True, + stem_width=64, + avg_down=True, + avd=True, + avd_first=False, + **kwargs) + if pretrained: + model.load_state_dict( + torch.hub.load_state_dict_from_url( + resnest_model_urls['resnest200'], + progress=True, + check_hash=True)) + return model + + +def resnest269(pretrained=False, root='~/.encoding/models', **kwargs): + model = ResNet(Bottleneck, [3, 30, 48, 8], + radix=2, + groups=1, + bottleneck_width=64, + deep_stem=True, + stem_width=64, + avg_down=True, + avd=True, + avd_first=False, + **kwargs) + if pretrained: + model.load_state_dict( + torch.hub.load_state_dict_from_url( + resnest_model_urls['resnest269'], + progress=True, + check_hash=True)) + return model diff --git a/aot/networks/encoders/resnest/resnet.py b/aot/networks/encoders/resnest/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ce8c99b3c236bfd6934599d30949bb92bb0124d7 --- /dev/null +++ b/aot/networks/encoders/resnest/resnet.py @@ -0,0 +1,444 @@ +import math +import torch.nn as nn + +from .splat import SplAtConv2d, DropBlock2D +from utils.learning import freeze_params + +__all__ = ['ResNet', 'Bottleneck'] + +_url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' + +_model_sha256 = {name: checksum for checksum, name in []} + + +def short_hash(name): + if name not in _model_sha256: + raise ValueError( + 'Pretrained model for {name} is not available.'.format(name=name)) + return _model_sha256[name][:8] + + +resnest_model_urls = { + name: _url_format.format(name, short_hash(name)) + for name in _model_sha256.keys() +} + + +class GlobalAvgPool2d(nn.Module): + def __init__(self): + """Global average pooling over the input's spatial dimensions""" + super(GlobalAvgPool2d, self).__init__() + + def forward(self, inputs): + return nn.functional.adaptive_avg_pool2d(inputs, + 1).view(inputs.size(0), -1) + + +class Bottleneck(nn.Module): + """ResNet Bottleneck + """ + # pylint: disable=unused-argument + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + radix=1, + cardinality=1, + bottleneck_width=64, + avd=False, + avd_first=False, + dilation=1, + is_first=False, + rectified_conv=False, + rectify_avg=False, + norm_layer=None, + dropblock_prob=0.0, + last_gamma=False): + super(Bottleneck, self).__init__() + group_width = int(planes * (bottleneck_width / 64.)) * cardinality + self.conv1 = nn.Conv2d(inplanes, + group_width, + kernel_size=1, + bias=False) + self.bn1 = norm_layer(group_width) + self.dropblock_prob = dropblock_prob + self.radix = radix + self.avd = avd and (stride > 1 or is_first) + self.avd_first = avd_first + + if self.avd: + self.avd_layer = nn.AvgPool2d(3, stride, padding=1) + stride = 1 + + if dropblock_prob > 0.0: + self.dropblock1 = DropBlock2D(dropblock_prob, 3) + if radix == 1: + self.dropblock2 = DropBlock2D(dropblock_prob, 3) + self.dropblock3 = DropBlock2D(dropblock_prob, 3) + + if radix >= 1: + self.conv2 = SplAtConv2d(group_width, + group_width, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=cardinality, + bias=False, + radix=radix, + rectify=rectified_conv, + rectify_avg=rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + elif rectified_conv: + from rfconv import RFConv2d + self.conv2 = RFConv2d(group_width, + group_width, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=cardinality, + bias=False, + average_mode=rectify_avg) + self.bn2 = norm_layer(group_width) + else: + self.conv2 = nn.Conv2d(group_width, + group_width, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=cardinality, + bias=False) + self.bn2 = norm_layer(group_width) + + self.conv3 = nn.Conv2d(group_width, + planes * 4, + kernel_size=1, + bias=False) + self.bn3 = norm_layer(planes * 4) + + if last_gamma: + from torch.nn.init import zeros_ + zeros_(self.bn3.weight) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.dilation = dilation + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + if self.dropblock_prob > 0.0: + out = self.dropblock1(out) + out = self.relu(out) + + if self.avd and self.avd_first: + out = self.avd_layer(out) + + out = self.conv2(out) + if self.radix == 0: + out = self.bn2(out) + if self.dropblock_prob > 0.0: + out = self.dropblock2(out) + out = self.relu(out) + + if self.avd and not self.avd_first: + out = self.avd_layer(out) + + out = self.conv3(out) + out = self.bn3(out) + if self.dropblock_prob > 0.0: + out = self.dropblock3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + """ResNet Variants + Parameters + ---------- + block : Block + Class for the residual block. Options are BasicBlockV1, BottleneckV1. + layers : list of int + Numbers of layers in each block + classes : int, default 1000 + Number of classification classes. + dilated : bool, default False + Applying dilation strategy to pretrained ResNet yielding a stride-8 model, + typically used in Semantic Segmentation. + norm_layer : object + Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; + for Synchronized Cross-GPU BachNormalization). + Reference: + - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. + - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." + """ + + # pylint: disable=unused-variable + def __init__(self, + block, + layers, + radix=1, + groups=1, + bottleneck_width=64, + num_classes=1000, + dilated=False, + dilation=1, + deep_stem=False, + stem_width=64, + avg_down=False, + rectified_conv=False, + rectify_avg=False, + avd=False, + avd_first=False, + final_drop=0.0, + dropblock_prob=0, + last_gamma=False, + norm_layer=nn.BatchNorm2d, + freeze_at=0): + self.cardinality = groups + self.bottleneck_width = bottleneck_width + # ResNet-D params + self.inplanes = stem_width * 2 if deep_stem else 64 + self.avg_down = avg_down + self.last_gamma = last_gamma + # ResNeSt params + self.radix = radix + self.avd = avd + self.avd_first = avd_first + + super(ResNet, self).__init__() + self.rectified_conv = rectified_conv + self.rectify_avg = rectify_avg + if rectified_conv: + from rfconv import RFConv2d + conv_layer = RFConv2d + else: + conv_layer = nn.Conv2d + conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {} + if deep_stem: + self.conv1 = nn.Sequential( + conv_layer(3, + stem_width, + kernel_size=3, + stride=2, + padding=1, + bias=False, + **conv_kwargs), + norm_layer(stem_width), + nn.ReLU(inplace=True), + conv_layer(stem_width, + stem_width, + kernel_size=3, + stride=1, + padding=1, + bias=False, + **conv_kwargs), + norm_layer(stem_width), + nn.ReLU(inplace=True), + conv_layer(stem_width, + stem_width * 2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + **conv_kwargs), + ) + else: + self.conv1 = conv_layer(3, + 64, + kernel_size=7, + stride=2, + padding=3, + bias=False, + **conv_kwargs) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, + 64, + layers[0], + norm_layer=norm_layer, + is_first=False) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + norm_layer=norm_layer) + if dilated or dilation == 4: + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=1, + dilation=2, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + elif dilation == 2: + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilation=1, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + else: + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob) + + self.stem = [self.conv1, self.bn1] + self.stages = [self.layer1, self.layer2, self.layer3] + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, norm_layer): + m.weight.data.fill_(1) + m.bias.data.zero_() + + self.freeze(freeze_at) + + def _make_layer(self, + block, + planes, + blocks, + stride=1, + dilation=1, + norm_layer=None, + dropblock_prob=0.0, + is_first=True): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + down_layers = [] + if self.avg_down: + if dilation == 1: + down_layers.append( + nn.AvgPool2d(kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + else: + down_layers.append( + nn.AvgPool2d(kernel_size=1, + stride=1, + ceil_mode=True, + count_include_pad=False)) + down_layers.append( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=1, + bias=False)) + else: + down_layers.append( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False)) + down_layers.append(norm_layer(planes * block.expansion)) + downsample = nn.Sequential(*down_layers) + + layers = [] + if dilation == 1 or dilation == 2: + layers.append( + block(self.inplanes, + planes, + stride, + downsample=downsample, + radix=self.radix, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, + avd_first=self.avd_first, + dilation=1, + is_first=is_first, + rectified_conv=self.rectified_conv, + rectify_avg=self.rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob, + last_gamma=self.last_gamma)) + elif dilation == 4: + layers.append( + block(self.inplanes, + planes, + stride, + downsample=downsample, + radix=self.radix, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, + avd_first=self.avd_first, + dilation=2, + is_first=is_first, + rectified_conv=self.rectified_conv, + rectify_avg=self.rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob, + last_gamma=self.last_gamma)) + else: + raise RuntimeError("=> unknown dilation size: {}".format(dilation)) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + radix=self.radix, + cardinality=self.cardinality, + bottleneck_width=self.bottleneck_width, + avd=self.avd, + avd_first=self.avd_first, + dilation=dilation, + rectified_conv=self.rectified_conv, + rectify_avg=self.rectify_avg, + norm_layer=norm_layer, + dropblock_prob=dropblock_prob, + last_gamma=self.last_gamma)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + xs = [] + + x = self.layer1(x) + xs.append(x) # 4X + x = self.layer2(x) + xs.append(x) # 8X + x = self.layer3(x) + xs.append(x) # 16X + # Following STMVOS, we drop stage 5. + xs.append(x) # 16X + + return xs + + def freeze(self, freeze_at): + if freeze_at >= 1: + for m in self.stem: + freeze_params(m) + + for idx, stage in enumerate(self.stages, start=2): + if freeze_at >= idx: + freeze_params(stage) diff --git a/aot/networks/encoders/resnest/splat.py b/aot/networks/encoders/resnest/splat.py new file mode 100644 index 0000000000000000000000000000000000000000..147d684332e378ac390e2be2cfed2daf9a94ad87 --- /dev/null +++ b/aot/networks/encoders/resnest/splat.py @@ -0,0 +1,132 @@ +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn import Conv2d, Module, ReLU +from torch.nn.modules.utils import _pair + +__all__ = ['SplAtConv2d', 'DropBlock2D'] + + +class DropBlock2D(object): + def __init__(self, *args, **kwargs): + raise NotImplementedError + + +class SplAtConv2d(Module): + """Split-Attention Conv2d + """ + def __init__(self, + in_channels, + channels, + kernel_size, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + bias=True, + radix=2, + reduction_factor=4, + rectify=False, + rectify_avg=False, + norm_layer=None, + dropblock_prob=0.0, + **kwargs): + super(SplAtConv2d, self).__init__() + padding = _pair(padding) + self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) + self.rectify_avg = rectify_avg + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.cardinality = groups + self.channels = channels + self.dropblock_prob = dropblock_prob + if self.rectify: + from rfconv import RFConv2d + self.conv = RFConv2d(in_channels, + channels * radix, + kernel_size, + stride, + padding, + dilation, + groups=groups * radix, + bias=bias, + average_mode=rectify_avg, + **kwargs) + else: + self.conv = Conv2d(in_channels, + channels * radix, + kernel_size, + stride, + padding, + dilation, + groups=groups * radix, + bias=bias, + **kwargs) + self.use_bn = norm_layer is not None + if self.use_bn: + self.bn0 = norm_layer(channels * radix) + self.relu = ReLU(inplace=True) + self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) + if self.use_bn: + self.bn1 = norm_layer(inter_channels) + self.fc2 = Conv2d(inter_channels, + channels * radix, + 1, + groups=self.cardinality) + if dropblock_prob > 0.0: + self.dropblock = DropBlock2D(dropblock_prob, 3) + self.rsoftmax = rSoftMax(radix, groups) + + def forward(self, x): + x = self.conv(x) + if self.use_bn: + x = self.bn0(x) + if self.dropblock_prob > 0.0: + x = self.dropblock(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + if self.radix > 1: + if torch.__version__ < '1.5': + splited = torch.split(x, int(rchannel // self.radix), dim=1) + else: + splited = torch.split(x, rchannel // self.radix, dim=1) + gap = sum(splited) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + if self.use_bn: + gap = self.bn1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + if torch.__version__ < '1.5': + attens = torch.split(atten, int(rchannel // self.radix), dim=1) + else: + attens = torch.split(atten, rchannel // self.radix, dim=1) + out = sum([att * split for (att, split) in zip(attens, splited)]) + else: + out = atten * x + return out.contiguous() + + +class rSoftMax(nn.Module): + def __init__(self, radix, cardinality): + super().__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x diff --git a/aot/networks/encoders/resnet.py b/aot/networks/encoders/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0ba2845e98990b3baadb02dff820f9da6d0e8d37 --- /dev/null +++ b/aot/networks/encoders/resnet.py @@ -0,0 +1,208 @@ +import math +import torch.nn as nn +from utils.learning import freeze_params + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + BatchNorm=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + dilation=dilation, + padding=dilation, + bias=False) + self.bn2 = BatchNorm(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = BatchNorm(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__(self, block, layers, output_stride, BatchNorm, freeze_at=0): + self.inplanes = 64 + super(ResNet, self).__init__() + + if output_stride == 16: + strides = [1, 2, 2, 1] + dilations = [1, 1, 1, 2] + elif output_stride == 8: + strides = [1, 2, 1, 1] + dilations = [1, 1, 2, 4] + else: + raise NotImplementedError + + # Modules + self.conv1 = nn.Conv2d(3, + 64, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.bn1 = BatchNorm(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, + 64, + layers[0], + stride=strides[0], + dilation=dilations[0], + BatchNorm=BatchNorm) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=strides[1], + dilation=dilations[1], + BatchNorm=BatchNorm) + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=strides[2], + dilation=dilations[2], + BatchNorm=BatchNorm) + # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) + + self.stem = [self.conv1, self.bn1] + self.stages = [self.layer1, self.layer2, self.layer3] + + self._init_weight() + self.freeze(freeze_at) + + def _make_layer(self, + block, + planes, + blocks, + stride=1, + dilation=1, + BatchNorm=None): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + BatchNorm(planes * block.expansion), + ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, max(dilation // 2, 1), + downsample, BatchNorm)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + dilation=dilation, + BatchNorm=BatchNorm)) + + return nn.Sequential(*layers) + + def forward(self, input): + x = self.conv1(input) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + xs = [] + + x = self.layer1(x) + xs.append(x) # 4X + x = self.layer2(x) + xs.append(x) # 8X + x = self.layer3(x) + xs.append(x) # 16X + # Following STMVOS, we drop stage 5. + xs.append(x) # 16X + + return xs + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def freeze(self, freeze_at): + if freeze_at >= 1: + for m in self.stem: + freeze_params(m) + + for idx, stage in enumerate(self.stages, start=2): + if freeze_at >= idx: + freeze_params(stage) + + +def ResNet50(output_stride, BatchNorm, freeze_at=0): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], + output_stride, + BatchNorm, + freeze_at=freeze_at) + return model + + +def ResNet101(output_stride, BatchNorm, freeze_at=0): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], + output_stride, + BatchNorm, + freeze_at=freeze_at) + return model + + +if __name__ == "__main__": + import torch + model = ResNet101(BatchNorm=nn.BatchNorm2d, output_stride=8) + input = torch.rand(1, 3, 512, 512) + output, low_level_feat = model(input) + print(output.size()) + print(low_level_feat.size()) diff --git a/aot/networks/encoders/swin/__init__.py b/aot/networks/encoders/swin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..828b95150cbc9f2cd42ead98a85d0370a62e8b9a --- /dev/null +++ b/aot/networks/encoders/swin/__init__.py @@ -0,0 +1 @@ +from .build import build_swin_model \ No newline at end of file diff --git a/aot/networks/encoders/swin/__pycache__/__init__.cpython-310.pyc b/aot/networks/encoders/swin/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cffc2aecddade40527129994a8a5ddd82fb1dc08 Binary files /dev/null and b/aot/networks/encoders/swin/__pycache__/__init__.cpython-310.pyc differ diff --git a/aot/networks/encoders/swin/__pycache__/build.cpython-310.pyc b/aot/networks/encoders/swin/__pycache__/build.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1365c80952b05630f749fd232d90c80f587b7d4 Binary files /dev/null and b/aot/networks/encoders/swin/__pycache__/build.cpython-310.pyc differ diff --git a/aot/networks/encoders/swin/__pycache__/swin_transformer.cpython-310.pyc b/aot/networks/encoders/swin/__pycache__/swin_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41a6268422d72b12ff43632a69b81aeea285344b Binary files /dev/null and b/aot/networks/encoders/swin/__pycache__/swin_transformer.cpython-310.pyc differ diff --git a/aot/networks/encoders/swin/build.py b/aot/networks/encoders/swin/build.py new file mode 100644 index 0000000000000000000000000000000000000000..4d832de035be764f7e3caac25ee595c36cacc575 --- /dev/null +++ b/aot/networks/encoders/swin/build.py @@ -0,0 +1,27 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- + +from .swin_transformer import SwinTransformer + + +def build_swin_model(model_type, freeze_at=0): + if model_type == 'swin_base': + model = SwinTransformer(embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=7, + drop_path_rate=0.3, + out_indices=(0, 1, 2), + ape=False, + patch_norm=True, + frozen_stages=freeze_at, + use_checkpoint=False) + + else: + raise NotImplementedError(f"Unkown model: {model_type}") + + return model diff --git a/aot/networks/encoders/swin/swin_transformer.py b/aot/networks/encoders/swin/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..743cfaae8888d919cd131b2e33736bafd0f53991 --- /dev/null +++ b/aot/networks/encoders/swin/swin_transformer.py @@ -0,0 +1,716 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- +from itertools import repeat +import collections.abc + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from networks.layers.basic import DropPath + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +def trunc_normal_(tensor, mean=0, std=1): + size = tensor.shape + tmp = tensor.new_empty(size + (4, )).normal_() + valid = (tmp < 2) & (tmp > -2) + ind = valid.max(-1, keepdim=True)[1] + tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) + tensor.data.mul_(std).add_(mean) + return tensor + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + def __init__(self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention(dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, + shifts=(-self.shift_size, -self.shift_size), + dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, + Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if + (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance( + drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + def __init__(self, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, + (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2), + frozen_stages=-1, + use_checkpoint=False): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) - 1 # remove the last stage + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1] + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], + patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if + (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop = nn.Identity() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + for block in m.blocks: + block.drop_path = nn.Identity() + block.attn.attn_drop = nn.Identity() + block.attn.proj_drop = nn.Identity() + for param in m.parameters(): + param.requires_grad = False + if m.downsample is not None: + for param in m.downsample.parameters(): + param.requires_grad = True + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + # logger = get_root_logger() + # load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, + size=(Wh, Ww), + mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, + 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + outs.append(outs[-1]) + + return outs diff --git a/aot/networks/engines/__init__.py b/aot/networks/engines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fcfa80a1ee8121eadcc2763869cab973207b5c41 --- /dev/null +++ b/aot/networks/engines/__init__.py @@ -0,0 +1,21 @@ +from networks.engines.aot_engine import AOTEngine, AOTInferEngine +from networks.engines.deaot_engine import DeAOTEngine, DeAOTInferEngine + + +def build_engine(name, phase='train', **kwargs): + if name == 'aotengine': + if phase == 'train': + return AOTEngine(**kwargs) + elif phase == 'eval': + return AOTInferEngine(**kwargs) + else: + raise NotImplementedError + elif name == 'deaotengine': + if phase == 'train': + return DeAOTEngine(**kwargs) + elif phase == 'eval': + return DeAOTInferEngine(**kwargs) + else: + raise NotImplementedError + else: + raise NotImplementedError diff --git a/aot/networks/engines/__pycache__/__init__.cpython-310.pyc b/aot/networks/engines/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bc9ec45737259ea08a74f1193b0e2c8dac95232 Binary files /dev/null and b/aot/networks/engines/__pycache__/__init__.cpython-310.pyc differ diff --git a/aot/networks/engines/__pycache__/aot_engine.cpython-310.pyc b/aot/networks/engines/__pycache__/aot_engine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efa1ef41b147d5d40a4b1eb255e11286fcb60f11 Binary files /dev/null and b/aot/networks/engines/__pycache__/aot_engine.cpython-310.pyc differ diff --git a/aot/networks/engines/__pycache__/deaot_engine.cpython-310.pyc b/aot/networks/engines/__pycache__/deaot_engine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b68bfb399802a150168403cbaeddfc87fe089ea Binary files /dev/null and b/aot/networks/engines/__pycache__/deaot_engine.cpython-310.pyc differ diff --git a/aot/networks/engines/aot_engine.py b/aot/networks/engines/aot_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..4474d8cef3828308ac2c42c4b15056bd84247247 --- /dev/null +++ b/aot/networks/engines/aot_engine.py @@ -0,0 +1,643 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + +from utils.math import generate_permute_matrix +from utils.image import one_hot_mask + +from networks.layers.basic import seq_to_2d + + +class AOTEngine(nn.Module): + def __init__(self, + aot_model, + gpu_id=0, + long_term_mem_gap=9999, + short_term_mem_skip=1, + max_len_long_term=9999): + super().__init__() + + self.cfg = aot_model.cfg + self.align_corners = aot_model.cfg.MODEL_ALIGN_CORNERS + self.AOT = aot_model + + self.max_obj_num = aot_model.max_obj_num + self.gpu_id = gpu_id + self.long_term_mem_gap = long_term_mem_gap + self.short_term_mem_skip = short_term_mem_skip + self.max_len_long_term = max_len_long_term + self.losses = None + + self.restart_engine() + + def forward(self, + all_frames, + all_masks, + batch_size, + obj_nums, + step=0, + tf_board=False, + use_prev_pred=False, + enable_prev_frame=False, + use_prev_prob=False): # only used for training + if self.losses is None: + self._init_losses() + + self.freeze_id = True if use_prev_pred else False + aux_weight = self.aux_weight * max(self.aux_step - step, + 0.) / self.aux_step + + self.offline_encoder(all_frames, all_masks) + + self.add_reference_frame(frame_step=0, obj_nums=obj_nums) + + grad_state = torch.no_grad if aux_weight == 0 else torch.enable_grad + with grad_state(): + ref_aux_loss, ref_aux_mask = self.generate_loss_mask( + self.offline_masks[self.frame_step], step) + + aux_losses = [ref_aux_loss] + aux_masks = [ref_aux_mask] + + curr_losses, curr_masks = [], [] + if enable_prev_frame: + self.set_prev_frame(frame_step=1) + with grad_state(): + prev_aux_loss, prev_aux_mask = self.generate_loss_mask( + self.offline_masks[self.frame_step], step) + aux_losses.append(prev_aux_loss) + aux_masks.append(prev_aux_mask) + else: + self.match_propogate_one_frame() + curr_loss, curr_mask, curr_prob = self.generate_loss_mask( + self.offline_masks[self.frame_step], step, return_prob=True) + self.update_short_term_memory( + curr_mask if not use_prev_prob else curr_prob, + None if use_prev_pred else self.assign_identity( + self.offline_one_hot_masks[self.frame_step])) + curr_losses.append(curr_loss) + curr_masks.append(curr_mask) + + self.match_propogate_one_frame() + curr_loss, curr_mask, curr_prob = self.generate_loss_mask( + self.offline_masks[self.frame_step], step, return_prob=True) + curr_losses.append(curr_loss) + curr_masks.append(curr_mask) + for _ in range(self.total_offline_frame_num - 3): + self.update_short_term_memory( + curr_mask if not use_prev_prob else curr_prob, + None if use_prev_pred else self.assign_identity( + self.offline_one_hot_masks[self.frame_step])) + self.match_propogate_one_frame() + curr_loss, curr_mask, curr_prob = self.generate_loss_mask( + self.offline_masks[self.frame_step], step, return_prob=True) + curr_losses.append(curr_loss) + curr_masks.append(curr_mask) + + aux_loss = torch.cat(aux_losses, dim=0).mean(dim=0) + pred_loss = torch.cat(curr_losses, dim=0).mean(dim=0) + + loss = aux_weight * aux_loss + pred_loss + + all_pred_mask = aux_masks + curr_masks + + all_frame_loss = aux_losses + curr_losses + + boards = {'image': {}, 'scalar': {}} + + return loss, all_pred_mask, all_frame_loss, boards + + def _init_losses(self): + cfg = self.cfg + + from networks.layers.loss import CrossEntropyLoss, SoftJaccordLoss + bce_loss = CrossEntropyLoss( + cfg.TRAIN_TOP_K_PERCENT_PIXELS, + cfg.TRAIN_HARD_MINING_RATIO * cfg.TRAIN_TOTAL_STEPS) + iou_loss = SoftJaccordLoss() + + losses = [bce_loss, iou_loss] + loss_weights = [0.5, 0.5] + + self.losses = nn.ModuleList(losses) + self.loss_weights = loss_weights + self.aux_weight = cfg.TRAIN_AUX_LOSS_WEIGHT + self.aux_step = cfg.TRAIN_TOTAL_STEPS * cfg.TRAIN_AUX_LOSS_RATIO + 1e-5 + + def encode_one_img_mask(self, img=None, mask=None, frame_step=-1): + if frame_step == -1: + frame_step = self.frame_step + + if self.enable_offline_enc: + curr_enc_embs = self.offline_enc_embs[frame_step] + elif img is None: + curr_enc_embs = None + else: + curr_enc_embs = self.AOT.encode_image(img) + + if mask is not None: + curr_one_hot_mask = one_hot_mask(mask, self.max_obj_num) + elif self.enable_offline_enc: + curr_one_hot_mask = self.offline_one_hot_masks[frame_step] + else: + curr_one_hot_mask = None + + return curr_enc_embs, curr_one_hot_mask + + def offline_encoder(self, all_frames, all_masks=None): + self.enable_offline_enc = True + self.offline_frames = all_frames.size(0) // self.batch_size + + # extract backbone features + self.offline_enc_embs = self.split_frames( + self.AOT.encode_image(all_frames), self.batch_size) + self.total_offline_frame_num = len(self.offline_enc_embs) + + if all_masks is not None: + # extract mask embeddings + offline_one_hot_masks = one_hot_mask(all_masks, self.max_obj_num) + self.offline_masks = list( + torch.split(all_masks, self.batch_size, dim=0)) + self.offline_one_hot_masks = list( + torch.split(offline_one_hot_masks, self.batch_size, dim=0)) + + if self.input_size_2d is None: + self.update_size(all_frames.size()[2:], + self.offline_enc_embs[0][-1].size()[2:]) + + def assign_identity(self, one_hot_mask): + if self.enable_id_shuffle: + one_hot_mask = torch.einsum('bohw,bot->bthw', one_hot_mask, + self.id_shuffle_matrix) + + id_emb = self.AOT.get_id_emb(one_hot_mask).view( + self.batch_size, -1, self.enc_hw).permute(2, 0, 1) + + if self.training and self.freeze_id: + id_emb = id_emb.detach() + + return id_emb + + def split_frames(self, xs, chunk_size): + new_xs = [] + for x in xs: + all_x = list(torch.split(x, chunk_size, dim=0)) + new_xs.append(all_x) + return list(zip(*new_xs)) + + def add_reference_frame(self, + img=None, + mask=None, + frame_step=-1, + obj_nums=None, + img_embs=None): + if self.obj_nums is None and obj_nums is None: + print('No objects for reference frame!') + exit() + elif obj_nums is not None: + self.obj_nums = obj_nums + + if frame_step == -1: + frame_step = self.frame_step + + if img_embs is None: + curr_enc_embs, curr_one_hot_mask = self.encode_one_img_mask( + img, mask, frame_step) + else: + _, curr_one_hot_mask = self.encode_one_img_mask( + None, mask, frame_step) + curr_enc_embs = img_embs + + if curr_enc_embs is None: + print('No image for reference frame!') + exit() + + if curr_one_hot_mask is None: + print('No mask for reference frame!') + exit() + + if self.input_size_2d is None: + self.update_size(img.size()[2:], curr_enc_embs[-1].size()[2:]) + + self.curr_enc_embs = curr_enc_embs + self.curr_one_hot_mask = curr_one_hot_mask + + if self.pos_emb is None: + self.pos_emb = self.AOT.get_pos_emb(curr_enc_embs[-1]).expand( + self.batch_size, -1, -1, + -1).view(self.batch_size, -1, self.enc_hw).permute(2, 0, 1) + + curr_id_emb = self.assign_identity(curr_one_hot_mask) + self.curr_id_embs = curr_id_emb + + # self matching and propagation + self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs, + None, + None, + curr_id_emb, + pos_emb=self.pos_emb, + size_2d=self.enc_size_2d) + + lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories = self.curr_lstt_output + + if self.long_term_memories is None: + self.long_term_memories = lstt_long_memories + else: + self.update_long_term_memory(lstt_long_memories) + + self.last_mem_step = self.frame_step + + self.short_term_memories_list = [lstt_short_memories] + self.short_term_memories = lstt_short_memories + + def set_prev_frame(self, img=None, mask=None, frame_step=1): + self.frame_step = frame_step + curr_enc_embs, curr_one_hot_mask = self.encode_one_img_mask( + img, mask, frame_step) + + if curr_enc_embs is None: + print('No image for previous frame!') + exit() + + if curr_one_hot_mask is None: + print('No mask for previous frame!') + exit() + + self.curr_enc_embs = curr_enc_embs + self.curr_one_hot_mask = curr_one_hot_mask + + curr_id_emb = self.assign_identity(curr_one_hot_mask) + self.curr_id_embs = curr_id_emb + + # self matching and propagation + self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs, + None, + None, + curr_id_emb, + pos_emb=self.pos_emb, + size_2d=self.enc_size_2d) + + lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories = self.curr_lstt_output + + if self.long_term_memories is None: + self.long_term_memories = lstt_long_memories + else: + self.update_long_term_memory(lstt_long_memories) + self.last_mem_step = frame_step + + self.short_term_memories_list = [lstt_short_memories] + self.short_term_memories = lstt_short_memories + + def update_long_term_memory(self, new_long_term_memories): + TOKEN_NUM = new_long_term_memories[0][0].shape[0] + if self.long_term_memories is None: + self.long_term_memories = new_long_term_memories + updated_long_term_memories = [] + for new_long_term_memory, last_long_term_memory in zip( + new_long_term_memories, self.long_term_memories): + updated_e = [] + for new_e, last_e in zip(new_long_term_memory, + last_long_term_memory): + if new_e is None or last_e is None: + updated_e.append(None) + else: + if last_e.shape[0] >= self.max_len_long_term * TOKEN_NUM: + last_e = last_e[:(self.max_len_long_term - 1) * TOKEN_NUM] + updated_e.append(torch.cat([new_e, last_e], dim=0)) + updated_long_term_memories.append(updated_e) + self.long_term_memories = updated_long_term_memories + + def update_short_term_memory(self, curr_mask, curr_id_emb=None, skip_long_term_update=False): + if curr_id_emb is None: + if len(curr_mask.size()) == 3 or curr_mask.size()[0] == 1: + curr_one_hot_mask = one_hot_mask(curr_mask, self.max_obj_num) + else: + curr_one_hot_mask = curr_mask + curr_id_emb = self.assign_identity(curr_one_hot_mask) + + lstt_curr_memories = self.curr_lstt_output[1] + lstt_curr_memories_2d = [] + for layer_idx in range(len(lstt_curr_memories)): + curr_k, curr_v = lstt_curr_memories[layer_idx][ + 0], lstt_curr_memories[layer_idx][1] + curr_k, curr_v = self.AOT.LSTT.layers[layer_idx].fuse_key_value_id( + curr_k, curr_v, curr_id_emb) + lstt_curr_memories[layer_idx][0], lstt_curr_memories[layer_idx][ + 1] = curr_k, curr_v + lstt_curr_memories_2d.append([ + seq_to_2d(lstt_curr_memories[layer_idx][0], self.enc_size_2d), + seq_to_2d(lstt_curr_memories[layer_idx][1], self.enc_size_2d) + ]) + + self.short_term_memories_list.append(lstt_curr_memories_2d) + self.short_term_memories_list = self.short_term_memories_list[ + -self.short_term_mem_skip:] + self.short_term_memories = self.short_term_memories_list[0] + + if self.frame_step - self.last_mem_step >= self.long_term_mem_gap: + # skip the update of long-term memory or not + if not skip_long_term_update: + self.update_long_term_memory(lstt_curr_memories) + self.last_mem_step = self.frame_step + + def match_propogate_one_frame(self, img=None, img_embs=None): + self.frame_step += 1 + if img_embs is None: + curr_enc_embs, _ = self.encode_one_img_mask( + img, None, self.frame_step) + else: + curr_enc_embs = img_embs + self.curr_enc_embs = curr_enc_embs + + self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs, + self.long_term_memories, + self.short_term_memories, + None, + pos_emb=self.pos_emb, + size_2d=self.enc_size_2d) + + def decode_current_logits(self, output_size=None): + curr_enc_embs = self.curr_enc_embs + curr_lstt_embs = self.curr_lstt_output[0] + + pred_id_logits = self.AOT.decode_id_logits(curr_lstt_embs, + curr_enc_embs) + + if self.enable_id_shuffle: # reverse shuffle + pred_id_logits = torch.einsum('bohw,bto->bthw', pred_id_logits, + self.id_shuffle_matrix) + + # remove unused identities + for batch_idx, obj_num in enumerate(self.obj_nums): + pred_id_logits[batch_idx, (obj_num+1):] = - \ + 1e+10 if pred_id_logits.dtype == torch.float32 else -1e+4 + + self.pred_id_logits = pred_id_logits + + if output_size is not None: + pred_id_logits = F.interpolate(pred_id_logits, + size=output_size, + mode="bilinear", + align_corners=self.align_corners) + + return pred_id_logits + + def predict_current_mask(self, output_size=None, return_prob=False): + if output_size is None: + output_size = self.input_size_2d + + pred_id_logits = F.interpolate(self.pred_id_logits, + size=output_size, + mode="bilinear", + align_corners=self.align_corners) + pred_mask = torch.argmax(pred_id_logits, dim=1) + + if not return_prob: + return pred_mask + else: + pred_prob = torch.softmax(pred_id_logits, dim=1) + return pred_mask, pred_prob + + def calculate_current_loss(self, gt_mask, step): + pred_id_logits = self.pred_id_logits + + pred_id_logits = F.interpolate(pred_id_logits, + size=gt_mask.size()[-2:], + mode="bilinear", + align_corners=self.align_corners) + + label_list = [] + logit_list = [] + for batch_idx, obj_num in enumerate(self.obj_nums): + now_label = gt_mask[batch_idx].long() + now_logit = pred_id_logits[batch_idx, :(obj_num + 1)].unsqueeze(0) + label_list.append(now_label.long()) + logit_list.append(now_logit) + + total_loss = 0 + for loss, loss_weight in zip(self.losses, self.loss_weights): + total_loss = total_loss + loss_weight * \ + loss(logit_list, label_list, step) + + return total_loss + + def generate_loss_mask(self, gt_mask, step, return_prob=False): + self.decode_current_logits() + loss = self.calculate_current_loss(gt_mask, step) + if return_prob: + mask, prob = self.predict_current_mask(return_prob=True) + return loss, mask, prob + else: + mask = self.predict_current_mask() + return loss, mask + + def keep_gt_mask(self, pred_mask, keep_prob=0.2): + pred_mask = pred_mask.float() + gt_mask = self.offline_masks[self.frame_step].float().squeeze(1) + + shape = [1 for _ in range(pred_mask.ndim)] + shape[0] = self.batch_size + random_tensor = keep_prob + torch.rand( + shape, dtype=pred_mask.dtype, device=pred_mask.device) + random_tensor.floor_() # binarize + + pred_mask = pred_mask * (1 - random_tensor) + gt_mask * random_tensor + + return pred_mask + + def restart_engine(self, batch_size=1, enable_id_shuffle=False): + + self.batch_size = batch_size + self.frame_step = 0 + self.last_mem_step = -1 + self.enable_id_shuffle = enable_id_shuffle + self.freeze_id = False + + self.obj_nums = None + self.pos_emb = None + self.enc_size_2d = None + self.enc_hw = None + self.input_size_2d = None + + self.long_term_memories = None + self.short_term_memories_list = [] + self.short_term_memories = None + + self.enable_offline_enc = False + self.offline_enc_embs = None + self.offline_one_hot_masks = None + self.offline_frames = -1 + self.total_offline_frame_num = 0 + + self.curr_enc_embs = None + self.curr_memories = None + self.curr_id_embs = None + + if enable_id_shuffle: + self.id_shuffle_matrix = generate_permute_matrix( + self.max_obj_num + 1, batch_size, gpu_id=self.gpu_id) + else: + self.id_shuffle_matrix = None + + def update_size(self, input_size, enc_size): + self.input_size_2d = input_size + self.enc_size_2d = enc_size + self.enc_hw = self.enc_size_2d[0] * self.enc_size_2d[1] + + +class AOTInferEngine(nn.Module): + def __init__(self, + aot_model, + gpu_id=0, + long_term_mem_gap=9999, + short_term_mem_skip=1, + max_aot_obj_num=None, + max_len_long_term=9999,): + super().__init__() + + self.cfg = aot_model.cfg + self.AOT = aot_model + + if max_aot_obj_num is None or max_aot_obj_num > aot_model.max_obj_num: + self.max_aot_obj_num = aot_model.max_obj_num + else: + self.max_aot_obj_num = max_aot_obj_num + + self.gpu_id = gpu_id + self.long_term_mem_gap = long_term_mem_gap + self.short_term_mem_skip = short_term_mem_skip + self.max_len_long_term = max_len_long_term + self.aot_engines = [] + + self.restart_engine() + def restart_engine(self): + del (self.aot_engines) + self.aot_engines = [] + self.obj_nums = None + + def separate_mask(self, mask, obj_nums): + if mask is None: + return [None] * len(self.aot_engines) + if len(self.aot_engines) == 1: + return [mask], [obj_nums] + + separated_obj_nums = [ + self.max_aot_obj_num for _ in range(len(self.aot_engines)) + ] + if obj_nums % self.max_aot_obj_num > 0: + separated_obj_nums[-1] = obj_nums % self.max_aot_obj_num + + if len(mask.size()) == 3 or mask.size()[0] == 1: + separated_masks = [] + for idx in range(len(self.aot_engines)): + start_id = idx * self.max_aot_obj_num + 1 + end_id = (idx + 1) * self.max_aot_obj_num + fg_mask = ((mask >= start_id) & (mask <= end_id)).float() + separated_mask = (fg_mask * mask - start_id + 1) * fg_mask + separated_masks.append(separated_mask) + return separated_masks, separated_obj_nums + else: + prob = mask + separated_probs = [] + for idx in range(len(self.aot_engines)): + start_id = idx * self.max_aot_obj_num + 1 + end_id = (idx + 1) * self.max_aot_obj_num + fg_prob = prob[start_id:(end_id + 1)] + bg_prob = 1. - torch.sum(fg_prob, dim=1, keepdim=True) + separated_probs.append(torch.cat([bg_prob, fg_prob], dim=1)) + return separated_probs, separated_obj_nums + + def min_logit_aggregation(self, all_logits): + if len(all_logits) == 1: + return all_logits[0] + + fg_logits = [] + bg_logits = [] + + for logit in all_logits: + bg_logits.append(logit[:, 0:1]) + fg_logits.append(logit[:, 1:1 + self.max_aot_obj_num]) + + bg_logit, _ = torch.min(torch.cat(bg_logits, dim=1), + dim=1, + keepdim=True) + merged_logit = torch.cat([bg_logit] + fg_logits, dim=1) + + return merged_logit + + def soft_logit_aggregation(self, all_logits): + if len(all_logits) == 1: + return all_logits[0] + + fg_probs = [] + bg_probs = [] + + for logit in all_logits: + prob = torch.softmax(logit, dim=1) + bg_probs.append(prob[:, 0:1]) + fg_probs.append(prob[:, 1:1 + self.max_aot_obj_num]) + + bg_prob = torch.prod(torch.cat(bg_probs, dim=1), dim=1, keepdim=True) + merged_prob = torch.cat([bg_prob] + fg_probs, + dim=1).clamp(1e-5, 1 - 1e-5) + merged_logit = torch.logit(merged_prob) + + return merged_logit + + def add_reference_frame(self, img, mask, obj_nums, frame_step=-1): + if isinstance(obj_nums, list): + obj_nums = obj_nums[0] + self.obj_nums = obj_nums + aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1) + while (aot_num > len(self.aot_engines)): + new_engine = AOTEngine(self.AOT, self.gpu_id, + self.long_term_mem_gap, + self.short_term_mem_skip, + self.max_len_long_term,) + new_engine.eval() + self.aot_engines.append(new_engine) + + separated_masks, separated_obj_nums = self.separate_mask( + mask, obj_nums) + img_embs = None + for aot_engine, separated_mask, separated_obj_num in zip( + self.aot_engines, separated_masks, separated_obj_nums): + aot_engine.add_reference_frame(img, + separated_mask, + obj_nums=[separated_obj_num], + frame_step=frame_step, + img_embs=img_embs) + + if img_embs is None: # reuse image embeddings + img_embs = aot_engine.curr_enc_embs + + self.update_size() + + def match_propogate_one_frame(self, img=None): + img_embs = None + for aot_engine in self.aot_engines: + aot_engine.match_propogate_one_frame(img, img_embs=img_embs) + if img_embs is None: # reuse image embeddings + img_embs = aot_engine.curr_enc_embs + + def decode_current_logits(self, output_size=None): + all_logits = [] + for aot_engine in self.aot_engines: + all_logits.append(aot_engine.decode_current_logits(output_size)) + pred_id_logits = self.soft_logit_aggregation(all_logits) + return pred_id_logits + + def update_memory(self, curr_mask, skip_long_term_update=False): + _curr_mask = F.interpolate(curr_mask,self.input_size_2d) + separated_masks, _ = self.separate_mask(_curr_mask, self.obj_nums) + for aot_engine, separated_mask in zip(self.aot_engines, + separated_masks): + aot_engine.update_short_term_memory(separated_mask, + skip_long_term_update=skip_long_term_update) + + def update_size(self): + self.input_size_2d = self.aot_engines[0].input_size_2d + self.enc_size_2d = self.aot_engines[0].enc_size_2d + self.enc_hw = self.aot_engines[0].enc_hw diff --git a/aot/networks/engines/deaot_engine.py b/aot/networks/engines/deaot_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..b27be6fba82b31f0ff5542ce703f31e5efb9e2d7 --- /dev/null +++ b/aot/networks/engines/deaot_engine.py @@ -0,0 +1,98 @@ +import numpy as np + +from utils.image import one_hot_mask + +from networks.layers.basic import seq_to_2d +from networks.engines.aot_engine import AOTEngine, AOTInferEngine + + +class DeAOTEngine(AOTEngine): + def __init__(self, + aot_model, + gpu_id=0, + long_term_mem_gap=9999, + short_term_mem_skip=1, + layer_loss_scaling_ratio=2., + max_len_long_term=9999): + super().__init__(aot_model, gpu_id, long_term_mem_gap, + short_term_mem_skip, max_len_long_term) + self.layer_loss_scaling_ratio = layer_loss_scaling_ratio + def update_short_term_memory(self, curr_mask, curr_id_emb=None, skip_long_term_update=False): + + if curr_id_emb is None: + if len(curr_mask.size()) == 3 or curr_mask.size()[0] == 1: + curr_one_hot_mask = one_hot_mask(curr_mask, self.max_obj_num) + else: + curr_one_hot_mask = curr_mask + curr_id_emb = self.assign_identity(curr_one_hot_mask) + + lstt_curr_memories = self.curr_lstt_output[1] + lstt_curr_memories_2d = [] + for layer_idx in range(len(lstt_curr_memories)): + curr_k, curr_v, curr_id_k, curr_id_v = lstt_curr_memories[ + layer_idx] + curr_id_k, curr_id_v = self.AOT.LSTT.layers[ + layer_idx].fuse_key_value_id(curr_id_k, curr_id_v, curr_id_emb) + lstt_curr_memories[layer_idx][2], lstt_curr_memories[layer_idx][ + 3] = curr_id_k, curr_id_v + local_curr_id_k = seq_to_2d( + curr_id_k, self.enc_size_2d) if curr_id_k is not None else None + local_curr_id_v = seq_to_2d(curr_id_v, self.enc_size_2d) + lstt_curr_memories_2d.append([ + seq_to_2d(curr_k, self.enc_size_2d), + seq_to_2d(curr_v, self.enc_size_2d), local_curr_id_k, + local_curr_id_v + ]) + + self.short_term_memories_list.append(lstt_curr_memories_2d) + self.short_term_memories_list = self.short_term_memories_list[ + -self.short_term_mem_skip:] + self.short_term_memories = self.short_term_memories_list[0] + + if self.frame_step - self.last_mem_step >= self.long_term_mem_gap: + # skip the update of long-term memory or not + if not skip_long_term_update: + self.update_long_term_memory(lstt_curr_memories) + self.last_mem_step = self.frame_step + + +class DeAOTInferEngine(AOTInferEngine): + def __init__(self, + aot_model, + gpu_id=0, + long_term_mem_gap=9999, + short_term_mem_skip=1, + max_aot_obj_num=None, + max_len_long_term=9999): + super().__init__(aot_model, gpu_id, long_term_mem_gap, + short_term_mem_skip, max_aot_obj_num, max_len_long_term) + def add_reference_frame(self, img, mask, obj_nums, frame_step=-1): + if isinstance(obj_nums, list): + obj_nums = obj_nums[0] + self.obj_nums = obj_nums + aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1) + while (aot_num > len(self.aot_engines)): + new_engine = DeAOTEngine(self.AOT, self.gpu_id, + self.long_term_mem_gap, + self.short_term_mem_skip, + max_len_long_term = self.max_len_long_term) + new_engine.eval() + self.aot_engines.append(new_engine) + + separated_masks, separated_obj_nums = self.separate_mask( + mask, obj_nums) + img_embs = None + for aot_engine, separated_mask, separated_obj_num in zip( + self.aot_engines, separated_masks, separated_obj_nums): + if aot_engine.obj_nums is None or aot_engine.obj_nums[0] < separated_obj_num: + aot_engine.add_reference_frame(img, + separated_mask, + obj_nums=[separated_obj_num], + frame_step=frame_step, + img_embs=img_embs) + else: + aot_engine.update_short_term_memory(separated_mask) + if img_embs is None: # reuse image embeddings + img_embs = aot_engine.curr_enc_embs + + self.update_size() diff --git a/aot/networks/layers/__init__.py b/aot/networks/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/aot/networks/layers/__pycache__/__init__.cpython-310.pyc b/aot/networks/layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1d1e485e74fafdbe882fe91c89d92547fd01319 Binary files /dev/null and b/aot/networks/layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/aot/networks/layers/__pycache__/attention.cpython-310.pyc b/aot/networks/layers/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f877a92c3e7e8109343ca5742d70102a21f6a04 Binary files /dev/null and b/aot/networks/layers/__pycache__/attention.cpython-310.pyc differ diff --git a/aot/networks/layers/__pycache__/basic.cpython-310.pyc b/aot/networks/layers/__pycache__/basic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd8fc4a22cb381e530ef8279abcfb18af3d32ed8 Binary files /dev/null and b/aot/networks/layers/__pycache__/basic.cpython-310.pyc differ diff --git a/aot/networks/layers/__pycache__/normalization.cpython-310.pyc b/aot/networks/layers/__pycache__/normalization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec832a9d81e5fe4aeb5379609d73f61855d364be Binary files /dev/null and b/aot/networks/layers/__pycache__/normalization.cpython-310.pyc differ diff --git a/aot/networks/layers/__pycache__/position.cpython-310.pyc b/aot/networks/layers/__pycache__/position.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6fb93d07a5611c23b9569700f6c6e39c8b44ca4 Binary files /dev/null and b/aot/networks/layers/__pycache__/position.cpython-310.pyc differ diff --git a/aot/networks/layers/__pycache__/transformer.cpython-310.pyc b/aot/networks/layers/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88fc5462c6c026fb54ef337a624135a91a44e20d Binary files /dev/null and b/aot/networks/layers/__pycache__/transformer.cpython-310.pyc differ diff --git a/aot/networks/layers/attention.py b/aot/networks/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2bd2598a7ca768c99187bbdacbecac8e3fbd3adb --- /dev/null +++ b/aot/networks/layers/attention.py @@ -0,0 +1,905 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from networks.layers.basic import DropOutLogit, ScaleOffset, DWConv2d + + +def multiply_by_ychunks(x, y, chunks=1): + if chunks <= 1: + return x @ y + else: + return torch.cat([x @ _y for _y in y.chunk(chunks, dim=-1)], dim=-1) + + +def multiply_by_xchunks(x, y, chunks=1): + if chunks <= 1: + return x @ y + else: + return torch.cat([_x @ y for _x in x.chunk(chunks, dim=-2)], dim=-2) + + +# Long-term attention +class MultiheadAttention(nn.Module): + def __init__(self, + d_model, + num_head=8, + dropout=0., + use_linear=True, + d_att=None, + use_dis=False, + qk_chunks=1, + max_mem_len_ratio=-1, + top_k=-1): + super().__init__() + self.d_model = d_model + self.num_head = num_head + self.use_dis = use_dis + self.qk_chunks = qk_chunks + self.max_mem_len_ratio = float(max_mem_len_ratio) + self.top_k = top_k + + self.hidden_dim = d_model // num_head + self.d_att = self.hidden_dim if d_att is None else d_att + self.T = self.d_att**0.5 + self.use_linear = use_linear + + if use_linear: + self.linear_Q = nn.Linear(d_model, d_model) + self.linear_K = nn.Linear(d_model, d_model) + self.linear_V = nn.Linear(d_model, d_model) + + self.dropout = nn.Dropout(dropout) + self.drop_prob = dropout + self.projection = nn.Linear(d_model, d_model) + self._init_weight() + + def forward(self, Q, K, V): + """ + :param Q: A 3d tensor with shape of [T_q, bs, C_q] + :param K: A 3d tensor with shape of [T_k, bs, C_k] + :param V: A 3d tensor with shape of [T_v, bs, C_v] + """ + num_head = self.num_head + hidden_dim = self.hidden_dim + + bs = Q.size()[1] + + # Linear projections + if self.use_linear: + Q = self.linear_Q(Q) + K = self.linear_K(K) + V = self.linear_V(V) + + # Scale + Q = Q / self.T + + if not self.training and self.max_mem_len_ratio > 0: + mem_len_ratio = float(K.size(0)) / Q.size(0) + if mem_len_ratio > self.max_mem_len_ratio: + scaling_ratio = math.log(mem_len_ratio) / math.log( + self.max_mem_len_ratio) + Q = Q * scaling_ratio + + # Multi-head + Q = Q.view(-1, bs, num_head, self.d_att).permute(1, 2, 0, 3) + K = K.view(-1, bs, num_head, self.d_att).permute(1, 2, 3, 0) + V = V.view(-1, bs, num_head, hidden_dim).permute(1, 2, 0, 3) + + # Multiplication + QK = multiply_by_ychunks(Q, K, self.qk_chunks) + if self.use_dis: + QK = 2 * QK - K.pow(2).sum(dim=-2, keepdim=True) + + # Activation + if not self.training and self.top_k > 0 and self.top_k < QK.size()[-1]: + top_QK, indices = torch.topk(QK, k=self.top_k, dim=-1) + top_attn = torch.softmax(top_QK, dim=-1) + attn = torch.zeros_like(QK).scatter_(-1, indices, top_attn) + else: + attn = torch.softmax(QK, dim=-1) + + # Dropouts + attn = self.dropout(attn) + + # Weighted sum + outputs = multiply_by_xchunks(attn, V, + self.qk_chunks).permute(2, 0, 1, 3) + + # Restore shape + outputs = outputs.reshape(-1, bs, self.d_model) + + outputs = self.projection(outputs) + + return outputs, attn + + def _init_weight(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + +# Short-term attention +class MultiheadLocalAttentionV1(nn.Module): + def __init__(self, + d_model, + num_head, + dropout=0., + max_dis=7, + dilation=1, + use_linear=True, + enable_corr=True): + super().__init__() + self.dilation = dilation + self.window_size = 2 * max_dis + 1 + self.max_dis = max_dis + self.num_head = num_head + self.T = ((d_model / num_head)**0.5) + + self.use_linear = use_linear + if use_linear: + self.linear_Q = nn.Conv2d(d_model, d_model, kernel_size=1) + self.linear_K = nn.Conv2d(d_model, d_model, kernel_size=1) + self.linear_V = nn.Conv2d(d_model, d_model, kernel_size=1) + + self.relative_emb_k = nn.Conv2d(d_model, + num_head * self.window_size * + self.window_size, + kernel_size=1, + groups=num_head) + self.relative_emb_v = nn.Parameter( + torch.zeros([ + self.num_head, d_model // self.num_head, + self.window_size * self.window_size + ])) + + self.enable_corr = enable_corr + + if enable_corr: + from spatial_correlation_sampler import SpatialCorrelationSampler + self.correlation_sampler = SpatialCorrelationSampler( + kernel_size=1, + patch_size=self.window_size, + stride=1, + padding=0, + dilation=1, + dilation_patch=self.dilation) + + self.projection = nn.Linear(d_model, d_model) + + self.dropout = nn.Dropout(dropout) + self.drop_prob = dropout + + def forward(self, q, k, v): + n, c, h, w = v.size() + + if self.use_linear: + q = self.linear_Q(q) + k = self.linear_K(k) + v = self.linear_V(v) + + hidden_dim = c // self.num_head + + relative_emb = self.relative_emb_k(q) + memory_mask = torch.ones((1, 1, h, w), device=v.device).float() + + # Scale + q = q / self.T + + q = q.view(-1, hidden_dim, h, w) + k = k.reshape(-1, hidden_dim, h, w).contiguous() + unfolded_vu = self.pad_and_unfold(v).view( + n, self.num_head, hidden_dim, self.window_size * self.window_size, + h * w) + self.relative_emb_v.unsqueeze(0).unsqueeze(-1) + + relative_emb = relative_emb.view(n, self.num_head, + self.window_size * self.window_size, + h * w) + unfolded_k_mask = self.pad_and_unfold(memory_mask).bool().view( + 1, 1, self.window_size * self.window_size, + h * w).expand(n, self.num_head, -1, -1) + + if self.enable_corr: + qk = self.correlation_sampler(q, k).view( + n, self.num_head, self.window_size * self.window_size, + h * w) + relative_emb + else: + unfolded_k = self.pad_and_unfold(k).view( + n * self.num_head, hidden_dim, + self.window_size * self.window_size, h, w) + qk = (q.unsqueeze(2) * unfolded_k).sum(dim=1).view( + n, self.num_head, self.window_size * self.window_size, + h * w) + relative_emb + + qk_mask = 1 - unfolded_k_mask + + qk -= qk_mask * 1e+8 if qk.dtype == torch.float32 else qk_mask * 1e+4 + + local_attn = torch.softmax(qk, dim=2) + + local_attn = self.dropout(local_attn) + + output = (local_attn.unsqueeze(2) * unfolded_vu).sum(dim=3).permute( + 3, 0, 1, 2).view(h * w, n, c) + + output = self.projection(output) + + return output, local_attn + + def pad_and_unfold(self, x): + pad_pixel = self.max_dis * self.dilation + x = F.pad(x, (pad_pixel, pad_pixel, pad_pixel, pad_pixel), + mode='constant', + value=0) + x = F.unfold(x, + kernel_size=(self.window_size, self.window_size), + stride=(1, 1), + dilation=self.dilation) + return x + + +class MultiheadLocalAttentionV2(nn.Module): + def __init__(self, + d_model, + num_head, + dropout=0., + max_dis=7, + dilation=1, + use_linear=True, + enable_corr=True, + d_att=None, + use_dis=False): + super().__init__() + self.dilation = dilation + self.window_size = 2 * max_dis + 1 + self.max_dis = max_dis + self.num_head = num_head + self.hidden_dim = d_model // num_head + self.d_att = self.hidden_dim if d_att is None else d_att + self.T = self.d_att**0.5 + self.use_dis = use_dis + + self.use_linear = use_linear + if use_linear: + self.linear_Q = nn.Conv2d(d_model, d_model, kernel_size=1) + self.linear_K = nn.Conv2d(d_model, d_model, kernel_size=1) + self.linear_V = nn.Conv2d(d_model, d_model, kernel_size=1) + + self.relative_emb_k = nn.Conv2d(self.d_att * self.num_head, + num_head * self.window_size * + self.window_size, + kernel_size=1, + groups=num_head) + self.relative_emb_v = nn.Parameter( + torch.zeros([ + self.num_head, d_model // self.num_head, + self.window_size * self.window_size + ])) + + self.enable_corr = enable_corr + + if enable_corr: + from spatial_correlation_sampler import SpatialCorrelationSampler + self.correlation_sampler = SpatialCorrelationSampler( + kernel_size=1, + patch_size=self.window_size, + stride=1, + padding=0, + dilation=1, + dilation_patch=self.dilation) + + self.projection = nn.Linear(d_model, d_model) + + self.dropout = nn.Dropout(dropout) + + self.drop_prob = dropout + + self.local_mask = None + self.last_size_2d = None + self.qk_mask = None + + def forward(self, q, k, v): + n, c, h, w = v.size() + + if self.use_linear: + q = self.linear_Q(q) + k = self.linear_K(k) + v = self.linear_V(v) + + hidden_dim = self.hidden_dim + + if self.qk_mask is not None and (h, w) == self.last_size_2d: + qk_mask = self.qk_mask + else: + memory_mask = torch.ones((1, 1, h, w), device=v.device).float() + unfolded_k_mask = self.pad_and_unfold(memory_mask).view( + 1, 1, self.window_size * self.window_size, h * w) + qk_mask = 1 - unfolded_k_mask + self.qk_mask = qk_mask + + relative_emb = self.relative_emb_k(q) + + # Scale + q = q / self.T + + q = q.view(-1, self.d_att, h, w) + k = k.view(-1, self.d_att, h, w) + v = v.view(-1, self.num_head, hidden_dim, h * w) + + relative_emb = relative_emb.view(n, self.num_head, + self.window_size * self.window_size, + h * w) + + if self.enable_corr: + qk = self.correlation_sampler(q, k).view( + n, self.num_head, self.window_size * self.window_size, h * w) + else: + unfolded_k = self.pad_and_unfold(k).view( + n * self.num_head, hidden_dim, + self.window_size * self.window_size, h, w) + qk = (q.unsqueeze(2) * unfolded_k).sum(dim=1).view( + n, self.num_head, self.window_size * self.window_size, h * w) + if self.use_dis: + qk = 2 * qk - self.pad_and_unfold( + k.pow(2).sum(dim=1, keepdim=True)).view( + n, self.num_head, self.window_size * self.window_size, + h * w) + + qk = qk + relative_emb + + qk -= qk_mask * 1e+8 if qk.dtype == torch.float32 else qk_mask * 1e+4 + + local_attn = torch.softmax(qk, dim=2) + + local_attn = self.dropout(local_attn) + + agg_bias = torch.einsum('bhwn,hcw->bhnc', local_attn, + self.relative_emb_v) + + global_attn = self.local2global(local_attn, h, w) + + agg_value = (global_attn @ v.transpose(-2, -1)) + + output = (agg_value + agg_bias).permute(2, 0, 1, + 3).reshape(h * w, n, c) + + output = self.projection(output) + + self.last_size_2d = (h, w) + return output, local_attn + + def local2global(self, local_attn, height, width): + batch_size = local_attn.size()[0] + + pad_height = height + 2 * self.max_dis + pad_width = width + 2 * self.max_dis + + if self.local_mask is not None and (height, + width) == self.last_size_2d: + local_mask = self.local_mask + else: + ky, kx = torch.meshgrid([ + torch.arange(0, pad_height, device=local_attn.device), + torch.arange(0, pad_width, device=local_attn.device) + ]) + qy, qx = torch.meshgrid([ + torch.arange(0, height, device=local_attn.device), + torch.arange(0, width, device=local_attn.device) + ]) + + offset_y = qy.reshape(-1, 1) - ky.reshape(1, -1) + self.max_dis + offset_x = qx.reshape(-1, 1) - kx.reshape(1, -1) + self.max_dis + + local_mask = (offset_y.abs() <= self.max_dis) & (offset_x.abs() <= + self.max_dis) + local_mask = local_mask.view(1, 1, height * width, pad_height, + pad_width) + self.local_mask = local_mask + + global_attn = torch.zeros( + (batch_size, self.num_head, height * width, pad_height, pad_width), + device=local_attn.device) + global_attn[local_mask.expand(batch_size, self.num_head, + -1, -1, -1)] = local_attn.transpose( + -1, -2).reshape(-1) + global_attn = global_attn[:, :, :, self.max_dis:-self.max_dis, + self.max_dis:-self.max_dis].reshape( + batch_size, self.num_head, + height * width, height * width) + + return global_attn + + def pad_and_unfold(self, x): + pad_pixel = self.max_dis * self.dilation + x = F.pad(x, (pad_pixel, pad_pixel, pad_pixel, pad_pixel), + mode='constant', + value=0) + x = F.unfold(x, + kernel_size=(self.window_size, self.window_size), + stride=(1, 1), + dilation=self.dilation) + return x + + +class MultiheadLocalAttentionV3(nn.Module): + def __init__(self, + d_model, + num_head, + dropout=0., + max_dis=7, + dilation=1, + use_linear=True): + super().__init__() + self.dilation = dilation + self.window_size = 2 * max_dis + 1 + self.max_dis = max_dis + self.num_head = num_head + self.T = ((d_model / num_head)**0.5) + + self.use_linear = use_linear + if use_linear: + self.linear_Q = nn.Conv2d(d_model, d_model, kernel_size=1) + self.linear_K = nn.Conv2d(d_model, d_model, kernel_size=1) + self.linear_V = nn.Conv2d(d_model, d_model, kernel_size=1) + + self.relative_emb_k = nn.Conv2d(d_model, + num_head * self.window_size * + self.window_size, + kernel_size=1, + groups=num_head) + self.relative_emb_v = nn.Parameter( + torch.zeros([ + self.num_head, d_model // self.num_head, + self.window_size * self.window_size + ])) + + self.projection = nn.Linear(d_model, d_model) + self.dropout = DropOutLogit(dropout) + + self.padded_local_mask = None + self.local_mask = None + self.last_size_2d = None + self.qk_mask = None + + def forward(self, q, k, v): + n, c, h, w = q.size() + + if self.use_linear: + q = self.linear_Q(q) + k = self.linear_K(k) + v = self.linear_V(v) + + hidden_dim = c // self.num_head + + relative_emb = self.relative_emb_k(q) + relative_emb = relative_emb.view(n, self.num_head, + self.window_size * self.window_size, + h * w) + padded_local_mask, local_mask = self.compute_mask(h, + w, + device=q.device) + qk_mask = (~padded_local_mask).float() + + # Scale + q = q / self.T + + q = q.view(-1, self.num_head, hidden_dim, h * w) + k = k.view(-1, self.num_head, hidden_dim, h * w) + v = v.view(-1, self.num_head, hidden_dim, h * w) + + qk = q.transpose(-1, -2) @ k # [B, nH, kL, qL] + + pad_pixel = self.max_dis * self.dilation + + padded_qk = F.pad(qk.view(-1, self.num_head, h * w, h, w), + (pad_pixel, pad_pixel, pad_pixel, pad_pixel), + mode='constant', + value=-1e+8 if qk.dtype == torch.float32 else -1e+4) + + qk_mask = qk_mask * 1e+8 if (padded_qk.dtype + == torch.float32) else qk_mask * 1e+4 + padded_qk = padded_qk - qk_mask + + padded_qk[padded_local_mask.expand(n, self.num_head, -1, -1, + -1)] += relative_emb.transpose( + -1, -2).reshape(-1) + padded_qk = self.dropout(padded_qk) + + local_qk = padded_qk[padded_local_mask.expand(n, self.num_head, -1, -1, + -1)] + + global_qk = padded_qk[:, :, :, self.max_dis:-self.max_dis, + self.max_dis:-self.max_dis].reshape( + n, self.num_head, h * w, h * w) + + local_attn = torch.softmax(local_qk.reshape( + n, self.num_head, h * w, self.window_size * self.window_size), + dim=3) + global_attn = torch.softmax(global_qk, dim=3) + + agg_bias = torch.einsum('bhnw,hcw->nbhc', local_attn, + self.relative_emb_v).reshape(h * w, n, c) + + agg_value = (global_attn @ v.transpose(-2, -1)) + + output = agg_value + agg_bias + + output = self.projection(output) + + self.last_size_2d = (h, w) + return output, local_attn + + def compute_mask(self, height, width, device=None): + pad_height = height + 2 * self.max_dis + pad_width = width + 2 * self.max_dis + + if self.padded_local_mask is not None and (height, + width) == self.last_size_2d: + padded_local_mask = self.padded_local_mask + local_mask = self.local_mask + + else: + ky, kx = torch.meshgrid([ + torch.arange(0, pad_height, device=device), + torch.arange(0, pad_width, device=device) + ]) + qy, qx = torch.meshgrid([ + torch.arange(0, height, device=device), + torch.arange(0, width, device=device) + ]) + + qy = qy.reshape(-1, 1) + qx = qx.reshape(-1, 1) + offset_y = qy - ky.reshape(1, -1) + self.max_dis + offset_x = qx - kx.reshape(1, -1) + self.max_dis + padded_local_mask = (offset_y.abs() <= self.max_dis) & ( + offset_x.abs() <= self.max_dis) + padded_local_mask = padded_local_mask.view(1, 1, height * width, + pad_height, pad_width) + local_mask = padded_local_mask[:, :, :, self.max_dis:-self.max_dis, + self.max_dis:-self.max_dis] + pad_pixel = self.max_dis * self.dilation + local_mask = F.pad(local_mask.float(), + (pad_pixel, pad_pixel, pad_pixel, pad_pixel), + mode='constant', + value=0).view(1, 1, height * width, pad_height, + pad_width) + self.padded_local_mask = padded_local_mask + self.local_mask = local_mask + + return padded_local_mask, local_mask + + +def linear_gate(x, dim=-1): + # return F.relu_(x).pow(2.) / x.size()[dim] + return torch.softmax(x, dim=dim) + + +def silu(x): + return x * torch.sigmoid(x) + + +class GatedPropagation(nn.Module): + def __init__(self, + d_qk, + d_vu, + num_head=8, + dropout=0., + use_linear=True, + d_att=None, + use_dis=False, + qk_chunks=1, + max_mem_len_ratio=-1, + top_k=-1, + expand_ratio=2.): + super().__init__() + expand_ratio = expand_ratio + self.expand_d_vu = int(d_vu * expand_ratio) + self.d_vu = d_vu + self.d_qk = d_qk + self.num_head = num_head + self.use_dis = use_dis + self.qk_chunks = qk_chunks + self.max_mem_len_ratio = float(max_mem_len_ratio) + self.top_k = top_k + + self.hidden_dim = self.expand_d_vu // num_head + self.d_att = d_qk // num_head if d_att is None else d_att + self.T = self.d_att**0.5 + self.use_linear = use_linear + self.d_middle = self.d_att * self.num_head + + if use_linear: + self.linear_QK = nn.Linear(d_qk, self.d_middle) + half_d_vu = self.hidden_dim * num_head // 2 + self.linear_V1 = nn.Linear(d_vu // 2, half_d_vu) + self.linear_V2 = nn.Linear(d_vu // 2, half_d_vu) + self.linear_U1 = nn.Linear(d_vu // 2, half_d_vu) + self.linear_U2 = nn.Linear(d_vu // 2, half_d_vu) + + self.dropout = nn.Dropout(dropout) + self.drop_prob = dropout + + self.dw_conv = DWConv2d(self.expand_d_vu) + self.projection = nn.Linear(self.expand_d_vu, d_vu) + + self._init_weight() + + def forward(self, Q, K, V, U, size_2d): + """ + :param Q: A 3d tensor with shape of [T_q, bs, C_q] + :param K: A 3d tensor with shape of [T_k, bs, C_k] + :param V: A 3d tensor with shape of [T_v, bs, C_v] + """ + num_head = self.num_head + hidden_dim = self.hidden_dim + + l, bs, _ = Q.size() + + # Linear projections + if self.use_linear: + Q = K = self.linear_QK(Q) + + def cat(X1, X2): + if num_head > 1: + X1 = X1.view(-1, bs, num_head, hidden_dim // 2) + X2 = X2.view(-1, bs, num_head, hidden_dim // 2) + X = torch.cat([X1, X2], + dim=-1).view(-1, bs, num_head * hidden_dim) + else: + X = torch.cat([X1, X2], dim=-1) + return X + + V1, V2 = torch.split(V, self.d_vu // 2, dim=-1) + V1 = self.linear_V1(V1) + V2 = self.linear_V2(V2) + V = silu(cat(V1, V2)) + + U1, U2 = torch.split(U, self.d_vu // 2, dim=-1) + U1 = self.linear_U1(U1) + U2 = self.linear_U2(U2) + U = silu(cat(U1, U2)) + + # Scale + Q = Q / self.T + + if not self.training and self.max_mem_len_ratio > 0: + mem_len_ratio = float(K.size(0)) / Q.size(0) + if mem_len_ratio > self.max_mem_len_ratio: + scaling_ratio = math.log(mem_len_ratio) / math.log( + self.max_mem_len_ratio) + Q = Q * scaling_ratio + + # Multi-head + Q = Q.view(-1, bs, num_head, self.d_att).permute(1, 2, 0, 3) + K = K.view(-1, bs, num_head, self.d_att).permute(1, 2, 3, 0) + V = V.view(-1, bs, num_head, hidden_dim).permute(1, 2, 0, 3) + + # Multiplication + QK = multiply_by_ychunks(Q, K, self.qk_chunks) + if self.use_dis: + QK = 2 * QK - K.pow(2).sum(dim=-2, keepdim=True) + + # Activation + if not self.training and self.top_k > 0 and self.top_k < QK.size()[-1]: + top_QK, indices = torch.topk(QK, k=self.top_k, dim=-1) + top_attn = linear_gate(top_QK, dim=-1) + attn = torch.zeros_like(QK).scatter_(-1, indices, top_attn) + else: + attn = linear_gate(QK, dim=-1) + + # Dropouts + attn = self.dropout(attn) + + # Weighted sum + outputs = multiply_by_xchunks(attn, V, + self.qk_chunks).permute(2, 0, 1, 3) + + # Restore shape + outputs = outputs.reshape(l, bs, -1) * U + + outputs = self.dw_conv(outputs, size_2d) + outputs = self.projection(outputs) + + return outputs, attn + + def _init_weight(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + +class LocalGatedPropagation(nn.Module): + def __init__(self, + d_qk, + d_vu, + num_head, + dropout=0., + max_dis=7, + dilation=1, + use_linear=True, + enable_corr=True, + d_att=None, + use_dis=False, + expand_ratio=2.): + super().__init__() + expand_ratio = expand_ratio + self.expand_d_vu = int(d_vu * expand_ratio) + self.d_qk = d_qk + self.d_vu = d_vu + self.dilation = dilation + self.window_size = 2 * max_dis + 1 + self.max_dis = max_dis + self.num_head = num_head + self.hidden_dim = self.expand_d_vu // num_head + self.d_att = d_qk // num_head if d_att is None else d_att + self.T = self.d_att**0.5 + self.use_dis = use_dis + + self.d_middle = self.d_att * self.num_head + self.use_linear = use_linear + if use_linear: + self.linear_QK = nn.Conv2d(d_qk, self.d_middle, kernel_size=1) + self.linear_V = nn.Conv2d(d_vu, + self.expand_d_vu, + kernel_size=1, + groups=2) + self.linear_U = nn.Conv2d(d_vu, + self.expand_d_vu, + kernel_size=1, + groups=2) + + self.relative_emb_k = nn.Conv2d(self.d_middle, + num_head * self.window_size * + self.window_size, + kernel_size=1, + groups=num_head) + + self.enable_corr = enable_corr + + if enable_corr: + from spatial_correlation_sampler import SpatialCorrelationSampler + self.correlation_sampler = SpatialCorrelationSampler( + kernel_size=1, + patch_size=self.window_size, + stride=1, + padding=0, + dilation=1, + dilation_patch=self.dilation) + + self.dw_conv = DWConv2d(self.expand_d_vu) + self.projection = nn.Linear(self.expand_d_vu, d_vu) + + self.dropout = nn.Dropout(dropout) + + self.drop_prob = dropout + + self.local_mask = None + self.last_size_2d = None + self.qk_mask = None + + def forward(self, q, k, v, u, size_2d): + n, c, h, w = v.size() + hidden_dim = self.hidden_dim + + if self.use_linear: + q = k = self.linear_QK(q) + v = silu(self.linear_V(v)) + u = silu(self.linear_U(u)) + if self.num_head > 1: + v = v.view(-1, 2, self.num_head, hidden_dim // 2, + h * w).permute(0, 2, 1, 3, 4).reshape(n, -1, h, w) + u = u.view(-1, 2, self.num_head, hidden_dim // 2, + h * w).permute(4, 0, 2, 1, 3).reshape(h * w, n, -1) + else: + u = u.permute(2, 3, 0, 1).reshape(h * w, n, -1) + + if self.qk_mask is not None and (h, w) == self.last_size_2d: + qk_mask = self.qk_mask + else: + memory_mask = torch.ones((1, 1, h, w), device=v.device).float() + unfolded_k_mask = self.pad_and_unfold(memory_mask).view( + 1, 1, self.window_size * self.window_size, h * w) + qk_mask = 1 - unfolded_k_mask + self.qk_mask = qk_mask + + relative_emb = self.relative_emb_k(q) + + # Scale + q = q / self.T + + q = q.view(-1, self.d_att, h, w) + k = k.view(-1, self.d_att, h, w) + v = v.view(-1, self.num_head, hidden_dim, h * w) + + relative_emb = relative_emb.view(n, self.num_head, + self.window_size * self.window_size, + h * w) + + if self.enable_corr: + qk = self.correlation_sampler(q, k).view( + n, self.num_head, self.window_size * self.window_size, h * w) + else: + unfolded_k = self.pad_and_unfold(k).view( + n * self.num_head, self.d_att, + self.window_size * self.window_size, h, w) + qk = (q.unsqueeze(2) * unfolded_k).sum(dim=1).view( + n, self.num_head, self.window_size * self.window_size, h * w) + if self.use_dis: + qk = 2 * qk - self.pad_and_unfold( + k.pow(2).sum(dim=1, keepdim=True)).view( + n, self.num_head, self.window_size * self.window_size, + h * w) + + qk = qk + relative_emb + + qk -= qk_mask * 1e+8 if qk.dtype == torch.float32 else qk_mask * 1e+4 + + local_attn = linear_gate(qk, dim=2) + + local_attn = self.dropout(local_attn) + + global_attn = self.local2global(local_attn, h, w) + + agg_value = (global_attn @ v.transpose(-2, -1)).permute( + 2, 0, 1, 3).reshape(h * w, n, -1) + + output = agg_value * u + + output = self.dw_conv(output, size_2d) + output = self.projection(output) + + self.last_size_2d = (h, w) + return output, local_attn + + def local2global(self, local_attn, height, width): + batch_size = local_attn.size()[0] + + pad_height = height + 2 * self.max_dis + pad_width = width + 2 * self.max_dis + + if self.local_mask is not None and (height, + width) == self.last_size_2d: + local_mask = self.local_mask + else: + ky, kx = torch.meshgrid([ + torch.arange(0, pad_height, device=local_attn.device), + torch.arange(0, pad_width, device=local_attn.device) + ]) + qy, qx = torch.meshgrid([ + torch.arange(0, height, device=local_attn.device), + torch.arange(0, width, device=local_attn.device) + ]) + + offset_y = qy.reshape(-1, 1) - ky.reshape(1, -1) + self.max_dis + offset_x = qx.reshape(-1, 1) - kx.reshape(1, -1) + self.max_dis + + local_mask = (offset_y.abs() <= self.max_dis) & (offset_x.abs() <= + self.max_dis) + local_mask = local_mask.view(1, 1, height * width, pad_height, + pad_width) + self.local_mask = local_mask + + global_attn = torch.zeros( + (batch_size, self.num_head, height * width, pad_height, pad_width), + device=local_attn.device) + global_attn[local_mask.expand(batch_size, self.num_head, + -1, -1, -1)] = local_attn.transpose( + -1, -2).reshape(-1) + global_attn = global_attn[:, :, :, self.max_dis:-self.max_dis, + self.max_dis:-self.max_dis].reshape( + batch_size, self.num_head, + height * width, height * width) + + return global_attn + + def pad_and_unfold(self, x): + pad_pixel = self.max_dis * self.dilation + x = F.pad(x, (pad_pixel, pad_pixel, pad_pixel, pad_pixel), + mode='constant', + value=0) + x = F.unfold(x, + kernel_size=(self.window_size, self.window_size), + stride=(1, 1), + dilation=self.dilation) + return x diff --git a/aot/networks/layers/basic.py b/aot/networks/layers/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..1c9137c1c99c04f33658fc3cf0442ec5d23c50fa --- /dev/null +++ b/aot/networks/layers/basic.py @@ -0,0 +1,168 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +class GroupNorm1D(nn.Module): + def __init__(self, indim, groups=8): + super().__init__() + self.gn = nn.GroupNorm(groups, indim) + + def forward(self, x): + return self.gn(x.permute(1, 2, 0)).permute(2, 0, 1) + + +class GNActDWConv2d(nn.Module): + def __init__(self, indim, gn_groups=32): + super().__init__() + self.gn = nn.GroupNorm(gn_groups, indim) + self.conv = nn.Conv2d(indim, + indim, + 5, + dilation=1, + padding=2, + groups=indim, + bias=False) + + def forward(self, x, size_2d): + h, w = size_2d + _, bs, c = x.size() + x = x.view(h, w, bs, c).permute(2, 3, 0, 1) + x = self.gn(x) + x = F.gelu(x) + x = self.conv(x) + x = x.view(bs, c, h * w).permute(2, 0, 1) + return x + + +class DWConv2d(nn.Module): + def __init__(self, indim, dropout=0.1): + super().__init__() + self.conv = nn.Conv2d(indim, + indim, + 5, + dilation=1, + padding=2, + groups=indim, + bias=False) + self.dropout = nn.Dropout2d(p=dropout, inplace=True) + + def forward(self, x, size_2d): + h, w = size_2d + _, bs, c = x.size() + x = x.view(h, w, bs, c).permute(2, 3, 0, 1) + x = self.conv(x) + x = self.dropout(x) + x = x.view(bs, c, h * w).permute(2, 0, 1) + return x + + +class ScaleOffset(nn.Module): + def __init__(self, indim): + super().__init__() + self.gamma = nn.Parameter(torch.ones(indim)) + # torch.nn.init.normal_(self.gamma, std=0.02) + self.beta = nn.Parameter(torch.zeros(indim)) + + def forward(self, x): + if len(x.size()) == 3: + return x * self.gamma + self.beta + else: + return x * self.gamma.view(1, -1, 1, 1) + self.beta.view( + 1, -1, 1, 1) + + +class ConvGN(nn.Module): + def __init__(self, indim, outdim, kernel_size, gn_groups=8): + super().__init__() + self.conv = nn.Conv2d(indim, + outdim, + kernel_size, + padding=kernel_size // 2) + self.gn = nn.GroupNorm(gn_groups, outdim) + + def forward(self, x): + return self.gn(self.conv(x)) + + +def seq_to_2d(tensor, size_2d): + h, w = size_2d + _, n, c = tensor.size() + tensor = tensor.view(h, w, n, c).permute(2, 3, 0, 1).contiguous() + return tensor + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = ( + x.shape[0], + x.shape[1], + ) + (1, ) * (x.ndim - 2 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +def mask_out(x, y, mask_rate=0.15, training=False): + if mask_rate == 0. or not training: + return x + + keep_prob = 1 - mask_rate + shape = ( + x.shape[0], + x.shape[1], + ) + (1, ) * (x.ndim - 2 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x * random_tensor + y * (1 - random_tensor) + + return output + + +class DropPath(nn.Module): + def __init__(self, drop_prob=None, batch_dim=0): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.batch_dim = batch_dim + + def forward(self, x): + return self.drop_path(x, self.drop_prob) + + def drop_path(self, x, drop_prob): + if drop_prob == 0. or not self.training: + return x + keep_prob = 1 - drop_prob + shape = [1 for _ in range(x.ndim)] + shape[self.batch_dim] = x.shape[self.batch_dim] + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropOutLogit(nn.Module): + def __init__(self, drop_prob=None): + super(DropOutLogit, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return self.drop_logit(x, self.drop_prob) + + def drop_logit(self, x, drop_prob): + if drop_prob == 0. or not self.training: + return x + random_tensor = drop_prob + torch.rand( + x.shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + mask = random_tensor * 1e+8 if ( + x.dtype == torch.float32) else random_tensor * 1e+4 + output = x - mask + return output diff --git a/aot/networks/layers/loss.py b/aot/networks/layers/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..acdd4daef65eb768f5696a11c07615a2fd2d5d8e --- /dev/null +++ b/aot/networks/layers/loss.py @@ -0,0 +1,188 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from itertools import ifilterfalse +except ImportError: # py3k + from itertools import filterfalse as ifilterfalse + + +def dice_loss(probas, labels, smooth=1): + + C = probas.size(1) + losses = [] + for c in list(range(C)): + fg = (labels == c).float() + if fg.sum() == 0: + continue + class_pred = probas[:, c] + p0 = class_pred + g0 = fg + numerator = 2 * torch.sum(p0 * g0) + smooth + denominator = torch.sum(p0) + torch.sum(g0) + smooth + losses.append(1 - ((numerator) / (denominator))) + return mean(losses) + + +def tversky_loss(probas, labels, alpha=0.5, beta=0.5, epsilon=1e-6): + ''' + Tversky loss function. + probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) + labels: [P] Tensor, ground truth labels (between 0 and C - 1) + + Same as soft dice loss when alpha=beta=0.5. + Same as Jaccord loss when alpha=beta=1.0. + See `Tversky loss function for image segmentation using 3D fully convolutional deep networks` + https://arxiv.org/pdf/1706.05721.pdf + ''' + C = probas.size(1) + losses = [] + for c in list(range(C)): + fg = (labels == c).float() + if fg.sum() == 0: + continue + class_pred = probas[:, c] + p0 = class_pred + p1 = 1 - class_pred + g0 = fg + g1 = 1 - fg + numerator = torch.sum(p0 * g0) + denominator = numerator + alpha * \ + torch.sum(p0*g1) + beta*torch.sum(p1*g0) + losses.append(1 - ((numerator) / (denominator + epsilon))) + return mean(losses) + + +def flatten_probas(probas, labels, ignore=255): + """ + Flattens predictions in the batch + """ + B, C, H, W = probas.size() + probas = probas.permute(0, 2, 3, + 1).contiguous().view(-1, C) # B * H * W, C = P, C + labels = labels.view(-1) + if ignore is None: + return probas, labels + valid = (labels != ignore) + vprobas = probas[valid.view(-1, 1).expand(-1, C)].reshape(-1, C) + # vprobas = probas[torch.nonzero(valid).squeeze()] + vlabels = labels[valid] + return vprobas, vlabels + + +def isnan(x): + return x != x + + +def mean(l, ignore_nan=False, empty=0): + """ + nanmean compatible with generators. + """ + l = iter(l) + if ignore_nan: + l = ifilterfalse(isnan, l) + try: + n = 1 + acc = next(l) + except StopIteration: + if empty == 'raise': + raise ValueError('Empty mean') + return empty + for n, v in enumerate(l, 2): + acc += v + if n == 1: + return acc + return acc / n + + +class DiceLoss(nn.Module): + def __init__(self, ignore_index=255): + super(DiceLoss, self).__init__() + self.ignore_index = ignore_index + + def forward(self, tmp_dic, label_dic, step=None): + total_loss = [] + for idx in range(len(tmp_dic)): + pred = tmp_dic[idx] + label = label_dic[idx] + pred = F.softmax(pred, dim=1) + label = label.view(1, 1, pred.size()[2], pred.size()[3]) + loss = dice_loss( + *flatten_probas(pred, label, ignore=self.ignore_index)) + total_loss.append(loss.unsqueeze(0)) + total_loss = torch.cat(total_loss, dim=0) + return total_loss + + +class SoftJaccordLoss(nn.Module): + def __init__(self, ignore_index=255): + super(SoftJaccordLoss, self).__init__() + self.ignore_index = ignore_index + + def forward(self, tmp_dic, label_dic, step=None): + total_loss = [] + for idx in range(len(tmp_dic)): + pred = tmp_dic[idx] + label = label_dic[idx] + pred = F.softmax(pred, dim=1) + label = label.view(1, 1, pred.size()[2], pred.size()[3]) + loss = tversky_loss(*flatten_probas(pred, + label, + ignore=self.ignore_index), + alpha=1.0, + beta=1.0) + total_loss.append(loss.unsqueeze(0)) + total_loss = torch.cat(total_loss, dim=0) + return total_loss + + +class CrossEntropyLoss(nn.Module): + def __init__(self, + top_k_percent_pixels=None, + hard_example_mining_step=100000): + super(CrossEntropyLoss, self).__init__() + self.top_k_percent_pixels = top_k_percent_pixels + if top_k_percent_pixels is not None: + assert (top_k_percent_pixels > 0 and top_k_percent_pixels < 1) + self.hard_example_mining_step = hard_example_mining_step + 1e-5 + if self.top_k_percent_pixels is None: + self.celoss = nn.CrossEntropyLoss(ignore_index=255, + reduction='mean') + else: + self.celoss = nn.CrossEntropyLoss(ignore_index=255, + reduction='none') + + def forward(self, dic_tmp, y, step): + total_loss = [] + for i in range(len(dic_tmp)): + pred_logits = dic_tmp[i] + gts = y[i] + if self.top_k_percent_pixels is None: + final_loss = self.celoss(pred_logits, gts) + else: + # Only compute the loss for top k percent pixels. + # First, compute the loss for all pixels. Note we do not put the loss + # to loss_collection and set reduction = None to keep the shape. + num_pixels = float(pred_logits.size(2) * pred_logits.size(3)) + pred_logits = pred_logits.view( + -1, pred_logits.size(1), + pred_logits.size(2) * pred_logits.size(3)) + gts = gts.view(-1, gts.size(1) * gts.size(2)) + pixel_losses = self.celoss(pred_logits, gts) + if self.hard_example_mining_step == 0: + top_k_pixels = int(self.top_k_percent_pixels * num_pixels) + else: + ratio = min(1.0, + step / float(self.hard_example_mining_step)) + top_k_pixels = int((ratio * self.top_k_percent_pixels + + (1.0 - ratio)) * num_pixels) + top_k_loss, top_k_indices = torch.topk(pixel_losses, + k=top_k_pixels, + dim=1) + + final_loss = torch.mean(top_k_loss) + final_loss = final_loss.unsqueeze(0) + total_loss.append(final_loss) + total_loss = torch.cat(total_loss, dim=0) + return total_loss diff --git a/aot/networks/layers/normalization.py b/aot/networks/layers/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..3d89d9c43e5273c4141983fa0654ef5b912f2b92 --- /dev/null +++ b/aot/networks/layers/normalization.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters + are fixed + """ + def __init__(self, n, epsilon=1e-5): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n) - epsilon) + self.epsilon = epsilon + + def forward(self, x): + """ + Refer to Detectron2 (https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py) + """ + if x.requires_grad: + # When gradients are needed, F.batch_norm will use extra memory + # because its backward op computes gradients for weight/bias as well. + scale = self.weight * (self.running_var + self.epsilon).rsqrt() + bias = self.bias - self.running_mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + out_dtype = x.dtype # may be half + return x * scale.to(out_dtype) + bias.to(out_dtype) + else: + # When gradients are not needed, F.batch_norm is a single fused op + # and provide more optimization opportunities. + return F.batch_norm( + x, + self.running_mean, + self.running_var, + self.weight, + self.bias, + training=False, + eps=self.epsilon, + ) diff --git a/aot/networks/layers/position.py b/aot/networks/layers/position.py new file mode 100644 index 0000000000000000000000000000000000000000..8d37f7273bcf509dd90f67d6ced7534339611ed0 --- /dev/null +++ b/aot/networks/layers/position.py @@ -0,0 +1,90 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from utils.math import truncated_normal_ + + +class Downsample2D(nn.Module): + def __init__(self, mode='nearest', scale=4): + super().__init__() + self.mode = mode + self.scale = scale + + def forward(self, x): + n, c, h, w = x.size() + x = F.interpolate(x, + size=(h // self.scale + 1, w // self.scale + 1), + mode=self.mode) + return x + + +def generate_coord(x): + _, _, h, w = x.size() + device = x.device + col = torch.arange(0, h, device=device) + row = torch.arange(0, w, device=device) + grid_h, grid_w = torch.meshgrid(col, row) + return grid_h, grid_w + + +class PositionEmbeddingSine(nn.Module): + def __init__(self, + num_pos_feats=64, + temperature=10000, + normalize=False, + scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + grid_y, grid_x = generate_coord(x) + + y_embed = grid_y.unsqueeze(0).float() + x_embed = grid_x.unsqueeze(0).float() + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, + dtype=torch.float32, + device=x.device) + dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + def __init__(self, num_pos_feats=64, H=30, W=30): + super().__init__() + self.H = H + self.W = W + self.pos_emb = nn.Parameter( + truncated_normal_(torch.zeros(1, num_pos_feats, H, W))) + + def forward(self, x): + bs, _, h, w = x.size() + pos_emb = self.pos_emb + if h != self.H or w != self.W: + pos_emb = F.interpolate(pos_emb, size=(h, w), mode="bilinear") + return pos_emb diff --git a/aot/networks/layers/transformer.py b/aot/networks/layers/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5f20ab82112214101b9481b793e2b1931539b34b --- /dev/null +++ b/aot/networks/layers/transformer.py @@ -0,0 +1,690 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from networks.layers.basic import DropPath, GroupNorm1D, GNActDWConv2d, seq_to_2d, ScaleOffset, mask_out +from networks.layers.attention import silu, MultiheadAttention, MultiheadLocalAttentionV2, MultiheadLocalAttentionV3, GatedPropagation, LocalGatedPropagation + + +def _get_norm(indim, type='ln', groups=8): + if type == 'gn': + return GroupNorm1D(indim, groups) + else: + return nn.LayerNorm(indim) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError( + F"activation should be relu/gele/glu, not {activation}.") + + +class LongShortTermTransformer(nn.Module): + def __init__(self, + num_layers=2, + d_model=256, + self_nhead=8, + att_nhead=8, + dim_feedforward=1024, + emb_dropout=0., + droppath=0.1, + lt_dropout=0., + st_dropout=0., + droppath_lst=False, + droppath_scaling=False, + activation="gelu", + return_intermediate=False, + intermediate_norm=True, + final_norm=True, + block_version="v1"): + + super().__init__() + self.intermediate_norm = intermediate_norm + self.final_norm = final_norm + self.num_layers = num_layers + self.return_intermediate = return_intermediate + + self.emb_dropout = nn.Dropout(emb_dropout, True) + self.mask_token = nn.Parameter(torch.randn([1, 1, d_model])) + + if block_version == "v1": + block = LongShortTermTransformerBlock + elif block_version == "v2": + block = LongShortTermTransformerBlockV2 + elif block_version == "v3": + block = LongShortTermTransformerBlockV3 + else: + raise NotImplementedError + + layers = [] + for idx in range(num_layers): + if droppath_scaling: + if num_layers == 1: + droppath_rate = 0 + else: + droppath_rate = droppath * idx / (num_layers - 1) + else: + droppath_rate = droppath + layers.append( + block(d_model, self_nhead, att_nhead, dim_feedforward, + droppath_rate, lt_dropout, st_dropout, droppath_lst, + activation)) + self.layers = nn.ModuleList(layers) + + num_norms = num_layers - 1 if intermediate_norm else 0 + if final_norm: + num_norms += 1 + self.decoder_norms = [ + _get_norm(d_model, type='ln') for _ in range(num_norms) + ] if num_norms > 0 else None + + if self.decoder_norms is not None: + self.decoder_norms = nn.ModuleList(self.decoder_norms) + + def forward(self, + tgt, + long_term_memories, + short_term_memories, + curr_id_emb=None, + self_pos=None, + size_2d=None): + + output = self.emb_dropout(tgt) + + # output = mask_out(output, self.mask_token, 0.15, self.training) + + intermediate = [] + intermediate_memories = [] + + for idx, layer in enumerate(self.layers): + output, memories = layer(output, + long_term_memories[idx] if + long_term_memories is not None else None, + short_term_memories[idx] if + short_term_memories is not None else None, + curr_id_emb=curr_id_emb, + self_pos=self_pos, + size_2d=size_2d) + + if self.return_intermediate: + intermediate.append(output) + intermediate_memories.append(memories) + + if self.decoder_norms is not None: + if self.final_norm: + output = self.decoder_norms[-1](output) + + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.intermediate_norm: + for idx in range(len(intermediate) - 1): + intermediate[idx] = self.decoder_norms[idx]( + intermediate[idx]) + + if self.return_intermediate: + return intermediate, intermediate_memories + + return output, memories + + +class DualBranchGPM(nn.Module): + def __init__(self, + num_layers=2, + d_model=256, + self_nhead=8, + att_nhead=8, + dim_feedforward=1024, + emb_dropout=0., + droppath=0.1, + lt_dropout=0., + st_dropout=0., + droppath_lst=False, + droppath_scaling=False, + activation="gelu", + return_intermediate=False, + intermediate_norm=True, + final_norm=True): + + super().__init__() + self.intermediate_norm = intermediate_norm + self.final_norm = final_norm + self.num_layers = num_layers + self.return_intermediate = return_intermediate + + self.emb_dropout = nn.Dropout(emb_dropout, True) + # self.mask_token = nn.Parameter(torch.randn([1, 1, d_model])) + + block = GatedPropagationModule + + layers = [] + for idx in range(num_layers): + if droppath_scaling: + if num_layers == 1: + droppath_rate = 0 + else: + droppath_rate = droppath * idx / (num_layers - 1) + else: + droppath_rate = droppath + layers.append( + block(d_model, + self_nhead, + att_nhead, + dim_feedforward, + droppath_rate, + lt_dropout, + st_dropout, + droppath_lst, + activation, + layer_idx=idx)) + self.layers = nn.ModuleList(layers) + + num_norms = num_layers - 1 if intermediate_norm else 0 + if final_norm: + num_norms += 1 + self.decoder_norms = [ + _get_norm(d_model * 2, type='gn', groups=2) + for _ in range(num_norms) + ] if num_norms > 0 else None + + if self.decoder_norms is not None: + self.decoder_norms = nn.ModuleList(self.decoder_norms) + + def forward(self, + tgt, + long_term_memories, + short_term_memories, + curr_id_emb=None, + self_pos=None, + size_2d=None): + + output = self.emb_dropout(tgt) + + # output = mask_out(output, self.mask_token, 0.15, self.training) + + intermediate = [] + intermediate_memories = [] + output_id = None + + for idx, layer in enumerate(self.layers): + output, output_id, memories = layer( + output, + output_id, + long_term_memories[idx] + if long_term_memories is not None else None, + short_term_memories[idx] + if short_term_memories is not None else None, + curr_id_emb=curr_id_emb, + self_pos=self_pos, + size_2d=size_2d) + + cat_output = torch.cat([output, output_id], dim=2) + + if self.return_intermediate: + intermediate.append(cat_output) + intermediate_memories.append(memories) + + if self.decoder_norms is not None: + if self.final_norm: + cat_output = self.decoder_norms[-1](cat_output) + + if self.return_intermediate: + intermediate.pop() + intermediate.append(cat_output) + + if self.intermediate_norm: + for idx in range(len(intermediate) - 1): + intermediate[idx] = self.decoder_norms[idx]( + intermediate[idx]) + + if self.return_intermediate: + return intermediate, intermediate_memories + + return cat_output, memories + + +class LongShortTermTransformerBlock(nn.Module): + def __init__(self, + d_model, + self_nhead, + att_nhead, + dim_feedforward=1024, + droppath=0.1, + lt_dropout=0., + st_dropout=0., + droppath_lst=False, + activation="gelu", + local_dilation=1, + enable_corr=True): + super().__init__() + + # Long Short-Term Attention + self.norm1 = _get_norm(d_model) + self.linear_Q = nn.Linear(d_model, d_model) + self.linear_V = nn.Linear(d_model, d_model) + + self.long_term_attn = MultiheadAttention(d_model, + att_nhead, + use_linear=False, + dropout=lt_dropout) + + # MultiheadLocalAttention = MultiheadLocalAttentionV2 if enable_corr else MultiheadLocalAttentionV3 + if enable_corr: + try: + import spatial_correlation_sampler + MultiheadLocalAttention = MultiheadLocalAttentionV2 + except Exception as inst: + print(inst) + print("Failed to import PyTorch Correlation, For better efficiency, please install it.") + MultiheadLocalAttention = MultiheadLocalAttentionV3 + else: + MultiheadLocalAttention = MultiheadLocalAttentionV3 + self.short_term_attn = MultiheadLocalAttention(d_model, + att_nhead, + dilation=local_dilation, + use_linear=False, + dropout=st_dropout) + self.lst_dropout = nn.Dropout(max(lt_dropout, st_dropout), True) + self.droppath_lst = droppath_lst + + # Self-attention + self.norm2 = _get_norm(d_model) + self.self_attn = MultiheadAttention(d_model, self_nhead) + + # Feed-forward + self.norm3 = _get_norm(d_model) + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.activation = GNActDWConv2d(dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.droppath = DropPath(droppath, batch_dim=1) + self._init_weight() + + def with_pos_embed(self, tensor, pos=None): + size = tensor.size() + if len(size) == 4 and pos is not None: + n, c, h, w = size + pos = pos.view(h, w, n, c).permute(2, 3, 0, 1) + return tensor if pos is None else tensor + pos + + def forward(self, + tgt, + long_term_memory=None, + short_term_memory=None, + curr_id_emb=None, + self_pos=None, + size_2d=(30, 30)): + + # Self-attention + _tgt = self.norm1(tgt) + q = k = self.with_pos_embed(_tgt, self_pos) + v = _tgt + tgt2 = self.self_attn(q, k, v)[0] + + tgt = tgt + self.droppath(tgt2) + + # Long Short-Term Attention + _tgt = self.norm2(tgt) + + curr_Q = self.linear_Q(_tgt) + curr_K = curr_Q + curr_V = _tgt + + local_Q = seq_to_2d(curr_Q, size_2d) + + if curr_id_emb is not None: + global_K, global_V = self.fuse_key_value_id( + curr_K, curr_V, curr_id_emb) + local_K = seq_to_2d(global_K, size_2d) + local_V = seq_to_2d(global_V, size_2d) + else: + global_K, global_V = long_term_memory + local_K, local_V = short_term_memory + + tgt2 = self.long_term_attn(curr_Q, global_K, global_V)[0] + tgt3 = self.short_term_attn(local_Q, local_K, local_V)[0] + + if self.droppath_lst: + tgt = tgt + self.droppath(tgt2 + tgt3) + else: + tgt = tgt + self.lst_dropout(tgt2 + tgt3) + + # Feed-forward + _tgt = self.norm3(tgt) + + tgt2 = self.linear2(self.activation(self.linear1(_tgt), size_2d)) + + tgt = tgt + self.droppath(tgt2) + + return tgt, [[curr_K, curr_V], [global_K, global_V], + [local_K, local_V]] + + def fuse_key_value_id(self, key, value, id_emb): + K = key + V = self.linear_V(value + id_emb) + return K, V + + def _init_weight(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + +class LongShortTermTransformerBlockV2(nn.Module): + def __init__(self, + d_model, + self_nhead, + att_nhead, + dim_feedforward=1024, + droppath=0.1, + lt_dropout=0., + st_dropout=0., + droppath_lst=False, + activation="gelu", + local_dilation=1, + enable_corr=True): + super().__init__() + self.d_model = d_model + self.att_nhead = att_nhead + + # Self-attention + self.norm1 = _get_norm(d_model) + self.self_attn = MultiheadAttention(d_model, self_nhead) + + # Long Short-Term Attention + self.norm2 = _get_norm(d_model) + self.linear_QV = nn.Linear(d_model, 2 * d_model) + self.linear_ID_KV = nn.Linear(d_model, d_model + att_nhead) + + self.long_term_attn = MultiheadAttention(d_model, + att_nhead, + use_linear=False, + dropout=lt_dropout) + + # MultiheadLocalAttention = MultiheadLocalAttentionV2 if enable_corr else MultiheadLocalAttentionV3 + if enable_corr: + try: + import spatial_correlation_sampler + MultiheadLocalAttention = MultiheadLocalAttentionV2 + except Exception as inst: + print(inst) + print("Failed to import PyTorch Correlation, For better efficiency, please install it.") + MultiheadLocalAttention = MultiheadLocalAttentionV3 + else: + MultiheadLocalAttention = MultiheadLocalAttentionV3 + self.short_term_attn = MultiheadLocalAttention(d_model, + att_nhead, + dilation=local_dilation, + use_linear=False, + dropout=st_dropout) + self.lst_dropout = nn.Dropout(max(lt_dropout, st_dropout), True) + self.droppath_lst = droppath_lst + + # Feed-forward + self.norm3 = _get_norm(d_model) + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.activation = GNActDWConv2d(dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.droppath = DropPath(droppath, batch_dim=1) + self._init_weight() + + def with_pos_embed(self, tensor, pos=None): + size = tensor.size() + if len(size) == 4 and pos is not None: + n, c, h, w = size + pos = pos.view(h, w, n, c).permute(2, 3, 0, 1) + return tensor if pos is None else tensor + pos + + def forward(self, + tgt, + long_term_memory=None, + short_term_memory=None, + curr_id_emb=None, + self_pos=None, + size_2d=(30, 30)): + + # Self-attention + _tgt = self.norm1(tgt) + q = k = self.with_pos_embed(_tgt, self_pos) + v = _tgt + tgt2 = self.self_attn(q, k, v)[0] + + tgt = tgt + self.droppath(tgt2) + + # Long Short-Term Attention + _tgt = self.norm2(tgt) + + curr_QV = self.linear_QV(_tgt) + curr_QV = torch.split(curr_QV, self.d_model, dim=2) + curr_Q = curr_K = curr_QV[0] + curr_V = curr_QV[1] + + local_Q = seq_to_2d(curr_Q, size_2d) + + if curr_id_emb is not None: + global_K, global_V = self.fuse_key_value_id( + curr_K, curr_V, curr_id_emb) + + local_K = seq_to_2d(global_K, size_2d) + local_V = seq_to_2d(global_V, size_2d) + else: + global_K, global_V = long_term_memory + local_K, local_V = short_term_memory + + tgt2 = self.long_term_attn(curr_Q, global_K, global_V)[0] + tgt3 = self.short_term_attn(local_Q, local_K, local_V)[0] + + if self.droppath_lst: + tgt = tgt + self.droppath(tgt2 + tgt3) + else: + tgt = tgt + self.lst_dropout(tgt2 + tgt3) + + # Feed-forward + _tgt = self.norm3(tgt) + + tgt2 = self.linear2(self.activation(self.linear1(_tgt), size_2d)) + + tgt = tgt + self.droppath(tgt2) + + return tgt, [[curr_K, curr_V], [global_K, global_V], + [local_K, local_V]] + + def fuse_key_value_id(self, key, value, id_emb): + ID_KV = self.linear_ID_KV(id_emb) + ID_K, ID_V = torch.split(ID_KV, [self.att_nhead, self.d_model], dim=2) + bs = key.size(1) + K = key.view(-1, bs, self.att_nhead, self.d_model // + self.att_nhead) * (1 + torch.tanh(ID_K)).unsqueeze(-1) + K = K.view(-1, bs, self.d_model) + V = value + ID_V + return K, V + + def _init_weight(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + +class GatedPropagationModule(nn.Module): + def __init__(self, + d_model, + self_nhead, + att_nhead, + dim_feedforward=1024, + droppath=0.1, + lt_dropout=0., + st_dropout=0., + droppath_lst=False, + activation="gelu", + local_dilation=1, + enable_corr=True, + max_local_dis=7, + layer_idx=0, + expand_ratio=2.): + super().__init__() + expand_ratio = expand_ratio + expand_d_model = int(d_model * expand_ratio) + self.expand_d_model = expand_d_model + self.d_model = d_model + self.att_nhead = att_nhead + + d_att = d_model // 2 if att_nhead == 1 else d_model // att_nhead + self.d_att = d_att + self.layer_idx = layer_idx + + # Long Short-Term Attention + self.norm1 = _get_norm(d_model) + self.linear_QV = nn.Linear(d_model, d_att * att_nhead + expand_d_model) + self.linear_U = nn.Linear(d_model, expand_d_model) + + if layer_idx == 0: + self.linear_ID_V = nn.Linear(d_model, expand_d_model) + else: + self.id_norm1 = _get_norm(d_model) + self.linear_ID_V = nn.Linear(d_model * 2, expand_d_model) + self.linear_ID_U = nn.Linear(d_model, expand_d_model) + + self.long_term_attn = GatedPropagation(d_qk=self.d_model, + d_vu=self.d_model * 2, + num_head=att_nhead, + use_linear=False, + dropout=lt_dropout, + d_att=d_att, + top_k=-1, + expand_ratio=expand_ratio) + + if enable_corr: + try: + import spatial_correlation_sampler + except Exception as inst: + print(inst) + print("Failed to import PyTorch Correlation, For better efficiency, please install it.") + enable_corr = False + self.short_term_attn = LocalGatedPropagation(d_qk=self.d_model, + d_vu=self.d_model * 2, + num_head=att_nhead, + dilation=local_dilation, + use_linear=False, + enable_corr=enable_corr, + dropout=st_dropout, + d_att=d_att, + max_dis=max_local_dis, + expand_ratio=expand_ratio) + + self.lst_dropout = nn.Dropout(max(lt_dropout, st_dropout), True) + self.droppath_lst = droppath_lst + + # Self-attention + self.norm2 = _get_norm(d_model) + self.id_norm2 = _get_norm(d_model) + self.self_attn = GatedPropagation(d_model * 2, + d_model * 2, + self_nhead, + d_att=d_att) + + self.droppath = DropPath(droppath, batch_dim=1) + self._init_weight() + + def with_pos_embed(self, tensor, pos=None): + size = tensor.size() + if len(size) == 4 and pos is not None: + n, c, h, w = size + pos = pos.view(h, w, n, c).permute(2, 3, 0, 1) + return tensor if pos is None else tensor + pos + + def forward(self, + tgt, + tgt_id=None, + long_term_memory=None, + short_term_memory=None, + curr_id_emb=None, + self_pos=None, + size_2d=(30, 30)): + + # Long Short-Term Attention + _tgt = self.norm1(tgt) + + curr_QV = self.linear_QV(_tgt) + curr_QV = torch.split( + curr_QV, [self.d_att * self.att_nhead, self.expand_d_model], dim=2) + curr_Q = curr_K = curr_QV[0] + local_Q = seq_to_2d(curr_Q, size_2d) + curr_V = silu(curr_QV[1]) + curr_U = self.linear_U(_tgt) + + if tgt_id is None: + tgt_id = 0 + cat_curr_U = torch.cat( + [silu(curr_U), torch.ones_like(curr_U)], dim=-1) + curr_ID_V = None + else: + _tgt_id = self.id_norm1(tgt_id) + curr_ID_V = _tgt_id + curr_ID_U = self.linear_ID_U(_tgt_id) + cat_curr_U = silu(torch.cat([curr_U, curr_ID_U], dim=-1)) + + if curr_id_emb is not None: + global_K, global_V = curr_K, curr_V + local_K = seq_to_2d(global_K, size_2d) + local_V = seq_to_2d(global_V, size_2d) + + _, global_ID_V = self.fuse_key_value_id(None, curr_ID_V, + curr_id_emb) + local_ID_V = seq_to_2d(global_ID_V, size_2d) + else: + global_K, global_V, _, global_ID_V = long_term_memory + local_K, local_V, _, local_ID_V = short_term_memory + + cat_global_V = torch.cat([global_V, global_ID_V], dim=-1) + cat_local_V = torch.cat([local_V, local_ID_V], dim=1) + + cat_tgt2, _ = self.long_term_attn(curr_Q, global_K, cat_global_V, + cat_curr_U, size_2d) + cat_tgt3, _ = self.short_term_attn(local_Q, local_K, cat_local_V, + cat_curr_U, size_2d) + + tgt2, tgt_id2 = torch.split(cat_tgt2, self.d_model, dim=-1) + tgt3, tgt_id3 = torch.split(cat_tgt3, self.d_model, dim=-1) + + if self.droppath_lst: + tgt = tgt + self.droppath(tgt2 + tgt3) + tgt_id = tgt_id + self.droppath(tgt_id2 + tgt_id3) + else: + tgt = tgt + self.lst_dropout(tgt2 + tgt3) + tgt_id = tgt_id + self.lst_dropout(tgt_id2 + tgt_id3) + + # Self-attention + _tgt = self.norm2(tgt) + _tgt_id = self.id_norm2(tgt_id) + q = k = v = u = torch.cat([_tgt, _tgt_id], dim=-1) + + cat_tgt2, _ = self.self_attn(q, k, v, u, size_2d) + + tgt2, tgt_id2 = torch.split(cat_tgt2, self.d_model, dim=-1) + + tgt = tgt + self.droppath(tgt2) + tgt_id = tgt_id + self.droppath(tgt_id2) + + return tgt, tgt_id, [[curr_K, curr_V, None, curr_ID_V], + [global_K, global_V, None, global_ID_V], + [local_K, local_V, None, local_ID_V]] + + def fuse_key_value_id(self, key, value, id_emb): + ID_K = None + if value is not None: + ID_V = silu(self.linear_ID_V(torch.cat([value, id_emb], dim=2))) + else: + ID_V = silu(self.linear_ID_V(id_emb)) + return ID_K, ID_V + + def _init_weight(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) diff --git a/aot/networks/managers/evaluator.py b/aot/networks/managers/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..7414519d6f878a4552847fa76b45f2fff12bff21 --- /dev/null +++ b/aot/networks/managers/evaluator.py @@ -0,0 +1,552 @@ +import os +import time +import datetime as datetime +import json + +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision import transforms + +from dataloaders.eval_datasets import YOUTUBEVOS_Test, YOUTUBEVOS_DenseTest, DAVIS_Test, EVAL_TEST +import dataloaders.video_transforms as tr + +from utils.image import flip_tensor, save_mask +from utils.checkpoint import load_network +from utils.eval import zip_folder + +from networks.models import build_vos_model +from networks.engines import build_engine + + +class Evaluator(object): + def __init__(self, cfg, rank=0, seq_queue=None, info_queue=None): + self.gpu = cfg.TEST_GPU_ID + rank + self.gpu_num = cfg.TEST_GPU_NUM + self.rank = rank + self.cfg = cfg + self.seq_queue = seq_queue + self.info_queue = info_queue + + self.print_log("Exp {}:".format(cfg.EXP_NAME)) + self.print_log(json.dumps(cfg.__dict__, indent=4, sort_keys=True)) + + print("Use GPU {} for evaluating.".format(self.gpu)) + torch.cuda.set_device(self.gpu) + + self.print_log('Build VOS model.') + self.model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(self.gpu) + + self.process_pretrained_model() + + self.prepare_dataset() + + def process_pretrained_model(self): + cfg = self.cfg + + if cfg.TEST_CKPT_PATH == 'test': + self.ckpt = 'test' + self.print_log('Test evaluation.') + return + + if cfg.TEST_CKPT_PATH is None: + if cfg.TEST_CKPT_STEP is not None: + ckpt = str(cfg.TEST_CKPT_STEP) + else: + ckpts = os.listdir(cfg.DIR_CKPT) + if len(ckpts) > 0: + ckpts = list( + map(lambda x: int(x.split('_')[-1].split('.')[0]), + ckpts)) + ckpt = np.sort(ckpts)[-1] + else: + self.print_log('No checkpoint in {}.'.format(cfg.DIR_CKPT)) + exit() + self.ckpt = ckpt + if cfg.TEST_EMA: + cfg.DIR_CKPT = os.path.join(cfg.DIR_RESULT, 'ema_ckpt') + cfg.TEST_CKPT_PATH = os.path.join(cfg.DIR_CKPT, + 'save_step_%s.pth' % ckpt) + try: + self.model, removed_dict = load_network( + self.model, cfg.TEST_CKPT_PATH, self.gpu) + except Exception as inst: + self.print_log(inst) + self.print_log('Try to use backup checkpoint.') + DIR_RESULT = './backup/{}/{}'.format(cfg.EXP_NAME, + cfg.STAGE_NAME) + DIR_CKPT = os.path.join(DIR_RESULT, 'ema_ckpt') + TEST_CKPT_PATH = os.path.join(DIR_CKPT, + 'save_step_%s.pth' % ckpt) + self.model, removed_dict = load_network( + self.model, TEST_CKPT_PATH, self.gpu) + + if len(removed_dict) > 0: + self.print_log( + 'Remove {} from pretrained model.'.format(removed_dict)) + self.print_log('Load latest checkpoint from {}'.format( + cfg.TEST_CKPT_PATH)) + else: + self.ckpt = 'unknown' + self.model, removed_dict = load_network(self.model, + cfg.TEST_CKPT_PATH, + self.gpu) + if len(removed_dict) > 0: + self.print_log( + 'Remove {} from pretrained model.'.format(removed_dict)) + self.print_log('Load checkpoint from {}'.format( + cfg.TEST_CKPT_PATH)) + + def prepare_dataset(self): + cfg = self.cfg + self.print_log('Process dataset...') + eval_transforms = transforms.Compose([ + tr.MultiRestrictSize(cfg.TEST_MAX_SHORT_EDGE, + cfg.TEST_MAX_LONG_EDGE, cfg.TEST_FLIP, + cfg.TEST_MULTISCALE, cfg.MODEL_ALIGN_CORNERS), + tr.MultiToTensor() + ]) + + exp_name = cfg.EXP_NAME + if 'aost' in cfg.MODEL_VOS: + exp_name += '_L{}'.format(int(cfg.MODEL_LSTT_NUM)) + + eval_name = '{}_{}_{}_{}_ckpt_{}'.format(cfg.TEST_DATASET, + cfg.TEST_DATASET_SPLIT, + exp_name, cfg.STAGE_NAME, + self.ckpt) + + if cfg.TEST_EMA: + eval_name += '_ema' + if cfg.TEST_FLIP: + eval_name += '_flip' + if len(cfg.TEST_MULTISCALE) > 1: + eval_name += '_ms_' + str(cfg.TEST_MULTISCALE).replace( + '.', 'dot').replace('[', '').replace(']', '').replace( + ', ', '_') + + if 'youtubevos' in cfg.TEST_DATASET: + year = int(cfg.TEST_DATASET[-4:]) + self.result_root = os.path.join(cfg.DIR_EVALUATION, + cfg.TEST_DATASET, eval_name, + 'Annotations') + if '_all_frames' in cfg.TEST_DATASET_SPLIT: + split = cfg.TEST_DATASET_SPLIT.split('_')[0] + youtubevos_test = YOUTUBEVOS_DenseTest + + self.result_root_sparse = os.path.join(cfg.DIR_EVALUATION, + cfg.TEST_DATASET, + eval_name + '_sparse', + 'Annotations') + self.zip_dir_sparse = os.path.join( + cfg.DIR_EVALUATION, cfg.TEST_DATASET, + '{}_sparse.zip'.format(eval_name)) + else: + split = cfg.TEST_DATASET_SPLIT + youtubevos_test = YOUTUBEVOS_Test + + self.dataset = youtubevos_test(root=cfg.DIR_YTB, + year=year, + split=split, + transform=eval_transforms, + result_root=self.result_root) + + elif cfg.TEST_DATASET == 'davis2017': + resolution = 'Full-Resolution' if cfg.TEST_DATASET_FULL_RESOLUTION else '480p' + self.result_root = os.path.join(cfg.DIR_EVALUATION, + cfg.TEST_DATASET, eval_name, + 'Annotations', resolution) + self.dataset = DAVIS_Test( + split=[cfg.TEST_DATASET_SPLIT], + root=cfg.DIR_DAVIS, + year=2017, + transform=eval_transforms, + full_resolution=cfg.TEST_DATASET_FULL_RESOLUTION, + result_root=self.result_root) + + elif cfg.TEST_DATASET == 'davis2016': + resolution = 'Full-Resolution' if cfg.TEST_DATASET_FULL_RESOLUTION else '480p' + self.result_root = os.path.join(cfg.DIR_EVALUATION, + cfg.TEST_DATASET, eval_name, + 'Annotations', resolution) + self.dataset = DAVIS_Test( + split=[cfg.TEST_DATASET_SPLIT], + root=cfg.DIR_DAVIS, + year=2016, + transform=eval_transforms, + full_resolution=cfg.TEST_DATASET_FULL_RESOLUTION, + result_root=self.result_root) + + elif cfg.TEST_DATASET == 'test': + self.result_root = os.path.join(cfg.DIR_EVALUATION, + cfg.TEST_DATASET, eval_name, + 'Annotations') + self.dataset = EVAL_TEST(eval_transforms, self.result_root) + else: + self.print_log('Unknown dataset!') + exit() + + self.print_log('Eval {} on {} {}:'.format(cfg.EXP_NAME, + cfg.TEST_DATASET, + cfg.TEST_DATASET_SPLIT)) + self.source_folder = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, + eval_name, 'Annotations') + self.zip_dir = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, + '{}.zip'.format(eval_name)) + if not os.path.exists(self.result_root): + try: + os.makedirs(self.result_root) + except Exception as inst: + self.print_log(inst) + self.print_log('Failed to mask dir: {}.'.format( + self.result_root)) + self.print_log('Done!') + + def evaluating(self): + cfg = self.cfg + self.model.eval() + video_num = 0 + processed_video_num = 0 + total_time = 0 + total_frame = 0 + total_sfps = 0 + total_video_num = len(self.dataset) + start_eval_time = time.time() + + if self.seq_queue is not None: + if self.rank == 0: + for seq_idx in range(total_video_num): + self.seq_queue.put(seq_idx) + for _ in range(self.gpu_num): + self.seq_queue.put('END') + coming_seq_idx = self.seq_queue.get() + + all_engines = [] + with torch.no_grad(): + for seq_idx, seq_dataset in enumerate(self.dataset): + video_num += 1 + + if self.seq_queue is not None: + if coming_seq_idx == 'END': + break + elif coming_seq_idx != seq_idx: + continue + else: + coming_seq_idx = self.seq_queue.get() + + processed_video_num += 1 + + for engine in all_engines: + engine.restart_engine() + + seq_name = seq_dataset.seq_name + print('GPU {} - Processing Seq {} [{}/{}]:'.format( + self.gpu, seq_name, video_num, total_video_num)) + torch.cuda.empty_cache() + + seq_dataloader = DataLoader(seq_dataset, + batch_size=1, + shuffle=False, + num_workers=cfg.TEST_WORKERS, + pin_memory=True) + + if 'all_frames' in cfg.TEST_DATASET_SPLIT: + images_sparse = seq_dataset.images_sparse + seq_dir_sparse = os.path.join(self.result_root_sparse, + seq_name) + if not os.path.exists(seq_dir_sparse): + os.makedirs(seq_dir_sparse) + + seq_total_time = 0 + seq_total_frame = 0 + seq_pred_masks = {'dense': [], 'sparse': []} + seq_timers = [] + + for frame_idx, samples in enumerate(seq_dataloader): + + all_preds = [] + new_obj_label = None + aug_num = len(samples) + + for aug_idx in range(aug_num): + if len(all_engines) <= aug_idx: + all_engines.append( + build_engine(cfg.MODEL_ENGINE, + phase='eval', + aot_model=self.model, + gpu_id=self.gpu, + long_term_mem_gap=self.cfg. + TEST_LONG_TERM_MEM_GAP, + short_term_mem_skip=self.cfg. + TEST_SHORT_TERM_MEM_SKIP)) + all_engines[-1].eval() + + if aug_num > 1: # if use test-time augmentation + torch.cuda.empty_cache() # release GPU memory + + engine = all_engines[aug_idx] + + sample = samples[aug_idx] + + is_flipped = sample['meta']['flip'] + + obj_nums = sample['meta']['obj_num'] + imgname = sample['meta']['current_name'] + ori_height = sample['meta']['height'] + ori_width = sample['meta']['width'] + obj_idx = sample['meta']['obj_idx'] + + obj_nums = [int(obj_num) for obj_num in obj_nums] + obj_idx = [int(_obj_idx) for _obj_idx in obj_idx] + + current_img = sample['current_img'] + current_img = current_img.cuda(self.gpu, + non_blocking=True) + sample['current_img'] = current_img + + if 'current_label' in sample.keys(): + current_label = sample['current_label'].cuda( + self.gpu, non_blocking=True).float() + else: + current_label = None + + ############################################################# + + if frame_idx == 0: + _current_label = F.interpolate( + current_label, + size=current_img.size()[2:], + mode="nearest") + engine.add_reference_frame(current_img, + _current_label, + frame_step=0, + obj_nums=obj_nums) + else: + if aug_idx == 0: + seq_timers.append([]) + now_timer = torch.cuda.Event( + enable_timing=True) + now_timer.record() + seq_timers[-1].append(now_timer) + + engine.match_propogate_one_frame(current_img) + pred_logit = engine.decode_current_logits( + (ori_height, ori_width)) + + if is_flipped: + pred_logit = flip_tensor(pred_logit, 3) + + pred_prob = torch.softmax(pred_logit, dim=1) + all_preds.append(pred_prob) + + if not is_flipped and current_label is not None and new_obj_label is None: + new_obj_label = current_label + + if frame_idx > 0: + all_pred_probs = [ + torch.mean(pred, dim=0, keepdim=True) + for pred in all_preds + ] + all_pred_labels = [ + torch.argmax(prob, dim=1, keepdim=True).float() + for prob in all_pred_probs + ] + + cat_all_preds = torch.cat(all_preds, dim=0) + pred_prob = torch.mean(cat_all_preds, + dim=0, + keepdim=True) + pred_label = torch.argmax(pred_prob, + dim=1, + keepdim=True).float() + + if new_obj_label is not None: + keep = (new_obj_label == 0).float() + all_pred_labels = [label * \ + keep + new_obj_label * (1 - keep) for label in all_pred_labels] + + pred_label = pred_label * \ + keep + new_obj_label * (1 - keep) + new_obj_nums = [int(pred_label.max().item())] + + if cfg.TEST_FLIP: + all_flip_pred_labels = [ + flip_tensor(label, 3) + for label in all_pred_labels + ] + flip_pred_label = flip_tensor(pred_label, 3) + + for aug_idx in range(len(samples)): + engine = all_engines[aug_idx] + current_img = samples[aug_idx]['current_img'] + + # current_label = flip_pred_label if samples[ + # aug_idx]['meta']['flip'] else pred_label + current_label = all_flip_pred_labels[ + aug_idx] if samples[aug_idx]['meta'][ + 'flip'] else all_pred_labels[aug_idx] + current_label = F.interpolate( + current_label, + size=engine.input_size_2d, + mode="nearest") + engine.add_reference_frame( + current_img, + current_label, + obj_nums=new_obj_nums, + frame_step=frame_idx) + engine.decode_current_logits( + (ori_height, ori_width)) + engine.update_memory(current_label) + else: + if not cfg.MODEL_USE_PREV_PROB: + if cfg.TEST_FLIP: + all_flip_pred_labels = [ + flip_tensor(label, 3) + for label in all_pred_labels + ] + flip_pred_label = flip_tensor( + pred_label, 3) + + for aug_idx in range(len(samples)): + engine = all_engines[aug_idx] + # current_label = flip_pred_label if samples[ + # aug_idx]['meta']['flip'] else pred_label + current_label = all_flip_pred_labels[ + aug_idx] if samples[aug_idx]['meta'][ + 'flip'] else all_pred_labels[ + aug_idx] + current_label = F.interpolate( + current_label, + size=engine.input_size_2d, + mode="nearest") + engine.update_memory(current_label) + else: + if cfg.TEST_FLIP: + all_flip_pred_probs = [ + flip_tensor(prob, 3) + for prob in all_pred_probs + ] + flip_pred_prob = flip_tensor(pred_prob, 3) + + for aug_idx in range(len(samples)): + engine = all_engines[aug_idx] + # current_prob = flip_pred_prob if samples[ + # aug_idx]['meta']['flip'] else pred_prob + current_label = all_flip_pred_probs[ + aug_idx] if samples[aug_idx]['meta'][ + 'flip'] else all_pred_probs[aug_idx] + current_prob = F.interpolate( + current_prob, + size=engine.input_size_2d, + mode="nearest") + engine.update_memory(current_prob) + + now_timer = torch.cuda.Event(enable_timing=True) + now_timer.record() + seq_timers[-1].append((now_timer)) + + if cfg.TEST_FRAME_LOG: + torch.cuda.synchronize() + one_frametime = seq_timers[-1][0].elapsed_time( + seq_timers[-1][1]) / 1e3 + obj_num = obj_nums[0] + print( + 'GPU {} - Frame: {} - Obj Num: {}, Time: {}ms'. + format(self.gpu, imgname[0].split('.')[0], + obj_num, int(one_frametime * 1e3))) + # Save result + seq_pred_masks['dense'].append({ + 'path': + os.path.join(self.result_root, seq_name, + imgname[0].split('.')[0] + '.png'), + 'mask': + pred_label, + 'obj_idx': + obj_idx + }) + if 'all_frames' in cfg.TEST_DATASET_SPLIT and imgname in images_sparse: + seq_pred_masks['sparse'].append({ + 'path': + os.path.join(self.result_root_sparse, seq_name, + imgname[0].split('.')[0] + + '.png'), + 'mask': + pred_label, + 'obj_idx': + obj_idx + }) + + # Save result + for mask_result in seq_pred_masks['dense'] + seq_pred_masks[ + 'sparse']: + save_mask(mask_result['mask'].squeeze(0).squeeze(0), + mask_result['path'], mask_result['obj_idx']) + del (seq_pred_masks) + + for timer in seq_timers: + torch.cuda.synchronize() + one_frametime = timer[0].elapsed_time(timer[1]) / 1e3 + seq_total_time += one_frametime + seq_total_frame += 1 + del (seq_timers) + + seq_avg_time_per_frame = seq_total_time / seq_total_frame + total_time += seq_total_time + total_frame += seq_total_frame + total_avg_time_per_frame = total_time / total_frame + total_sfps += seq_avg_time_per_frame + avg_sfps = total_sfps / processed_video_num + max_mem = torch.cuda.max_memory_allocated( + device=self.gpu) / (1024.**3) + print( + "GPU {} - Seq {} - FPS: {:.2f}. All-Frame FPS: {:.2f}, All-Seq FPS: {:.2f}, Max Mem: {:.2f}G" + .format(self.gpu, seq_name, 1. / seq_avg_time_per_frame, + 1. / total_avg_time_per_frame, 1. / avg_sfps, + max_mem)) + + if self.seq_queue is not None: + if self.rank != 0: + self.info_queue.put({ + 'total_time': total_time, + 'total_frame': total_frame, + 'total_sfps': total_sfps, + 'processed_video_num': processed_video_num, + 'max_mem': max_mem + }) + print('Finished the evaluation on GPU {}.'.format(self.gpu)) + if self.rank == 0: + for _ in range(self.gpu_num - 1): + info_dict = self.info_queue.get() + total_time += info_dict['total_time'] + total_frame += info_dict['total_frame'] + total_sfps += info_dict['total_sfps'] + processed_video_num += info_dict['processed_video_num'] + max_mem = max(max_mem, info_dict['max_mem']) + all_reduced_total_avg_time_per_frame = total_time / total_frame + all_reduced_avg_sfps = total_sfps / processed_video_num + print( + "GPU {} - All-Frame FPS: {:.2f}, All-Seq FPS: {:.2f}, Max Mem: {:.2f}G" + .format(list(range(self.gpu_num)), + 1. / all_reduced_total_avg_time_per_frame, + 1. / all_reduced_avg_sfps, max_mem)) + else: + print( + "GPU {} - All-Frame FPS: {:.2f}, All-Seq FPS: {:.2f}, Max Mem: {:.2f}G" + .format(self.gpu, 1. / total_avg_time_per_frame, 1. / avg_sfps, + max_mem)) + + if self.rank == 0: + zip_folder(self.source_folder, self.zip_dir) + self.print_log('Saving result to {}.'.format(self.zip_dir)) + if 'all_frames' in cfg.TEST_DATASET_SPLIT: + zip_folder(self.result_root_sparse, self.zip_dir_sparse) + end_eval_time = time.time() + total_eval_time = str( + datetime.timedelta(seconds=int(end_eval_time - + start_eval_time))) + self.print_log("Total evaluation time: {}".format(total_eval_time)) + + def print_log(self, string): + if self.rank == 0: + print(string) diff --git a/aot/networks/managers/trainer.py b/aot/networks/managers/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..6fe4f42494720c0702e0dc871d2fa6aed83b85b9 --- /dev/null +++ b/aot/networks/managers/trainer.py @@ -0,0 +1,686 @@ +import os +import time +import json +import datetime as datetime + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import torch.distributed as dist +from torch.utils.data import DataLoader +from torchvision import transforms + +from dataloaders.train_datasets import DAVIS2017_Train, YOUTUBEVOS_Train, StaticTrain, TEST +import dataloaders.video_transforms as tr + +from utils.meters import AverageMeter +from utils.image import label2colormap, masked_image, save_image +from utils.checkpoint import load_network_and_optimizer, load_network, save_network +from utils.learning import adjust_learning_rate, get_trainable_params +from utils.metric import pytorch_iou +from utils.ema import ExponentialMovingAverage, get_param_buffer_for_ema + +from networks.models import build_vos_model +from networks.engines import build_engine + + +class Trainer(object): + def __init__(self, rank, cfg, enable_amp=True): + self.gpu = rank + cfg.DIST_START_GPU + self.gpu_num = cfg.TRAIN_GPUS + self.rank = rank + self.cfg = cfg + + self.print_log("Exp {}:".format(cfg.EXP_NAME)) + self.print_log(json.dumps(cfg.__dict__, indent=4, sort_keys=True)) + + print("Use GPU {} for training VOS.".format(self.gpu)) + torch.cuda.set_device(self.gpu) + torch.backends.cudnn.benchmark = True if cfg.DATA_RANDOMCROP[ + 0] == cfg.DATA_RANDOMCROP[ + 1] and 'swin' not in cfg.MODEL_ENCODER else False + + self.print_log('Build VOS model.') + + self.model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(self.gpu) + self.model_encoder = self.model.encoder + self.engine = build_engine( + cfg.MODEL_ENGINE, + 'train', + aot_model=self.model, + gpu_id=self.gpu, + long_term_mem_gap=cfg.TRAIN_LONG_TERM_MEM_GAP) + + if cfg.MODEL_FREEZE_BACKBONE: + for param in self.model_encoder.parameters(): + param.requires_grad = False + + if cfg.DIST_ENABLE: + dist.init_process_group(backend=cfg.DIST_BACKEND, + init_method=cfg.DIST_URL, + world_size=cfg.TRAIN_GPUS, + rank=rank, + timeout=datetime.timedelta(seconds=300)) + + self.model.encoder = nn.SyncBatchNorm.convert_sync_batchnorm( + self.model.encoder).cuda(self.gpu) + + self.dist_engine = torch.nn.parallel.DistributedDataParallel( + self.engine, + device_ids=[self.gpu], + output_device=self.gpu, + find_unused_parameters=True, + broadcast_buffers=False) + else: + self.dist_engine = self.engine + + self.use_frozen_bn = False + if 'swin' in cfg.MODEL_ENCODER: + self.print_log('Use LN in Encoder!') + elif not cfg.MODEL_FREEZE_BN: + if cfg.DIST_ENABLE: + self.print_log('Use Sync BN in Encoder!') + else: + self.print_log('Use BN in Encoder!') + else: + self.use_frozen_bn = True + self.print_log('Use Frozen BN in Encoder!') + + if self.rank == 0: + try: + total_steps = float(cfg.TRAIN_TOTAL_STEPS) + ema_decay = 1. - 1. / (total_steps * cfg.TRAIN_EMA_RATIO) + self.ema_params = get_param_buffer_for_ema( + self.model, update_buffer=(not cfg.MODEL_FREEZE_BN)) + self.ema = ExponentialMovingAverage(self.ema_params, + decay=ema_decay) + self.ema_dir = cfg.DIR_EMA_CKPT + except Exception as inst: + self.print_log(inst) + self.print_log('Error: failed to create EMA model!') + + self.print_log('Build optimizer.') + + trainable_params = get_trainable_params( + model=self.dist_engine, + base_lr=cfg.TRAIN_LR, + use_frozen_bn=self.use_frozen_bn, + weight_decay=cfg.TRAIN_WEIGHT_DECAY, + exclusive_wd_dict=cfg.TRAIN_WEIGHT_DECAY_EXCLUSIVE, + no_wd_keys=cfg.TRAIN_WEIGHT_DECAY_EXEMPTION) + + if cfg.TRAIN_OPT == 'sgd': + self.optimizer = optim.SGD(trainable_params, + lr=cfg.TRAIN_LR, + momentum=cfg.TRAIN_SGD_MOMENTUM, + nesterov=True) + else: + self.optimizer = optim.AdamW(trainable_params, + lr=cfg.TRAIN_LR, + weight_decay=cfg.TRAIN_WEIGHT_DECAY) + + self.enable_amp = enable_amp + if enable_amp: + self.scaler = torch.cuda.amp.GradScaler() + else: + self.scaler = None + + self.prepare_dataset() + self.process_pretrained_model() + + if cfg.TRAIN_TBLOG and self.rank == 0: + from tensorboardX import SummaryWriter + self.tblogger = SummaryWriter(cfg.DIR_TB_LOG) + + def process_pretrained_model(self): + cfg = self.cfg + + self.step = cfg.TRAIN_START_STEP + self.epoch = 0 + + if cfg.TRAIN_AUTO_RESUME: + ckpts = os.listdir(cfg.DIR_CKPT) + if len(ckpts) > 0: + ckpts = list( + map(lambda x: int(x.split('_')[-1].split('.')[0]), ckpts)) + ckpt = np.sort(ckpts)[-1] + cfg.TRAIN_RESUME = True + cfg.TRAIN_RESUME_CKPT = ckpt + cfg.TRAIN_RESUME_STEP = ckpt + else: + cfg.TRAIN_RESUME = False + + if cfg.TRAIN_RESUME: + if self.rank == 0: + try: + try: + ema_ckpt_dir = os.path.join( + self.ema_dir, + 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) + ema_model, removed_dict = load_network( + self.model, ema_ckpt_dir, self.gpu) + except Exception as inst: + self.print_log(inst) + self.print_log('Try to use backup EMA checkpoint.') + DIR_RESULT = './backup/{}/{}'.format( + cfg.EXP_NAME, cfg.STAGE_NAME) + DIR_EMA_CKPT = os.path.join(DIR_RESULT, 'ema_ckpt') + ema_ckpt_dir = os.path.join( + DIR_EMA_CKPT, + 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) + ema_model, removed_dict = load_network( + self.model, ema_ckpt_dir, self.gpu) + + if len(removed_dict) > 0: + self.print_log( + 'Remove {} from EMA model.'.format(removed_dict)) + ema_decay = self.ema.decay + del (self.ema) + + ema_params = get_param_buffer_for_ema( + ema_model, update_buffer=(not cfg.MODEL_FREEZE_BN)) + self.ema = ExponentialMovingAverage(ema_params, + decay=ema_decay) + self.ema.num_updates = cfg.TRAIN_RESUME_CKPT + except Exception as inst: + self.print_log(inst) + self.print_log('Error: EMA model not found!') + + try: + resume_ckpt = os.path.join( + cfg.DIR_CKPT, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) + self.model, self.optimizer, removed_dict = load_network_and_optimizer( + self.model, + self.optimizer, + resume_ckpt, + self.gpu, + scaler=self.scaler) + except Exception as inst: + self.print_log(inst) + self.print_log('Try to use backup checkpoint.') + DIR_RESULT = './backup/{}/{}'.format(cfg.EXP_NAME, + cfg.STAGE_NAME) + DIR_CKPT = os.path.join(DIR_RESULT, 'ckpt') + resume_ckpt = os.path.join( + DIR_CKPT, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) + self.model, self.optimizer, removed_dict = load_network_and_optimizer( + self.model, + self.optimizer, + resume_ckpt, + self.gpu, + scaler=self.scaler) + + if len(removed_dict) > 0: + self.print_log( + 'Remove {} from checkpoint.'.format(removed_dict)) + + self.step = cfg.TRAIN_RESUME_STEP + if cfg.TRAIN_TOTAL_STEPS <= self.step: + self.print_log("Your training has finished!") + exit() + self.epoch = int(np.ceil(self.step / len(self.train_loader))) + + self.print_log('Resume from step {}'.format(self.step)) + + elif cfg.PRETRAIN: + if cfg.PRETRAIN_FULL: + try: + self.model, removed_dict = load_network( + self.model, cfg.PRETRAIN_MODEL, self.gpu) + except Exception as inst: + self.print_log(inst) + self.print_log('Try to use backup EMA checkpoint.') + DIR_RESULT = './backup/{}/{}'.format( + cfg.EXP_NAME, cfg.STAGE_NAME) + DIR_EMA_CKPT = os.path.join(DIR_RESULT, 'ema_ckpt') + PRETRAIN_MODEL = os.path.join( + DIR_EMA_CKPT, + cfg.PRETRAIN_MODEL.split('/')[-1]) + self.model, removed_dict = load_network( + self.model, PRETRAIN_MODEL, self.gpu) + + if len(removed_dict) > 0: + self.print_log('Remove {} from pretrained model.'.format( + removed_dict)) + self.print_log('Load pretrained VOS model from {}.'.format( + cfg.PRETRAIN_MODEL)) + else: + model_encoder, removed_dict = load_network( + self.model_encoder, cfg.PRETRAIN_MODEL, self.gpu) + if len(removed_dict) > 0: + self.print_log('Remove {} from pretrained model.'.format( + removed_dict)) + self.print_log( + 'Load pretrained backbone model from {}.'.format( + cfg.PRETRAIN_MODEL)) + + def prepare_dataset(self): + cfg = self.cfg + self.enable_prev_frame = cfg.TRAIN_ENABLE_PREV_FRAME + + self.print_log('Process dataset...') + if cfg.TRAIN_AUG_TYPE == 'v1': + composed_transforms = transforms.Compose([ + tr.RandomScale(cfg.DATA_MIN_SCALE_FACTOR, + cfg.DATA_MAX_SCALE_FACTOR, + cfg.DATA_SHORT_EDGE_LEN), + tr.BalancedRandomCrop(cfg.DATA_RANDOMCROP, + max_obj_num=cfg.MODEL_MAX_OBJ_NUM), + tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), + tr.Resize(cfg.DATA_RANDOMCROP, use_padding=True), + tr.ToTensor() + ]) + elif cfg.TRAIN_AUG_TYPE == 'v2': + composed_transforms = transforms.Compose([ + tr.RandomScale(cfg.DATA_MIN_SCALE_FACTOR, + cfg.DATA_MAX_SCALE_FACTOR, + cfg.DATA_SHORT_EDGE_LEN), + tr.BalancedRandomCrop(cfg.DATA_RANDOMCROP, + max_obj_num=cfg.MODEL_MAX_OBJ_NUM), + tr.RandomColorJitter(), + tr.RandomGrayScale(), + tr.RandomGaussianBlur(), + tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), + tr.Resize(cfg.DATA_RANDOMCROP, use_padding=True), + tr.ToTensor() + ]) + else: + assert NotImplementedError + + train_datasets = [] + if 'static' in cfg.DATASETS: + pretrain_vos_dataset = StaticTrain( + cfg.DIR_STATIC, + cfg.DATA_RANDOMCROP, + seq_len=cfg.DATA_SEQ_LEN, + merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB, + max_obj_n=cfg.MODEL_MAX_OBJ_NUM, + aug_type=cfg.TRAIN_AUG_TYPE) + train_datasets.append(pretrain_vos_dataset) + self.enable_prev_frame = False + + if 'davis2017' in cfg.DATASETS: + train_davis_dataset = DAVIS2017_Train( + root=cfg.DIR_DAVIS, + full_resolution=cfg.TRAIN_DATASET_FULL_RESOLUTION, + transform=composed_transforms, + repeat_time=cfg.DATA_DAVIS_REPEAT, + seq_len=cfg.DATA_SEQ_LEN, + rand_gap=cfg.DATA_RANDOM_GAP_DAVIS, + rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ, + merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB, + enable_prev_frame=self.enable_prev_frame, + max_obj_n=cfg.MODEL_MAX_OBJ_NUM) + train_datasets.append(train_davis_dataset) + + if 'youtubevos' in cfg.DATASETS: + train_ytb_dataset = YOUTUBEVOS_Train( + root=cfg.DIR_YTB, + transform=composed_transforms, + seq_len=cfg.DATA_SEQ_LEN, + rand_gap=cfg.DATA_RANDOM_GAP_YTB, + rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ, + merge_prob=cfg.DATA_DYNAMIC_MERGE_PROB, + enable_prev_frame=self.enable_prev_frame, + max_obj_n=cfg.MODEL_MAX_OBJ_NUM) + train_datasets.append(train_ytb_dataset) + + if 'test' in cfg.DATASETS: + test_dataset = TEST(transform=composed_transforms, + seq_len=cfg.DATA_SEQ_LEN) + train_datasets.append(test_dataset) + + if len(train_datasets) > 1: + train_dataset = torch.utils.data.ConcatDataset(train_datasets) + elif len(train_datasets) == 1: + train_dataset = train_datasets[0] + else: + self.print_log('No dataset!') + exit(0) + + self.train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset) if self.cfg.DIST_ENABLE else None + self.train_loader = DataLoader(train_dataset, + batch_size=int(cfg.TRAIN_BATCH_SIZE / + cfg.TRAIN_GPUS), + shuffle=False if self.cfg.DIST_ENABLE else True, + num_workers=cfg.DATA_WORKERS, + pin_memory=True, + sampler=self.train_sampler, + drop_last=True, + prefetch_factor=4) + + self.print_log('Done!') + + def sequential_training(self): + + cfg = self.cfg + + if self.enable_prev_frame: + frame_names = ['Ref', 'Prev'] + else: + frame_names = ['Ref(Prev)'] + + for i in range(cfg.DATA_SEQ_LEN - 1): + frame_names.append('Curr{}'.format(i + 1)) + + seq_len = len(frame_names) + + running_losses = [] + running_ious = [] + for _ in range(seq_len): + running_losses.append(AverageMeter()) + running_ious.append(AverageMeter()) + batch_time = AverageMeter() + avg_obj = AverageMeter() + + optimizer = self.optimizer + model = self.dist_engine + train_sampler = self.train_sampler + train_loader = self.train_loader + step = self.step + epoch = self.epoch + max_itr = cfg.TRAIN_TOTAL_STEPS + start_seq_training_step = int(cfg.TRAIN_SEQ_TRAINING_START_RATIO * + max_itr) + use_prev_prob = cfg.MODEL_USE_PREV_PROB + + self.print_log('Start training:') + model.train() + while step < cfg.TRAIN_TOTAL_STEPS: + if self.cfg.DIST_ENABLE: + train_sampler.set_epoch(epoch) + epoch += 1 + last_time = time.time() + for frame_idx, sample in enumerate(train_loader): + if step > cfg.TRAIN_TOTAL_STEPS: + break + + if step % cfg.TRAIN_TBLOG_STEP == 0 and self.rank == 0 and cfg.TRAIN_TBLOG: + tf_board = True + else: + tf_board = False + + if step >= start_seq_training_step: + use_prev_pred = True + freeze_params = cfg.TRAIN_SEQ_TRAINING_FREEZE_PARAMS + else: + use_prev_pred = False + freeze_params = [] + + if step % cfg.TRAIN_LR_UPDATE_STEP == 0: + now_lr = adjust_learning_rate( + optimizer=optimizer, + base_lr=cfg.TRAIN_LR, + p=cfg.TRAIN_LR_POWER, + itr=step, + max_itr=max_itr, + restart=cfg.TRAIN_LR_RESTART, + warm_up_steps=cfg.TRAIN_LR_WARM_UP_RATIO * max_itr, + is_cosine_decay=cfg.TRAIN_LR_COSINE_DECAY, + min_lr=cfg.TRAIN_LR_MIN, + encoder_lr_ratio=cfg.TRAIN_LR_ENCODER_RATIO, + freeze_params=freeze_params) + + ref_imgs = sample['ref_img'] # batch_size * 3 * h * w + prev_imgs = sample['prev_img'] + curr_imgs = sample['curr_img'] + ref_labels = sample['ref_label'] # batch_size * 1 * h * w + prev_labels = sample['prev_label'] + curr_labels = sample['curr_label'] + obj_nums = sample['meta']['obj_num'] + bs, _, h, w = curr_imgs[0].size() + + ref_imgs = ref_imgs.cuda(self.gpu, non_blocking=True) + prev_imgs = prev_imgs.cuda(self.gpu, non_blocking=True) + curr_imgs = [ + curr_img.cuda(self.gpu, non_blocking=True) + for curr_img in curr_imgs + ] + ref_labels = ref_labels.cuda(self.gpu, non_blocking=True) + prev_labels = prev_labels.cuda(self.gpu, non_blocking=True) + curr_labels = [ + curr_label.cuda(self.gpu, non_blocking=True) + for curr_label in curr_labels + ] + obj_nums = list(obj_nums) + obj_nums = [int(obj_num) for obj_num in obj_nums] + + batch_size = ref_imgs.size(0) + + all_frames = torch.cat([ref_imgs, prev_imgs] + curr_imgs, + dim=0) + all_labels = torch.cat([ref_labels, prev_labels] + curr_labels, + dim=0) + + self.engine.restart_engine(batch_size, True) + optimizer.zero_grad(set_to_none=True) + + if self.enable_amp: + with torch.cuda.amp.autocast(enabled=True): + + loss, all_pred, all_loss, boards = model( + all_frames, + all_labels, + batch_size, + use_prev_pred=use_prev_pred, + obj_nums=obj_nums, + step=step, + tf_board=tf_board, + enable_prev_frame=self.enable_prev_frame, + use_prev_prob=use_prev_prob) + loss = torch.mean(loss) + + start = time.time() + self.scaler.scale(loss).backward() + end = time.time() + print(end-start) + self.scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), + cfg.TRAIN_CLIP_GRAD_NORM) + self.scaler.step(optimizer) + self.scaler.update() + + else: + loss, all_pred, all_loss, boards = model( + all_frames, + all_labels, + ref_imgs.size(0), + use_prev_pred=use_prev_pred, + obj_nums=obj_nums, + step=step, + tf_board=tf_board, + enable_prev_frame=self.enable_prev_frame, + use_prev_prob=use_prev_prob) + loss = torch.mean(loss) + + torch.nn.utils.clip_grad_norm_(model.parameters(), + cfg.TRAIN_CLIP_GRAD_NORM) + loss.backward() + optimizer.step() + + for idx in range(seq_len): + now_pred = all_pred[idx].detach() + now_label = all_labels[idx * bs:(idx + 1) * bs].detach() + now_loss = torch.mean(all_loss[idx].detach()) + now_iou = pytorch_iou(now_pred.unsqueeze(1), now_label, + obj_nums) * 100 + if self.cfg.DIST_ENABLE: + dist.all_reduce(now_loss) + dist.all_reduce(now_iou) + now_loss /= self.gpu_num + now_iou /= self.gpu_num + if self.rank == 0: + running_losses[idx].update(now_loss.item()) + running_ious[idx].update(now_iou.item()) + + if self.rank == 0: + self.ema.update(self.ema_params) + + avg_obj.update(sum(obj_nums) / float(len(obj_nums))) + curr_time = time.time() + batch_time.update(curr_time - last_time) + last_time = curr_time + + if step % cfg.TRAIN_TBLOG_STEP == 0: + all_f = [ref_imgs, prev_imgs] + curr_imgs + self.process_log(ref_imgs, all_f[-2], all_f[-1], + ref_labels, all_pred[-2], now_label, + now_pred, boards, running_losses, + running_ious, now_lr, step) + + if step % cfg.TRAIN_LOG_STEP == 0: + strs = 'I:{}, LR:{:.5f}, T:{:.1f}({:.1f})s, Obj:{:.1f}({:.1f})'.format( + step, now_lr, batch_time.val, + batch_time.moving_avg, avg_obj.val, + avg_obj.moving_avg) + batch_time.reset() + avg_obj.reset() + for idx in range(seq_len): + strs += ', {}: L {:.3f}({:.3f}) IoU {:.1f}({:.1f})%'.format( + frame_names[idx], running_losses[idx].val, + running_losses[idx].moving_avg, + running_ious[idx].val, + running_ious[idx].moving_avg) + running_losses[idx].reset() + running_ious[idx].reset() + + self.print_log(strs) + + step += 1 + + if step % cfg.TRAIN_SAVE_STEP == 0 and self.rank == 0: + max_mem = torch.cuda.max_memory_allocated( + device=self.gpu) / (1024.**3) + ETA = str( + datetime.timedelta( + seconds=int(batch_time.moving_avg * + (cfg.TRAIN_TOTAL_STEPS - step)))) + self.print_log('ETA: {}, Max Mem: {:.2f}G.'.format( + ETA, max_mem)) + self.print_log('Save CKPT (Step {}).'.format(step)) + save_network(self.model, + optimizer, + step, + cfg.DIR_CKPT, + cfg.TRAIN_MAX_KEEP_CKPT, + backup_dir='./backup/{}/{}/ckpt'.format( + cfg.EXP_NAME, cfg.STAGE_NAME), + scaler=self.scaler) + try: + torch.cuda.empty_cache() + # First save original parameters before replacing with EMA version + self.ema.store(self.ema_params) + # Copy EMA parameters to model + self.ema.copy_to(self.ema_params) + # Save EMA model + save_network( + self.model, + optimizer, + step, + self.ema_dir, + cfg.TRAIN_MAX_KEEP_CKPT, + backup_dir='./backup/{}/{}/ema_ckpt'.format( + cfg.EXP_NAME, cfg.STAGE_NAME), + scaler=self.scaler) + # Restore original parameters to resume training later + self.ema.restore(self.ema_params) + except Exception as inst: + self.print_log(inst) + self.print_log('Error: failed to save EMA model!') + + self.print_log('Stop training!') + + def print_log(self, string): + if self.rank == 0: + print(string) + + def process_log(self, ref_imgs, prev_imgs, curr_imgs, ref_labels, + prev_labels, curr_labels, curr_pred, boards, + running_losses, running_ious, now_lr, step): + cfg = self.cfg + + mean = np.array([[[0.485]], [[0.456]], [[0.406]]]) + sigma = np.array([[[0.229]], [[0.224]], [[0.225]]]) + + show_ref_img, show_prev_img, show_curr_img = [ + img.cpu().numpy()[0] * sigma + mean + for img in [ref_imgs, prev_imgs, curr_imgs] + ] + + show_gt, show_prev_gt, show_ref_gt, show_preds_s = [ + label.cpu()[0].squeeze(0).numpy() + for label in [curr_labels, prev_labels, ref_labels, curr_pred] + ] + + show_gtf, show_prev_gtf, show_ref_gtf, show_preds_sf = [ + label2colormap(label).transpose((2, 0, 1)) + for label in [show_gt, show_prev_gt, show_ref_gt, show_preds_s] + ] + + if cfg.TRAIN_IMG_LOG or cfg.TRAIN_TBLOG: + + show_ref_img = masked_image(show_ref_img, show_ref_gtf, + show_ref_gt) + if cfg.TRAIN_IMG_LOG: + save_image( + show_ref_img, + os.path.join(cfg.DIR_IMG_LOG, + '%06d_ref_img.jpeg' % (step))) + + show_prev_img = masked_image(show_prev_img, show_prev_gtf, + show_prev_gt) + if cfg.TRAIN_IMG_LOG: + save_image( + show_prev_img, + os.path.join(cfg.DIR_IMG_LOG, + '%06d_prev_img.jpeg' % (step))) + + show_img_pred = masked_image(show_curr_img, show_preds_sf, + show_preds_s) + if cfg.TRAIN_IMG_LOG: + save_image( + show_img_pred, + os.path.join(cfg.DIR_IMG_LOG, + '%06d_prediction.jpeg' % (step))) + + show_curr_img = masked_image(show_curr_img, show_gtf, show_gt) + if cfg.TRAIN_IMG_LOG: + save_image( + show_curr_img, + os.path.join(cfg.DIR_IMG_LOG, + '%06d_groundtruth.jpeg' % (step))) + + if cfg.TRAIN_TBLOG: + for seq_step, running_loss, running_iou in zip( + range(len(running_losses)), running_losses, + running_ious): + self.tblogger.add_scalar('S{}/Loss'.format(seq_step), + running_loss.avg, step) + self.tblogger.add_scalar('S{}/IoU'.format(seq_step), + running_iou.avg, step) + + self.tblogger.add_scalar('LR', now_lr, step) + self.tblogger.add_image('Ref/Image', show_ref_img, step) + self.tblogger.add_image('Ref/GT', show_ref_gtf, step) + + self.tblogger.add_image('Prev/Image', show_prev_img, step) + self.tblogger.add_image('Prev/GT', show_prev_gtf, step) + + self.tblogger.add_image('Curr/Image_GT', show_curr_img, step) + self.tblogger.add_image('Curr/Image_Pred', show_img_pred, step) + + self.tblogger.add_image('Curr/Mask_GT', show_gtf, step) + self.tblogger.add_image('Curr/Mask_Pred', show_preds_sf, step) + + for key in boards['image'].keys(): + tmp = boards['image'][key].cpu().numpy() + self.tblogger.add_image('S{}/' + key, tmp, step) + for key in boards['scalar'].keys(): + tmp = boards['scalar'][key].cpu().numpy() + self.tblogger.add_scalar('S{}/' + key, tmp, step) + + self.tblogger.flush() + + del (boards) diff --git a/aot/networks/models/__init__.py b/aot/networks/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63995d437a912b4c2e7d44f7dc9ea9ec8dc3b546 --- /dev/null +++ b/aot/networks/models/__init__.py @@ -0,0 +1,11 @@ +from networks.models.aot import AOT +from networks.models.deaot import DeAOT + + +def build_vos_model(name, cfg, **kwargs): + if name == 'aot': + return AOT(cfg, encoder=cfg.MODEL_ENCODER, **kwargs) + elif name == 'deaot': + return DeAOT(cfg, encoder=cfg.MODEL_ENCODER, **kwargs) + else: + raise NotImplementedError diff --git a/aot/networks/models/__pycache__/__init__.cpython-310.pyc b/aot/networks/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19b2676c14974db7ae3c65a81ba9ce9c8c47b813 Binary files /dev/null and b/aot/networks/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/aot/networks/models/__pycache__/aot.cpython-310.pyc b/aot/networks/models/__pycache__/aot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c73bbc9c7d44288825c1be9441720e3dfa82fe4 Binary files /dev/null and b/aot/networks/models/__pycache__/aot.cpython-310.pyc differ diff --git a/aot/networks/models/__pycache__/deaot.cpython-310.pyc b/aot/networks/models/__pycache__/deaot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4e9d0f6bb0cd0a2b348d14cbbe418d71980ad6c Binary files /dev/null and b/aot/networks/models/__pycache__/deaot.cpython-310.pyc differ diff --git a/aot/networks/models/aot.py b/aot/networks/models/aot.py new file mode 100644 index 0000000000000000000000000000000000000000..813ee04be9028fe342310574d6f73793c4d506a8 --- /dev/null +++ b/aot/networks/models/aot.py @@ -0,0 +1,115 @@ +import torch.nn as nn + +from networks.encoders import build_encoder +from networks.layers.transformer import LongShortTermTransformer +from networks.decoders import build_decoder +from networks.layers.position import PositionEmbeddingSine + + +class AOT(nn.Module): + def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'): + super().__init__() + self.cfg = cfg + self.max_obj_num = cfg.MODEL_MAX_OBJ_NUM + self.epsilon = cfg.MODEL_EPSILON + + self.encoder = build_encoder(encoder, + frozen_bn=cfg.MODEL_FREEZE_BN, + freeze_at=cfg.TRAIN_ENCODER_FREEZE_AT) + self.encoder_projector = nn.Conv2d(cfg.MODEL_ENCODER_DIM[-1], + cfg.MODEL_ENCODER_EMBEDDING_DIM, + kernel_size=1) + + self.LSTT = LongShortTermTransformer( + cfg.MODEL_LSTT_NUM, + cfg.MODEL_ENCODER_EMBEDDING_DIM, + cfg.MODEL_SELF_HEADS, + cfg.MODEL_ATT_HEADS, + emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT, + droppath=cfg.TRAIN_LSTT_DROPPATH, + lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT, + st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT, + droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST, + droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING, + intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, + return_intermediate=True) + + decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \ + (cfg.MODEL_LSTT_NUM + + 1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM + + self.decoder = build_decoder( + decoder, + in_dim=decoder_indim, + out_dim=cfg.MODEL_MAX_OBJ_NUM + 1, + decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, + hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM, + shortcut_dims=cfg.MODEL_ENCODER_DIM, + align_corners=cfg.MODEL_ALIGN_CORNERS) + + if cfg.MODEL_ALIGN_CORNERS: + self.patch_wise_id_bank = nn.Conv2d( + cfg.MODEL_MAX_OBJ_NUM + 1, + cfg.MODEL_ENCODER_EMBEDDING_DIM, + kernel_size=17, + stride=16, + padding=8) + else: + self.patch_wise_id_bank = nn.Conv2d( + cfg.MODEL_MAX_OBJ_NUM + 1, + cfg.MODEL_ENCODER_EMBEDDING_DIM, + kernel_size=16, + stride=16, + padding=0) + + self.id_dropout = nn.Dropout(cfg.TRAIN_LSTT_ID_DROPOUT, True) + + self.pos_generator = PositionEmbeddingSine( + cfg.MODEL_ENCODER_EMBEDDING_DIM // 2, normalize=True) + + self._init_weight() + + def get_pos_emb(self, x): + pos_emb = self.pos_generator(x) + return pos_emb + + def get_id_emb(self, x): + id_emb = self.patch_wise_id_bank(x) + id_emb = self.id_dropout(id_emb) + return id_emb + + def encode_image(self, img): + xs = self.encoder(img) + xs[-1] = self.encoder_projector(xs[-1]) + return xs + + def decode_id_logits(self, lstt_emb, shortcuts): + n, c, h, w = shortcuts[-1].size() + decoder_inputs = [shortcuts[-1]] + for emb in lstt_emb: + decoder_inputs.append(emb.view(h, w, n, c).permute(2, 3, 0, 1)) + pred_logit = self.decoder(decoder_inputs, shortcuts) + return pred_logit + + def LSTT_forward(self, + curr_embs, + long_term_memories, + short_term_memories, + curr_id_emb=None, + pos_emb=None, + size_2d=(30, 30)): + n, c, h, w = curr_embs[-1].size() + curr_emb = curr_embs[-1].view(n, c, h * w).permute(2, 0, 1) + lstt_embs, lstt_memories = self.LSTT(curr_emb, long_term_memories, + short_term_memories, curr_id_emb, + pos_emb, size_2d) + lstt_curr_memories, lstt_long_memories, lstt_short_memories = zip( + *lstt_memories) + return lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories + + def _init_weight(self): + nn.init.xavier_uniform_(self.encoder_projector.weight) + nn.init.orthogonal_( + self.patch_wise_id_bank.weight.view( + self.cfg.MODEL_ENCODER_EMBEDDING_DIM, -1).permute(0, 1), + gain=17**-2 if self.cfg.MODEL_ALIGN_CORNERS else 16**-2) diff --git a/aot/networks/models/deaot.py b/aot/networks/models/deaot.py new file mode 100644 index 0000000000000000000000000000000000000000..008dd43c75911d056843582bd073c0c226ddf37d --- /dev/null +++ b/aot/networks/models/deaot.py @@ -0,0 +1,55 @@ +import torch.nn as nn + +from networks.layers.transformer import DualBranchGPM +from networks.models.aot import AOT +from networks.decoders import build_decoder + + +class DeAOT(AOT): + def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'): + super().__init__(cfg, encoder, decoder) + + self.LSTT = DualBranchGPM( + cfg.MODEL_LSTT_NUM, + cfg.MODEL_ENCODER_EMBEDDING_DIM, + cfg.MODEL_SELF_HEADS, + cfg.MODEL_ATT_HEADS, + emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT, + droppath=cfg.TRAIN_LSTT_DROPPATH, + lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT, + st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT, + droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST, + droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING, + intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, + return_intermediate=True) + + decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \ + (cfg.MODEL_LSTT_NUM * 2 + + 1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM * 2 + + self.decoder = build_decoder( + decoder, + in_dim=decoder_indim, + out_dim=cfg.MODEL_MAX_OBJ_NUM + 1, + decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, + hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM, + shortcut_dims=cfg.MODEL_ENCODER_DIM, + align_corners=cfg.MODEL_ALIGN_CORNERS) + + self.id_norm = nn.LayerNorm(cfg.MODEL_ENCODER_EMBEDDING_DIM) + + self._init_weight() + + def decode_id_logits(self, lstt_emb, shortcuts): + n, c, h, w = shortcuts[-1].size() + decoder_inputs = [shortcuts[-1]] + for emb in lstt_emb: + decoder_inputs.append(emb.view(h, w, n, -1).permute(2, 3, 0, 1)) + pred_logit = self.decoder(decoder_inputs, shortcuts) + return pred_logit + + def get_id_emb(self, x): + id_emb = self.patch_wise_id_bank(x) + id_emb = self.id_norm(id_emb.permute(2, 3, 0, 1)).permute(2, 3, 0, 1) + id_emb = self.id_dropout(id_emb) + return id_emb diff --git a/aot/pretrain_models/README.md b/aot/pretrain_models/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6c72d93391cd314e85690ab2514b2d22642c4d97 --- /dev/null +++ b/aot/pretrain_models/README.md @@ -0,0 +1 @@ +Put pretrained models here. \ No newline at end of file diff --git a/aot/source/.DS_Store b/aot/source/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/aot/source/.DS_Store differ diff --git a/aot/source/overview.png b/aot/source/overview.png new file mode 100644 index 0000000000000000000000000000000000000000..0b3308870114399f601900ca7369a5904ed5de72 Binary files /dev/null and b/aot/source/overview.png differ diff --git a/aot/source/overview_deaot.png b/aot/source/overview_deaot.png new file mode 100644 index 0000000000000000000000000000000000000000..bdb15a162c7557aa62a1439d8cc6e922c7567db4 Binary files /dev/null and b/aot/source/overview_deaot.png differ diff --git a/aot/tools/demo.py b/aot/tools/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..64c30c0dd0afc9b335d286a0af7165738c24e3dd --- /dev/null +++ b/aot/tools/demo.py @@ -0,0 +1,286 @@ +import importlib +import sys +import os + +sys.path.append('.') +sys.path.append('..') + +import cv2 +from PIL import Image +from skimage.morphology.binary import binary_dilation + +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision import transforms + +from networks.models import build_vos_model +from networks.engines import build_engine +from utils.checkpoint import load_network + +from dataloaders.eval_datasets import VOSTest +import dataloaders.video_transforms as tr +from utils.image import save_mask + +_palette = [ + 255, 0, 0, 0, 0, 139, 255, 255, 84, 0, 255, 0, 139, 0, 139, 0, 128, 128, + 128, 128, 128, 139, 0, 0, 218, 165, 32, 144, 238, 144, 160, 82, 45, 148, 0, + 211, 255, 0, 255, 30, 144, 255, 255, 218, 185, 85, 107, 47, 255, 140, 0, + 50, 205, 50, 123, 104, 238, 240, 230, 140, 72, 61, 139, 128, 128, 0, 0, 0, + 205, 221, 160, 221, 143, 188, 143, 127, 255, 212, 176, 224, 230, 244, 164, + 96, 250, 128, 114, 70, 130, 180, 0, 128, 0, 173, 255, 47, 255, 105, 180, + 238, 130, 238, 154, 205, 50, 220, 20, 60, 176, 48, 96, 0, 206, 209, 0, 191, + 255, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 43, 43, 44, 44, 44, 45, 45, + 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 49, 50, 50, 50, 51, 51, 51, + 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 56, 56, 56, 57, 57, 57, 58, + 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, 62, 62, 63, 63, 63, 64, 64, + 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, 68, 69, 69, 69, 70, 70, 70, + 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, 75, 75, 75, 76, 76, 76, 77, + 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, 81, 81, 82, 82, 82, 83, 83, + 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, 87, 88, 88, 88, 89, 89, 89, + 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, 94, 94, 94, 95, 95, 95, 96, + 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, 100, 100, 100, 101, 101, 101, + 102, 102, 102, 103, 103, 103, 104, 104, 104, 105, 105, 105, 106, 106, 106, + 107, 107, 107, 108, 108, 108, 109, 109, 109, 110, 110, 110, 111, 111, 111, + 112, 112, 112, 113, 113, 113, 114, 114, 114, 115, 115, 115, 116, 116, 116, + 117, 117, 117, 118, 118, 118, 119, 119, 119, 120, 120, 120, 121, 121, 121, + 122, 122, 122, 123, 123, 123, 124, 124, 124, 125, 125, 125, 126, 126, 126, + 127, 127, 127, 128, 128, 128, 129, 129, 129, 130, 130, 130, 131, 131, 131, + 132, 132, 132, 133, 133, 133, 134, 134, 134, 135, 135, 135, 136, 136, 136, + 137, 137, 137, 138, 138, 138, 139, 139, 139, 140, 140, 140, 141, 141, 141, + 142, 142, 142, 143, 143, 143, 144, 144, 144, 145, 145, 145, 146, 146, 146, + 147, 147, 147, 148, 148, 148, 149, 149, 149, 150, 150, 150, 151, 151, 151, + 152, 152, 152, 153, 153, 153, 154, 154, 154, 155, 155, 155, 156, 156, 156, + 157, 157, 157, 158, 158, 158, 159, 159, 159, 160, 160, 160, 161, 161, 161, + 162, 162, 162, 163, 163, 163, 164, 164, 164, 165, 165, 165, 166, 166, 166, + 167, 167, 167, 168, 168, 168, 169, 169, 169, 170, 170, 170, 171, 171, 171, + 172, 172, 172, 173, 173, 173, 174, 174, 174, 175, 175, 175, 176, 176, 176, + 177, 177, 177, 178, 178, 178, 179, 179, 179, 180, 180, 180, 181, 181, 181, + 182, 182, 182, 183, 183, 183, 184, 184, 184, 185, 185, 185, 186, 186, 186, + 187, 187, 187, 188, 188, 188, 189, 189, 189, 190, 190, 190, 191, 191, 191, + 192, 192, 192, 193, 193, 193, 194, 194, 194, 195, 195, 195, 196, 196, 196, + 197, 197, 197, 198, 198, 198, 199, 199, 199, 200, 200, 200, 201, 201, 201, + 202, 202, 202, 203, 203, 203, 204, 204, 204, 205, 205, 205, 206, 206, 206, + 207, 207, 207, 208, 208, 208, 209, 209, 209, 210, 210, 210, 211, 211, 211, + 212, 212, 212, 213, 213, 213, 214, 214, 214, 215, 215, 215, 216, 216, 216, + 217, 217, 217, 218, 218, 218, 219, 219, 219, 220, 220, 220, 221, 221, 221, + 222, 222, 222, 223, 223, 223, 224, 224, 224, 225, 225, 225, 226, 226, 226, + 227, 227, 227, 228, 228, 228, 229, 229, 229, 230, 230, 230, 231, 231, 231, + 232, 232, 232, 233, 233, 233, 234, 234, 234, 235, 235, 235, 236, 236, 236, + 237, 237, 237, 238, 238, 238, 239, 239, 239, 240, 240, 240, 241, 241, 241, + 242, 242, 242, 243, 243, 243, 244, 244, 244, 245, 245, 245, 246, 246, 246, + 247, 247, 247, 248, 248, 248, 249, 249, 249, 250, 250, 250, 251, 251, 251, + 252, 252, 252, 253, 253, 253, 254, 254, 254, 255, 255, 255, 0, 0, 0 +] +color_palette = np.array(_palette).reshape(-1, 3) + + +def overlay(image, mask, colors=[255, 0, 0], cscale=1, alpha=0.4): + colors = np.atleast_2d(colors) * cscale + + im_overlay = image.copy() + object_ids = np.unique(mask) + + for object_id in object_ids[1:]: + # Overlay color on binary mask + + foreground = image * alpha + np.ones( + image.shape) * (1 - alpha) * np.array(colors[object_id]) + binary_mask = mask == object_id + + # Compose image + im_overlay[binary_mask] = foreground[binary_mask] + + countours = binary_dilation(binary_mask) ^ binary_mask + im_overlay[countours, :] = 0 + + return im_overlay.astype(image.dtype) + + +def demo(cfg): + video_fps = 15 + gpu_id = cfg.TEST_GPU_ID + + # Load pre-trained model + print('Build AOT model.') + model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(gpu_id) + + print('Load checkpoint from {}'.format(cfg.TEST_CKPT_PATH)) + model, _ = load_network(model, cfg.TEST_CKPT_PATH, gpu_id) + + print('Build AOT engine.') + engine = build_engine(cfg.MODEL_ENGINE, + phase='eval', + aot_model=model, + gpu_id=gpu_id, + long_term_mem_gap=cfg.TEST_LONG_TERM_MEM_GAP) + + # Prepare datasets for each sequence + transform = transforms.Compose([ + tr.MultiRestrictSize(cfg.TEST_MIN_SIZE, cfg.TEST_MAX_SIZE, + cfg.TEST_FLIP, cfg.TEST_MULTISCALE, + cfg.MODEL_ALIGN_CORNERS), + tr.MultiToTensor() + ]) + image_root = os.path.join(cfg.TEST_DATA_PATH, 'images') + label_root = os.path.join(cfg.TEST_DATA_PATH, 'masks') + + sequences = os.listdir(image_root) + seq_datasets = [] + for seq_name in sequences: + print('Build a dataset for sequence {}.'.format(seq_name)) + seq_images = np.sort(os.listdir(os.path.join(image_root, seq_name))) + seq_labels = [seq_images[0].replace('jpg', 'png')] + seq_dataset = VOSTest(image_root, + label_root, + seq_name, + seq_images, + seq_labels, + transform=transform) + seq_datasets.append(seq_dataset) + + # Infer + output_root = cfg.TEST_OUTPUT_PATH + output_mask_root = os.path.join(output_root, 'pred_masks') + if not os.path.exists(output_mask_root): + os.makedirs(output_mask_root) + + for seq_dataset in seq_datasets: + seq_name = seq_dataset.seq_name + image_seq_root = os.path.join(image_root, seq_name) + output_mask_seq_root = os.path.join(output_mask_root, seq_name) + if not os.path.exists(output_mask_seq_root): + os.makedirs(output_mask_seq_root) + print('Build a dataloader for sequence {}.'.format(seq_name)) + seq_dataloader = DataLoader(seq_dataset, + batch_size=1, + shuffle=False, + num_workers=cfg.TEST_WORKERS, + pin_memory=True) + + fourcc = cv2.VideoWriter_fourcc(*'XVID') + output_video_path = os.path.join( + output_root, '{}_{}fps.avi'.format(seq_name, video_fps)) + + print('Start the inference of sequence {}:'.format(seq_name)) + model.eval() + engine.restart_engine() + with torch.no_grad(): + for frame_idx, samples in enumerate(seq_dataloader): + sample = samples[0] + img_name = sample['meta']['current_name'][0] + + obj_nums = sample['meta']['obj_num'] + output_height = sample['meta']['height'] + output_width = sample['meta']['width'] + obj_idx = sample['meta']['obj_idx'] + + obj_nums = [int(obj_num) for obj_num in obj_nums] + obj_idx = [int(_obj_idx) for _obj_idx in obj_idx] + + current_img = sample['current_img'] + current_img = current_img.cuda(gpu_id, non_blocking=True) + + if frame_idx == 0: + videoWriter = cv2.VideoWriter( + output_video_path, fourcc, video_fps, + (int(output_width), int(output_height))) + print( + 'Object number: {}. Inference size: {}x{}. Output size: {}x{}.' + .format(obj_nums[0], + current_img.size()[2], + current_img.size()[3], int(output_height), + int(output_width))) + current_label = sample['current_label'].cuda( + gpu_id, non_blocking=True).float() + current_label = F.interpolate(current_label, + size=current_img.size()[2:], + mode="nearest") + # add reference frame + engine.add_reference_frame(current_img, + current_label, + frame_step=0, + obj_nums=obj_nums) + else: + print('Processing image {}...'.format(img_name)) + # predict segmentation + engine.match_propogate_one_frame(current_img) + pred_logit = engine.decode_current_logits( + (output_height, output_width)) + pred_prob = torch.softmax(pred_logit, dim=1) + pred_label = torch.argmax(pred_prob, dim=1, + keepdim=True).float() + _pred_label = F.interpolate(pred_label, + size=engine.input_size_2d, + mode="nearest") + # update memory + engine.update_memory(_pred_label) + + # save results + input_image_path = os.path.join(image_seq_root, img_name) + output_mask_path = os.path.join( + output_mask_seq_root, + img_name.split('.')[0] + '.png') + + pred_label = Image.fromarray( + pred_label.squeeze(0).squeeze(0).cpu().numpy().astype( + 'uint8')).convert('P') + pred_label.putpalette(_palette) + pred_label.save(output_mask_path) + + input_image = Image.open(input_image_path) + + overlayed_image = overlay( + np.array(input_image, dtype=np.uint8), + np.array(pred_label, dtype=np.uint8), color_palette) + videoWriter.write(overlayed_image[..., [2, 1, 0]]) + + print('Save a visualization video to {}.'.format(output_video_path)) + videoWriter.release() + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="AOT Demo") + parser.add_argument('--exp_name', type=str, default='default') + + parser.add_argument('--stage', type=str, default='pre_ytb_dav') + parser.add_argument('--model', type=str, default='r50_aotl') + + parser.add_argument('--gpu_id', type=int, default=0) + + parser.add_argument('--data_path', type=str, default='./datasets/Demo') + parser.add_argument('--output_path', type=str, default='./demo_output') + parser.add_argument('--ckpt_path', + type=str, + default='./pretrain_models/R50_AOTL_PRE_YTB_DAV.pth') + + parser.add_argument('--max_resolution', type=float, default=480 * 1.3) + + parser.add_argument('--amp', action='store_true') + parser.set_defaults(amp=False) + + args = parser.parse_args() + + engine_config = importlib.import_module('configs.' + args.stage) + cfg = engine_config.EngineConfig(args.exp_name, args.model) + + cfg.TEST_GPU_ID = args.gpu_id + + cfg.TEST_CKPT_PATH = args.ckpt_path + cfg.TEST_DATA_PATH = args.data_path + cfg.TEST_OUTPUT_PATH = args.output_path + + cfg.TEST_MIN_SIZE = None + cfg.TEST_MAX_SIZE = args.max_resolution * 800. / 480. + + if args.amp: + with torch.cuda.amp.autocast(enabled=True): + demo(cfg) + else: + demo(cfg) + + +if __name__ == '__main__': + main() diff --git a/aot/tools/eval.py b/aot/tools/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..d3eef634c9ba60ad896be9cde9e20783a3679cae --- /dev/null +++ b/aot/tools/eval.py @@ -0,0 +1,112 @@ +import importlib +import sys + +sys.path.append('.') +sys.path.append('..') + +import torch +import torch.multiprocessing as mp + +from networks.managers.evaluator import Evaluator + + +def main_worker(gpu, cfg, seq_queue=None, info_queue=None, enable_amp=False): + # Initiate a evaluating manager + evaluator = Evaluator(rank=gpu, + cfg=cfg, + seq_queue=seq_queue, + info_queue=info_queue) + # Start evaluation + if enable_amp: + with torch.cuda.amp.autocast(enabled=True): + evaluator.evaluating() + else: + evaluator.evaluating() + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Eval VOS") + parser.add_argument('--exp_name', type=str, default='default') + + parser.add_argument('--stage', type=str, default='pre') + parser.add_argument('--model', type=str, default='aott') + parser.add_argument('--lstt_num', type=int, default=-1) + parser.add_argument('--lt_gap', type=int, default=-1) + parser.add_argument('--st_skip', type=int, default=-1) + parser.add_argument('--max_id_num', type=int, default='-1') + + parser.add_argument('--gpu_id', type=int, default=0) + parser.add_argument('--gpu_num', type=int, default=1) + + parser.add_argument('--ckpt_path', type=str, default='') + parser.add_argument('--ckpt_step', type=int, default=-1) + + parser.add_argument('--dataset', type=str, default='') + parser.add_argument('--split', type=str, default='') + + parser.add_argument('--ema', action='store_true') + parser.set_defaults(ema=False) + + parser.add_argument('--flip', action='store_true') + parser.set_defaults(flip=False) + parser.add_argument('--ms', nargs='+', type=float, default=[1.]) + + parser.add_argument('--max_resolution', type=float, default=480 * 1.3) + + parser.add_argument('--amp', action='store_true') + parser.set_defaults(amp=False) + + args = parser.parse_args() + + engine_config = importlib.import_module('configs.' + args.stage) + cfg = engine_config.EngineConfig(args.exp_name, args.model) + + cfg.TEST_EMA = args.ema + + cfg.TEST_GPU_ID = args.gpu_id + cfg.TEST_GPU_NUM = args.gpu_num + + if args.lstt_num > 0: + cfg.MODEL_LSTT_NUM = args.lstt_num + if args.lt_gap > 0: + cfg.TEST_LONG_TERM_MEM_GAP = args.lt_gap + if args.st_skip > 0: + cfg.TEST_SHORT_TERM_MEM_SKIP = args.st_skip + + if args.max_id_num > 0: + cfg.MODEL_MAX_OBJ_NUM = args.max_id_num + + if args.ckpt_path != '': + cfg.TEST_CKPT_PATH = args.ckpt_path + if args.ckpt_step > 0: + cfg.TEST_CKPT_STEP = args.ckpt_step + + if args.dataset != '': + cfg.TEST_DATASET = args.dataset + + if args.split != '': + cfg.TEST_DATASET_SPLIT = args.split + + cfg.TEST_FLIP = args.flip + cfg.TEST_MULTISCALE = args.ms + + if cfg.TEST_MULTISCALE != [1.]: + cfg.TEST_MAX_SHORT_EDGE = args.max_resolution # for preventing OOM + else: + cfg.TEST_MAX_SHORT_EDGE = None # the default resolution setting of CFBI and AOT + cfg.TEST_MAX_LONG_EDGE = args.max_resolution * 800. / 480. + + if args.gpu_num > 1: + mp.set_start_method('spawn') + seq_queue = mp.Queue() + info_queue = mp.Queue() + mp.spawn(main_worker, + nprocs=cfg.TEST_GPU_NUM, + args=(cfg, seq_queue, info_queue, args.amp)) + else: + main_worker(0, cfg, enable_amp=args.amp) + + +if __name__ == '__main__': + main() diff --git a/aot/tools/train.py b/aot/tools/train.py new file mode 100644 index 0000000000000000000000000000000000000000..177dd36a2d5341730672d59933f5bfe16cc9322f --- /dev/null +++ b/aot/tools/train.py @@ -0,0 +1,87 @@ +import importlib +import random +import sys + +sys.setrecursionlimit(10000) +sys.path.append('.') +sys.path.append('..') + +import torch.multiprocessing as mp + +from networks.managers.trainer import Trainer + + +def main_worker(gpu, cfg, enable_amp=True): + # Initiate a training manager + trainer = Trainer(rank=gpu, cfg=cfg, enable_amp=enable_amp) + # Start Training + trainer.sequential_training() + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Train VOS") + parser.add_argument('--exp_name', type=str, default='') + parser.add_argument('--stage', type=str, default='pre') + parser.add_argument('--model', type=str, default='aott') + parser.add_argument('--max_id_num', type=int, default='-1') + + parser.add_argument('--start_gpu', type=int, default=0) + parser.add_argument('--gpu_num', type=int, default=-1) + parser.add_argument('--batch_size', type=int, default=-1) + parser.add_argument('--dist_url', type=str, default='') + parser.add_argument('--amp', action='store_true') + parser.set_defaults(amp=False) + + parser.add_argument('--pretrained_path', type=str, default='') + + parser.add_argument('--datasets', nargs='+', type=str, default=[]) + parser.add_argument('--lr', type=float, default=-1.) + parser.add_argument('--total_step', type=int, default=-1.) + parser.add_argument('--start_step', type=int, default=-1.) + + args = parser.parse_args() + + engine_config = importlib.import_module('configs.' + args.stage) + + cfg = engine_config.EngineConfig(args.exp_name, args.model) + + if len(args.datasets) > 0: + cfg.DATASETS = args.datasets + + cfg.DIST_START_GPU = args.start_gpu + if args.gpu_num > 0: + cfg.TRAIN_GPUS = args.gpu_num + if args.batch_size > 0: + cfg.TRAIN_BATCH_SIZE = args.batch_size + + if args.pretrained_path != '': + cfg.PRETRAIN_MODEL = args.pretrained_path + + if args.max_id_num > 0: + cfg.MODEL_MAX_OBJ_NUM = args.max_id_num + + if args.lr > 0: + cfg.TRAIN_LR = args.lr + + if args.total_step > 0: + cfg.TRAIN_TOTAL_STEPS = args.total_step + + if args.start_step > 0: + cfg.TRAIN_START_STEP = args.start_step + + if args.dist_url == '': + cfg.DIST_URL = 'tcp://127.0.0.1:123' + str(random.randint(0, 9)) + str( + random.randint(0, 9)) + else: + cfg.DIST_URL = args.dist_url + + if cfg.TRAIN_GPUS > 1: + # Use torch.multiprocessing.spawn to launch distributed processes + mp.spawn(main_worker, nprocs=cfg.TRAIN_GPUS, args=(cfg, args.amp)) + else: + cfg.TRAIN_GPUS = 1 + main_worker(0, cfg, args.amp) + +if __name__ == '__main__': + main() diff --git a/aot/train_eval.sh b/aot/train_eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..2f15c85e65e603186b9f581aa14ac8ba39cefac2 --- /dev/null +++ b/aot/train_eval.sh @@ -0,0 +1,50 @@ +exp="default" +gpu_num="4" + +model="aott" +# model="aots" +# model="aotb" +# model="aotl" +# model="r50_deaotl" +# model="swinb_aotl" + +## Training ## +stage="pre" +python tools/train.py --amp \ + --exp_name ${exp} \ + --stage ${stage} \ + --model ${model} \ + --gpu_num ${gpu_num} + +stage="pre_ytb_dav" +python tools/train.py --amp \ + --exp_name ${exp} \ + --stage ${stage} \ + --model ${model} \ + --gpu_num ${gpu_num} + +## Evaluation ## +dataset="davis2017" +split="test" +python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ + --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} + +dataset="davis2017" +split="val" +python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ + --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} + +dataset="davis2016" +split="val" +python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ + --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} + +dataset="youtubevos2018" +split="val" # or "val_all_frames" +python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ + --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} + +dataset="youtubevos2019" +split="val" # or "val_all_frames" +python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ + --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} \ No newline at end of file diff --git a/aot/utils/__init__.py b/aot/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/aot/utils/__pycache__/__init__.cpython-310.pyc b/aot/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f6ad6a535f3a015fc91967d6100d769cc201454 Binary files /dev/null and b/aot/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/aot/utils/__pycache__/checkpoint.cpython-310.pyc b/aot/utils/__pycache__/checkpoint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07affa7d42ac38d972fe57fe26773c31b9e19860 Binary files /dev/null and b/aot/utils/__pycache__/checkpoint.cpython-310.pyc differ diff --git a/aot/utils/__pycache__/image.cpython-310.pyc b/aot/utils/__pycache__/image.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96c58f85f6c08527ed8c0b5ba037c3daa0832450 Binary files /dev/null and b/aot/utils/__pycache__/image.cpython-310.pyc differ diff --git a/aot/utils/__pycache__/learning.cpython-310.pyc b/aot/utils/__pycache__/learning.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2d257d1f1858b6f74a498dc829799125e98bbd7 Binary files /dev/null and b/aot/utils/__pycache__/learning.cpython-310.pyc differ diff --git a/aot/utils/__pycache__/math.cpython-310.pyc b/aot/utils/__pycache__/math.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11f5fd58088c517955094b5e2f6b1f0c976da669 Binary files /dev/null and b/aot/utils/__pycache__/math.cpython-310.pyc differ diff --git a/aot/utils/checkpoint.py b/aot/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..fbee512afa972f2d61678b6b5752c72dd41023c7 --- /dev/null +++ b/aot/utils/checkpoint.py @@ -0,0 +1,163 @@ +import torch +import os +import shutil +import numpy as np + + +def load_network_and_optimizer(net, opt, pretrained_dir, gpu, scaler=None): + pretrained = torch.load(pretrained_dir, + map_location=torch.device("cuda:" + str(gpu))) + pretrained_dict = pretrained['state_dict'] + model_dict = net.state_dict() + pretrained_dict_update = {} + pretrained_dict_remove = [] + for k, v in pretrained_dict.items(): + if k in model_dict: + pretrained_dict_update[k] = v + elif k[:7] == 'module.': + if k[7:] in model_dict: + pretrained_dict_update[k[7:]] = v + else: + pretrained_dict_remove.append(k) + model_dict.update(pretrained_dict_update) + net.load_state_dict(model_dict) + opt.load_state_dict(pretrained['optimizer']) + if scaler is not None and 'scaler' in pretrained.keys(): + scaler.load_state_dict(pretrained['scaler']) + del (pretrained) + return net.cuda(gpu), opt, pretrained_dict_remove + + +def load_network_and_optimizer_v2(net, opt, pretrained_dir, gpu, scaler=None): + pretrained = torch.load(pretrained_dir, + map_location=torch.device("cuda:" + str(gpu))) + # load model + pretrained_dict = pretrained['state_dict'] + model_dict = net.state_dict() + pretrained_dict_update = {} + pretrained_dict_remove = [] + for k, v in pretrained_dict.items(): + if k in model_dict: + pretrained_dict_update[k] = v + elif k[:7] == 'module.': + if k[7:] in model_dict: + pretrained_dict_update[k[7:]] = v + else: + pretrained_dict_remove.append(k) + model_dict.update(pretrained_dict_update) + net.load_state_dict(model_dict) + + # load optimizer + opt_dict = opt.state_dict() + all_params = { + param_group['name']: param_group['params'][0] + for param_group in opt_dict['param_groups'] + } + pretrained_opt_dict = {'state': {}, 'param_groups': []} + for idx in range(len(pretrained['optimizer']['param_groups'])): + param_group = pretrained['optimizer']['param_groups'][idx] + if param_group['name'] in all_params.keys(): + pretrained_opt_dict['state'][all_params[ + param_group['name']]] = pretrained['optimizer']['state'][ + param_group['params'][0]] + param_group['params'][0] = all_params[param_group['name']] + pretrained_opt_dict['param_groups'].append(param_group) + + opt_dict.update(pretrained_opt_dict) + opt.load_state_dict(opt_dict) + + # load scaler + if scaler is not None and 'scaler' in pretrained.keys(): + scaler.load_state_dict(pretrained['scaler']) + del (pretrained) + return net.cuda(gpu), opt, pretrained_dict_remove + + +def load_network(net, pretrained_dir, gpu): + pretrained = torch.load(pretrained_dir, + map_location=torch.device("cuda:" + str(gpu))) + if 'state_dict' in pretrained.keys(): + pretrained_dict = pretrained['state_dict'] + elif 'model' in pretrained.keys(): + pretrained_dict = pretrained['model'] + else: + pretrained_dict = pretrained + model_dict = net.state_dict() + pretrained_dict_update = {} + pretrained_dict_remove = [] + for k, v in pretrained_dict.items(): + if k in model_dict: + pretrained_dict_update[k] = v + elif k[:7] == 'module.': + if k[7:] in model_dict: + pretrained_dict_update[k[7:]] = v + else: + pretrained_dict_remove.append(k) + model_dict.update(pretrained_dict_update) + net.load_state_dict(model_dict) + del (pretrained) + return net.cuda(gpu), pretrained_dict_remove + + +def save_network(net, + opt, + step, + save_path, + max_keep=8, + backup_dir='./saved_models', + scaler=None): + ckpt = {'state_dict': net.state_dict(), 'optimizer': opt.state_dict()} + if scaler is not None: + ckpt['scaler'] = scaler.state_dict() + + try: + if not os.path.exists(save_path): + os.makedirs(save_path) + save_file = 'save_step_%s.pth' % (step) + save_dir = os.path.join(save_path, save_file) + torch.save(ckpt, save_dir) + except: + save_path = backup_dir + if not os.path.exists(save_path): + os.makedirs(save_path) + save_file = 'save_step_%s.pth' % (step) + save_dir = os.path.join(save_path, save_file) + torch.save(ckpt, save_dir) + + all_ckpt = os.listdir(save_path) + if len(all_ckpt) > max_keep: + all_step = [] + for ckpt_name in all_ckpt: + step = int(ckpt_name.split('_')[-1].split('.')[0]) + all_step.append(step) + all_step = list(np.sort(all_step))[:-max_keep] + for step in all_step: + ckpt_path = os.path.join(save_path, 'save_step_%s.pth' % (step)) + os.system('rm {}'.format(ckpt_path)) + + +def cp_ckpt(remote_dir="data_wd/youtube_vos_jobs/result", curr_dir="backup"): + exps = os.listdir(curr_dir) + for exp in exps: + exp_dir = os.path.join(curr_dir, exp) + stages = os.listdir(exp_dir) + for stage in stages: + stage_dir = os.path.join(exp_dir, stage) + finals = ["ema_ckpt", "ckpt"] + for final in finals: + final_dir = os.path.join(stage_dir, final) + ckpts = os.listdir(final_dir) + for ckpt in ckpts: + if '.pth' not in ckpt: + continue + curr_ckpt_path = os.path.join(final_dir, ckpt) + remote_ckpt_path = os.path.join(remote_dir, exp, stage, + final, ckpt) + if os.path.exists(remote_ckpt_path): + os.system('rm {}'.format(remote_ckpt_path)) + try: + shutil.copy(curr_ckpt_path, remote_ckpt_path) + print("Copy {} to {}.".format(curr_ckpt_path, + remote_ckpt_path)) + except OSError as Inst: + return diff --git a/aot/utils/cp_ckpt.py b/aot/utils/cp_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..22cb96e2291676b4c913deb8e4e47a99d5b0ee16 --- /dev/null +++ b/aot/utils/cp_ckpt.py @@ -0,0 +1,36 @@ +import os +import shutil + + +def cp_ckpt(remote_dir="data_wd/youtube_vos_jobs/result", curr_dir="backup"): + exps = os.listdir(curr_dir) + for exp in exps: + print("Exp: ", exp) + exp_dir = os.path.join(curr_dir, exp) + stages = os.listdir(exp_dir) + for stage in stages: + print("Stage: ", stage) + stage_dir = os.path.join(exp_dir, stage) + finals = ["ema_ckpt", "ckpt"] + for final in finals: + print("Final: ", final) + final_dir = os.path.join(stage_dir, final) + ckpts = os.listdir(final_dir) + for ckpt in ckpts: + if '.pth' not in ckpt: + continue + curr_ckpt_path = os.path.join(final_dir, ckpt) + remote_ckpt_path = os.path.join(remote_dir, exp, stage, + final, ckpt) + if os.path.exists(remote_ckpt_path): + os.system('rm {}'.format(remote_ckpt_path)) + try: + shutil.copy(curr_ckpt_path, remote_ckpt_path) + print(ckpt, ': OK') + except OSError as Inst: + print(Inst) + print(ckpt, ': Fail') + + +if __name__ == "__main__": + cp_ckpt() diff --git a/aot/utils/ema.py b/aot/utils/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..ba1bc69511b65c04c7658d6693f9e403279f44a4 --- /dev/null +++ b/aot/utils/ema.py @@ -0,0 +1,93 @@ +from __future__ import division +from __future__ import unicode_literals + +import torch + + +def get_param_buffer_for_ema(model, + update_buffer=False, + required_buffers=['running_mean', 'running_var']): + params = model.parameters() + all_param_buffer = [p for p in params if p.requires_grad] + if update_buffer: + named_buffers = model.named_buffers() + for key, value in named_buffers: + for buffer_name in required_buffers: + if buffer_name in key: + all_param_buffer.append(value) + break + return all_param_buffer + + +class ExponentialMovingAverage: + """ + Maintains (exponential) moving average of a set of parameters. + """ + def __init__(self, parameters, decay, use_num_updates=True): + """ + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the result of + `model.parameters()`. + decay: The exponential decay. + use_num_updates: Whether to use number of updates when computing + averages. + """ + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.decay = decay + self.num_updates = 0 if use_num_updates else None + self.shadow_params = [p.clone().detach() for p in parameters] + self.collected_params = [] + + def update(self, parameters): + """ + Update currently maintained parameters. + Call this every time the parameters are updated, such as the result of + the `optimizer.step()` call. + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. + """ + decay = self.decay + if self.num_updates is not None: + self.num_updates += 1 + decay = min(decay, + (1 + self.num_updates) / (10 + self.num_updates)) + one_minus_decay = 1.0 - decay + with torch.no_grad(): + for s_param, param in zip(self.shadow_params, parameters): + s_param.sub_(one_minus_decay * (s_param - param)) + + def copy_to(self, parameters): + """ + Copy current parameters into given collection of parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. + """ + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) + del (self.collected_params) diff --git a/aot/utils/eval.py b/aot/utils/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..eff31ab65491c10c248061b2b102e26fb9ffacbc --- /dev/null +++ b/aot/utils/eval.py @@ -0,0 +1,13 @@ +import zipfile +import os + + +def zip_folder(source_folder, zip_dir): + f = zipfile.ZipFile(zip_dir, 'w', zipfile.ZIP_DEFLATED) + pre_len = len(os.path.dirname(source_folder)) + for dirpath, dirnames, filenames in os.walk(source_folder): + for filename in filenames: + pathfile = os.path.join(dirpath, filename) + arcname = pathfile[pre_len:].strip(os.path.sep) + f.write(pathfile, arcname) + f.close() \ No newline at end of file diff --git a/aot/utils/image.py b/aot/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..062a810c2e3399c4cb624582c200badb5e7731de --- /dev/null +++ b/aot/utils/image.py @@ -0,0 +1,127 @@ +import numpy as np +from PIL import Image +import torch +import threading + +_palette = [ + 0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, + 128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0, + 128, 191, 0, 128, 64, 128, 128, 191, 128, 128, 0, 64, 0, 128, 64, 0, 0, + 191, 0, 128, 191, 0, 0, 64, 128, 128, 64, 128, 22, 22, 22, 23, 23, 23, 24, + 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29, 30, 30, + 30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35, 36, 36, 36, + 37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, + 43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, + 49, 50, 50, 50, 51, 51, 51, 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, + 56, 56, 56, 57, 57, 57, 58, 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, + 62, 62, 63, 63, 63, 64, 64, 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, + 68, 69, 69, 69, 70, 70, 70, 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, + 75, 75, 75, 76, 76, 76, 77, 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, + 81, 81, 82, 82, 82, 83, 83, 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, + 87, 88, 88, 88, 89, 89, 89, 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, + 94, 94, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, + 100, 100, 100, 101, 101, 101, 102, 102, 102, 103, 103, 103, 104, 104, 104, + 105, 105, 105, 106, 106, 106, 107, 107, 107, 108, 108, 108, 109, 109, 109, + 110, 110, 110, 111, 111, 111, 112, 112, 112, 113, 113, 113, 114, 114, 114, + 115, 115, 115, 116, 116, 116, 117, 117, 117, 118, 118, 118, 119, 119, 119, + 120, 120, 120, 121, 121, 121, 122, 122, 122, 123, 123, 123, 124, 124, 124, + 125, 125, 125, 126, 126, 126, 127, 127, 127, 128, 128, 128, 129, 129, 129, + 130, 130, 130, 131, 131, 131, 132, 132, 132, 133, 133, 133, 134, 134, 134, + 135, 135, 135, 136, 136, 136, 137, 137, 137, 138, 138, 138, 139, 139, 139, + 140, 140, 140, 141, 141, 141, 142, 142, 142, 143, 143, 143, 144, 144, 144, + 145, 145, 145, 146, 146, 146, 147, 147, 147, 148, 148, 148, 149, 149, 149, + 150, 150, 150, 151, 151, 151, 152, 152, 152, 153, 153, 153, 154, 154, 154, + 155, 155, 155, 156, 156, 156, 157, 157, 157, 158, 158, 158, 159, 159, 159, + 160, 160, 160, 161, 161, 161, 162, 162, 162, 163, 163, 163, 164, 164, 164, + 165, 165, 165, 166, 166, 166, 167, 167, 167, 168, 168, 168, 169, 169, 169, + 170, 170, 170, 171, 171, 171, 172, 172, 172, 173, 173, 173, 174, 174, 174, + 175, 175, 175, 176, 176, 176, 177, 177, 177, 178, 178, 178, 179, 179, 179, + 180, 180, 180, 181, 181, 181, 182, 182, 182, 183, 183, 183, 184, 184, 184, + 185, 185, 185, 186, 186, 186, 187, 187, 187, 188, 188, 188, 189, 189, 189, + 190, 190, 190, 191, 191, 191, 192, 192, 192, 193, 193, 193, 194, 194, 194, + 195, 195, 195, 196, 196, 196, 197, 197, 197, 198, 198, 198, 199, 199, 199, + 200, 200, 200, 201, 201, 201, 202, 202, 202, 203, 203, 203, 204, 204, 204, + 205, 205, 205, 206, 206, 206, 207, 207, 207, 208, 208, 208, 209, 209, 209, + 210, 210, 210, 211, 211, 211, 212, 212, 212, 213, 213, 213, 214, 214, 214, + 215, 215, 215, 216, 216, 216, 217, 217, 217, 218, 218, 218, 219, 219, 219, + 220, 220, 220, 221, 221, 221, 222, 222, 222, 223, 223, 223, 224, 224, 224, + 225, 225, 225, 226, 226, 226, 227, 227, 227, 228, 228, 228, 229, 229, 229, + 230, 230, 230, 231, 231, 231, 232, 232, 232, 233, 233, 233, 234, 234, 234, + 235, 235, 235, 236, 236, 236, 237, 237, 237, 238, 238, 238, 239, 239, 239, + 240, 240, 240, 241, 241, 241, 242, 242, 242, 243, 243, 243, 244, 244, 244, + 245, 245, 245, 246, 246, 246, 247, 247, 247, 248, 248, 248, 249, 249, 249, + 250, 250, 250, 251, 251, 251, 252, 252, 252, 253, 253, 253, 254, 254, 254, + 255, 255, 255 +] + + +def label2colormap(label): + + m = label.astype(np.uint8) + r, c = m.shape + cmap = np.zeros((r, c, 3), dtype=np.uint8) + cmap[:, :, 0] = (m & 1) << 7 | (m & 8) << 3 | (m & 64) >> 1 + cmap[:, :, 1] = (m & 2) << 6 | (m & 16) << 2 | (m & 128) >> 2 + cmap[:, :, 2] = (m & 4) << 5 | (m & 32) << 1 + return cmap + + +def one_hot_mask(mask, cls_num): + if len(mask.size()) == 3: + mask = mask.unsqueeze(1) + indices = torch.arange(0, cls_num + 1, + device=mask.device).view(1, -1, 1, 1) + return (mask == indices).float() + + +def masked_image(image, colored_mask, mask, alpha=0.7): + mask = np.expand_dims(mask > 0, axis=0) + mask = np.repeat(mask, 3, axis=0) + show_img = (image * alpha + colored_mask * + (1 - alpha)) * mask + image * (1 - mask) + return show_img + + +def save_image(image, path): + im = Image.fromarray(np.uint8(image * 255.).transpose((1, 2, 0))) + im.save(path) + + +def _save_mask(mask, path, squeeze_idx=None): + if squeeze_idx is not None: + unsqueezed_mask = mask * 0 + for idx in range(1, len(squeeze_idx)): + obj_id = squeeze_idx[idx] + mask_i = mask == idx + unsqueezed_mask += (mask_i * obj_id).astype(np.uint8) + mask = unsqueezed_mask + mask = Image.fromarray(mask).convert('P') + mask.putpalette(_palette) + mask.save(path) + + +def save_mask(mask_tensor, path, squeeze_idx=None): + mask = mask_tensor.cpu().numpy().astype('uint8') + threading.Thread(target=_save_mask, args=[mask, path, squeeze_idx]).start() + + +def flip_tensor(tensor, dim=0): + inv_idx = torch.arange(tensor.size(dim) - 1, -1, -1, + device=tensor.device).long() + tensor = tensor.index_select(dim, inv_idx) + return tensor + + +def shuffle_obj_mask(mask): + + bs, obj_num, _, _ = mask.size() + new_masks = [] + for idx in range(bs): + now_mask = mask[idx] + random_matrix = torch.eye(obj_num, device=mask.device) + fg = random_matrix[1:][torch.randperm(obj_num - 1)] + random_matrix = torch.cat([random_matrix[0:1], fg], dim=0) + now_mask = torch.einsum('nm,nhw->mhw', random_matrix, now_mask) + new_masks.append(now_mask) + + return torch.stack(new_masks, dim=0) diff --git a/aot/utils/learning.py b/aot/utils/learning.py new file mode 100644 index 0000000000000000000000000000000000000000..8a2147c26d041b7c73491c219ec6f767c74a7511 --- /dev/null +++ b/aot/utils/learning.py @@ -0,0 +1,106 @@ +import math + + +def adjust_learning_rate(optimizer, + base_lr, + p, + itr, + max_itr, + restart=1, + warm_up_steps=1000, + is_cosine_decay=False, + min_lr=1e-5, + encoder_lr_ratio=1.0, + freeze_params=[]): + + if restart > 1: + each_max_itr = int(math.ceil(float(max_itr) / restart)) + itr = itr % each_max_itr + warm_up_steps /= restart + max_itr = each_max_itr + + if itr < warm_up_steps: + now_lr = min_lr + (base_lr - min_lr) * itr / warm_up_steps + else: + itr = itr - warm_up_steps + max_itr = max_itr - warm_up_steps + if is_cosine_decay: + now_lr = min_lr + (base_lr - min_lr) * (math.cos(math.pi * itr / + (max_itr + 1)) + + 1.) * 0.5 + else: + now_lr = min_lr + (base_lr - min_lr) * (1 - itr / (max_itr + 1))**p + + for param_group in optimizer.param_groups: + if encoder_lr_ratio != 1.0 and "encoder." in param_group["name"]: + param_group['lr'] = (now_lr - min_lr) * encoder_lr_ratio + min_lr + else: + param_group['lr'] = now_lr + + for freeze_param in freeze_params: + if freeze_param in param_group["name"]: + param_group['lr'] = 0 + param_group['weight_decay'] = 0 + break + + return now_lr + + +def get_trainable_params(model, + base_lr, + weight_decay, + use_frozen_bn=False, + exclusive_wd_dict={}, + no_wd_keys=[]): + params = [] + memo = set() + total_param = 0 + for key, value in model.named_parameters(): + if value in memo: + continue + total_param += value.numel() + if not value.requires_grad: + continue + memo.add(value) + wd = weight_decay + for exclusive_key in exclusive_wd_dict.keys(): + if exclusive_key in key: + wd = exclusive_wd_dict[exclusive_key] + break + if len(value.shape) == 1: # normalization layers + if 'bias' in key: # bias requires no weight decay + wd = 0. + elif not use_frozen_bn: # if not use frozen BN, apply zero weight decay + wd = 0. + elif 'encoder.' not in key: # if use frozen BN, apply weight decay to all frozen BNs in the encoder + wd = 0. + else: + for no_wd_key in no_wd_keys: + if no_wd_key in key: + wd = 0. + break + params += [{ + "params": [value], + "lr": base_lr, + "weight_decay": wd, + "name": key + }] + + print('Total Param: {:.2f}M'.format(total_param / 1e6)) + return params + + +def freeze_params(module): + for p in module.parameters(): + p.requires_grad = False + + +def calculate_params(state_dict): + memo = set() + total_param = 0 + for key, value in state_dict.items(): + if value in memo: + continue + memo.add(value) + total_param += value.numel() + print('Total Param: {:.2f}M'.format(total_param / 1e6)) diff --git a/aot/utils/math.py b/aot/utils/math.py new file mode 100644 index 0000000000000000000000000000000000000000..7f9ddc106892632f355d8dc0ce3cc46089c98e36 --- /dev/null +++ b/aot/utils/math.py @@ -0,0 +1,24 @@ +import torch + + +def generate_permute_matrix(dim, num, keep_first=True, gpu_id=0): + all_matrix = [] + for idx in range(num): + random_matrix = torch.eye(dim, device=torch.device('cuda', gpu_id)) + if keep_first: + fg = random_matrix[1:][torch.randperm(dim - 1)] + random_matrix = torch.cat([random_matrix[0:1], fg], dim=0) + else: + random_matrix = random_matrix[torch.randperm(dim)] + all_matrix.append(random_matrix) + return torch.stack(all_matrix, dim=0) + + +def truncated_normal_(tensor, mean=0, std=.02): + size = tensor.shape + tmp = tensor.new_empty(size + (4, )).normal_() + valid = (tmp < 2) & (tmp > -2) + ind = valid.max(-1, keepdim=True)[1] + tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) + tensor.data.mul_(std).add_(mean) + return tensor diff --git a/aot/utils/meters.py b/aot/utils/meters.py new file mode 100644 index 0000000000000000000000000000000000000000..00f48d871f8088cb59710105a462679d344d4b0f --- /dev/null +++ b/aot/utils/meters.py @@ -0,0 +1,31 @@ +from __future__ import absolute_import + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, momentum=0.999): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.long_count = 0 + self.momentum = momentum + self.moving_avg = 0 + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + if self.long_count == 0: + self.moving_avg = val + else: + momentum = min(self.momentum, 1. - 1. / self.long_count) + self.moving_avg = self.moving_avg * momentum + val * (1 - momentum) + self.val = val + self.sum += val * n + self.count += n + self.long_count += n + self.avg = self.sum / self.count diff --git a/aot/utils/metric.py b/aot/utils/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..ad474f825fd33dac8bb5d3d34a61092c48470996 --- /dev/null +++ b/aot/utils/metric.py @@ -0,0 +1,36 @@ +import torch + + +def pytorch_iou(pred, target, obj_num, epsilon=1e-6): + ''' + pred: [bs, h, w] + target: [bs, h, w] + obj_num: [bs] + ''' + bs = pred.size(0) + all_iou = [] + for idx in range(bs): + now_pred = pred[idx].unsqueeze(0) + now_target = target[idx].unsqueeze(0) + now_obj_num = obj_num[idx] + + obj_ids = torch.arange(0, now_obj_num + 1, + device=now_pred.device).int().view(-1, 1, 1) + if obj_ids.size(0) == 1: # only contain background + continue + else: + obj_ids = obj_ids[1:] + now_pred = (now_pred == obj_ids).float() + now_target = (now_target == obj_ids).float() + + intersection = (now_pred * now_target).sum((1, 2)) + union = ((now_pred + now_target) > 0).float().sum((1, 2)) + + now_iou = (intersection + epsilon) / (union + epsilon) + + all_iou.append(now_iou.mean()) + if len(all_iou) > 0: + all_iou = torch.stack(all_iou).mean() + else: + all_iou = torch.ones((1), device=pred.device) + return all_iou diff --git a/assets/840_iSXIa0hE8Ek.zip b/assets/840_iSXIa0hE8Ek.zip new file mode 100644 index 0000000000000000000000000000000000000000..ff07d53b4c1319ab4ed2480a48d952d7f75bc0fa --- /dev/null +++ b/assets/840_iSXIa0hE8Ek.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b30c0c83ee62dce1e52dfe3a2ae5eed70aab5f5450623c658c5ab2c775657f4e +size 48605936 diff --git a/assets/blackswan.mp4 b/assets/blackswan.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..dcbeed354f792ea888de4423a63e4d05dfa1fe33 Binary files /dev/null and b/assets/blackswan.mp4 differ diff --git a/assets/cars.mp4 b/assets/cars.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c008ead40c183fe6be71954bab812436bd01c13a --- /dev/null +++ b/assets/cars.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c2d6626c933bc67141089b76ac1227a6f6efb35c58109ab0d16e0d61b13cd37 +size 6854222 diff --git a/assets/cell.mp4 b/assets/cell.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..cc9e56db01966c596ed438283e46dc74a8b8e900 --- /dev/null +++ b/assets/cell.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0605cb1e23e5d0c435fa36b367f5638dd06f69c55ac40f732ee219f5179368a +size 4725839 diff --git a/assets/demo_3x2.gif b/assets/demo_3x2.gif new file mode 100644 index 0000000000000000000000000000000000000000..2fcc6d2c6d486d0a328db04272399c0427508c12 --- /dev/null +++ b/assets/demo_3x2.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3fb7dcf64ff4603e79251b8e1fce2d1c1778c280300a88c8a3360d635cc402b6 +size 3785934 diff --git a/assets/gradio.jpg b/assets/gradio.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6acb758bbe99b3909eba4fa3267831dcb9ad575e Binary files /dev/null and b/assets/gradio.jpg differ diff --git a/assets/interactive_webui.jpg b/assets/interactive_webui.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d42de0f34385f0e92385ffea6227016f911f3894 Binary files /dev/null and b/assets/interactive_webui.jpg differ diff --git a/assets/top.gif b/assets/top.gif new file mode 100644 index 0000000000000000000000000000000000000000..6a6fad55a37622e93b3392c868951bc5de3d0855 --- /dev/null +++ b/assets/top.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24feff46d3e1d6f30f5cb9f24823ab51e0ebdf1ddd7715fb14654971d4a484d3 +size 4498684