Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		Olga
		
	commited on
		
		
					Commit 
							
							·
						
						5f9d349
	
0
								Parent(s):
							
							
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +37 -0
- .gitignore +159 -0
- LICENSE.txt +12 -0
- README.md +14 -0
- app.py +434 -0
- assets/examples/video_bakery.mp4 +3 -0
- assets/examples/video_flowers.mp4 +3 -0
- assets/examples/video_fruits.mp4 +3 -0
- assets/examples/video_plant.mp4 +3 -0
- assets/examples/video_salad.mp4 +3 -0
- assets/examples/video_tram.mp4 +3 -0
- assets/examples/video_tulips.mp4 +3 -0
- assets/video_fruits_ours_full.mp4 +3 -0
- configs/gs/base.yaml +51 -0
- configs/train.yaml +38 -0
- requirements.txt +32 -0
- source/EDGS.code-workspace +11 -0
- source/__init__.py +0 -0
- source/corr_init.py +682 -0
- source/corr_init_new.py +904 -0
- source/data_utils.py +28 -0
- source/losses.py +100 -0
- source/networks.py +52 -0
- source/timer.py +24 -0
- source/trainer.py +262 -0
- source/utils_aux.py +92 -0
- source/utils_preprocess.py +334 -0
- source/vggt_to_colmap.py +598 -0
- source/visualization.py +1072 -0
- submodules/RoMa/.gitignore +11 -0
- submodules/RoMa/LICENSE +21 -0
- submodules/RoMa/README.md +123 -0
- submodules/RoMa/data/.gitignore +2 -0
- submodules/RoMa/demo/demo_3D_effect.py +47 -0
- submodules/RoMa/demo/demo_fundamental.py +34 -0
- submodules/RoMa/demo/demo_match.py +50 -0
- submodules/RoMa/demo/demo_match_opencv_sift.py +43 -0
- submodules/RoMa/demo/demo_match_tiny.py +77 -0
- submodules/RoMa/demo/gif/.gitignore +2 -0
- submodules/RoMa/experiments/eval_roma_outdoor.py +57 -0
- submodules/RoMa/experiments/eval_tiny_roma_v1_outdoor.py +84 -0
- submodules/RoMa/experiments/roma_indoor.py +320 -0
- submodules/RoMa/experiments/train_roma_outdoor.py +307 -0
- submodules/RoMa/experiments/train_tiny_roma_v1_outdoor.py +498 -0
- submodules/RoMa/requirements.txt +14 -0
- submodules/RoMa/romatch/__init__.py +8 -0
- submodules/RoMa/romatch/benchmarks/__init__.py +6 -0
- submodules/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py +113 -0
- submodules/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py +106 -0
- submodules/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py +118 -0
    	
        .gitattributes
    ADDED
    
    | @@ -0,0 +1,37 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            +
            *.tar filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            *.whl filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            *.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,159 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Local notebooks used for local debug
         | 
| 2 | 
            +
            notebooks_local/
         | 
| 3 | 
            +
            wandb/
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # Gradio files
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            served_files/
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # hidden folders 
         | 
| 10 | 
            +
            .*/**
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # The rest is taken from https://github.com/Anttwo/SuGaR
         | 
| 13 | 
            +
            *.pt
         | 
| 14 | 
            +
            *.pth
         | 
| 15 | 
            +
            output*
         | 
| 16 | 
            +
            *.slurm
         | 
| 17 | 
            +
            *.pyc
         | 
| 18 | 
            +
            *.ply
         | 
| 19 | 
            +
            *.obj
         | 
| 20 | 
            +
            sugar_tests.ipynb
         | 
| 21 | 
            +
            sugar_sh_tests.ipynb
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            # To remove
         | 
| 24 | 
            +
            frosting*
         | 
| 25 | 
            +
            extract_shell.py
         | 
| 26 | 
            +
            train_frosting_refined.py
         | 
| 27 | 
            +
            train_frosting.py
         | 
| 28 | 
            +
            run_frosting_viewer.py
         | 
| 29 | 
            +
            slurm_a100.sh
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            # Byte-compiled / optimized / DLL files
         | 
| 32 | 
            +
            __pycache__/
         | 
| 33 | 
            +
            *.py[cod]
         | 
| 34 | 
            +
            *$py.class
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            # C extensions
         | 
| 37 | 
            +
            *.so
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            # Distribution / packaging
         | 
| 40 | 
            +
            .Python
         | 
| 41 | 
            +
            build/
         | 
| 42 | 
            +
            develop-eggs/
         | 
| 43 | 
            +
            dist/
         | 
| 44 | 
            +
            downloads/
         | 
| 45 | 
            +
            eggs/
         | 
| 46 | 
            +
            .eggs/
         | 
| 47 | 
            +
            lib/
         | 
| 48 | 
            +
            lib64/
         | 
| 49 | 
            +
            parts/
         | 
| 50 | 
            +
            sdist/
         | 
| 51 | 
            +
            var/
         | 
| 52 | 
            +
            pip-wheel-metadata/
         | 
| 53 | 
            +
            share/python-wheels/
         | 
| 54 | 
            +
            *.egg-info/
         | 
| 55 | 
            +
            .installed.cfg
         | 
| 56 | 
            +
            *.egg
         | 
| 57 | 
            +
            MANIFEST
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            # PyInstaller
         | 
| 60 | 
            +
            #  Usually these files are written by a python script from a template
         | 
| 61 | 
            +
            #  before PyInstaller builds the exe, so as to inject date/other infos into it.
         | 
| 62 | 
            +
            *.manifest
         | 
| 63 | 
            +
            *.spec
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            # Installer logs
         | 
| 66 | 
            +
            pip-log.txt
         | 
| 67 | 
            +
            pip-delete-this-directory.txt
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            # Unit test / coverage reports
         | 
| 70 | 
            +
            htmlcov/
         | 
| 71 | 
            +
            .tox/
         | 
| 72 | 
            +
            .nox/
         | 
| 73 | 
            +
            .coverage
         | 
| 74 | 
            +
            .coverage.*
         | 
| 75 | 
            +
            .cache
         | 
| 76 | 
            +
            nosetests.xml
         | 
| 77 | 
            +
            coverage.xml
         | 
| 78 | 
            +
            *.cover
         | 
| 79 | 
            +
            *.py,cover
         | 
| 80 | 
            +
            .hypothesis/
         | 
| 81 | 
            +
            .pytest_cache/
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            # Translations
         | 
| 84 | 
            +
            *.mo
         | 
| 85 | 
            +
            *.pot
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            # Django stuff:
         | 
| 88 | 
            +
            *.log
         | 
| 89 | 
            +
            local_settings.py
         | 
| 90 | 
            +
            db.sqlite3
         | 
| 91 | 
            +
            db.sqlite3-journal
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            # Flask stuff:
         | 
| 94 | 
            +
            instance/
         | 
| 95 | 
            +
            .webassets-cache
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            # Scrapy stuff:
         | 
| 98 | 
            +
            .scrapy
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            # Sphinx documentation
         | 
| 101 | 
            +
            docs/_build/
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            # PyBuilder
         | 
| 104 | 
            +
            target/
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            # Jupyter Notebook
         | 
| 107 | 
            +
            .ipynb_checkpoints
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            # IPython
         | 
| 110 | 
            +
            profile_default/
         | 
| 111 | 
            +
            ipython_config.py
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            # pyenv
         | 
| 114 | 
            +
            .python-version
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            # pipenv
         | 
| 117 | 
            +
            #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
         | 
| 118 | 
            +
            #   However, in case of collaboration, if having platform-specific dependencies or dependencies
         | 
| 119 | 
            +
            #   having no cross-platform support, pipenv may install dependencies that don't work, or not
         | 
| 120 | 
            +
            #   install all needed dependencies.
         | 
| 121 | 
            +
            #Pipfile.lock
         | 
| 122 | 
            +
             | 
| 123 | 
            +
            # PEP 582; used by e.g. github.com/David-OConnor/pyflow
         | 
| 124 | 
            +
            __pypackages__/
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            # Celery stuff
         | 
| 127 | 
            +
            celerybeat-schedule
         | 
| 128 | 
            +
            celerybeat.pid
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            # SageMath parsed files
         | 
| 131 | 
            +
            *.sage.py
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            # Environments
         | 
| 134 | 
            +
            .env
         | 
| 135 | 
            +
            .venv
         | 
| 136 | 
            +
            env/
         | 
| 137 | 
            +
            venv/
         | 
| 138 | 
            +
            ENV/
         | 
| 139 | 
            +
            env.bak/
         | 
| 140 | 
            +
            venv.bak/
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            # Spyder project settings
         | 
| 143 | 
            +
            .spyderproject
         | 
| 144 | 
            +
            .spyproject
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            # Rope project settings
         | 
| 147 | 
            +
            .ropeproject
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            # mkdocs documentation
         | 
| 150 | 
            +
            /site
         | 
| 151 | 
            +
             | 
| 152 | 
            +
            # mypy
         | 
| 153 | 
            +
            .mypy_cache/
         | 
| 154 | 
            +
            .dmypy.json
         | 
| 155 | 
            +
            dmypy.json
         | 
| 156 | 
            +
             | 
| 157 | 
            +
            # Pyre type checker
         | 
| 158 | 
            +
            .pyre/
         | 
| 159 | 
            +
            learnableearthparser/fast_sampler/_sampler.c
         | 
    	
        LICENSE.txt
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Copyright 2025, Dmytro Kotovenko, Olga Grebenkova, Björn Ommer
         | 
| 2 | 
            +
            Redistribution and use in source and binary forms, with or without modification, are permitted for non-commercial academic research and/or non-commercial personal use only provided that the following conditions are met:
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            Any use of this software beyond the above specified conditions requires a separate license. Please contact the copyright holders to discuss license terms.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            title: EDGS
         | 
| 3 | 
            +
            emoji: 🎥
         | 
| 4 | 
            +
            colorFrom: pink
         | 
| 5 | 
            +
            colorTo: blue
         | 
| 6 | 
            +
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 5.25.2
         | 
| 8 | 
            +
            app_file: app.py
         | 
| 9 | 
            +
            pinned: false
         | 
| 10 | 
            +
            python_version: "3.10"
         | 
| 11 | 
            +
            license: other
         | 
| 12 | 
            +
            ---
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,434 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import spaces
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import shutil
         | 
| 5 | 
            +
            import tempfile
         | 
| 6 | 
            +
            import argparse
         | 
| 7 | 
            +
            import gradio as gr
         | 
| 8 | 
            +
            import sys
         | 
| 9 | 
            +
            import io
         | 
| 10 | 
            +
            import subprocess
         | 
| 11 | 
            +
            from PIL import Image
         | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            from hydra import initialize, compose
         | 
| 14 | 
            +
            import hydra
         | 
| 15 | 
            +
            from omegaconf import OmegaConf
         | 
| 16 | 
            +
            import time
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            def install_submodules():
         | 
| 19 | 
            +
                subprocess.check_call(['pip', 'install', './submodules/RoMa'])
         | 
| 20 | 
            +
                
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            STATIC_FILE_SERVING_FOLDER = "./served_files"
         | 
| 24 | 
            +
            MODEL_PATH = None
         | 
| 25 | 
            +
            os.makedirs(STATIC_FILE_SERVING_FOLDER, exist_ok=True)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            trainer = None
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            class Tee(io.TextIOBase):
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def __init__(self, *streams):
         | 
| 32 | 
            +
                    self.streams = streams
         | 
| 33 | 
            +
                    
         | 
| 34 | 
            +
                def write(self, data):
         | 
| 35 | 
            +
                    for stream in self.streams:
         | 
| 36 | 
            +
                        stream.write(data)
         | 
| 37 | 
            +
                    return len(data)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def flush(self):
         | 
| 40 | 
            +
                    for stream in self.streams:
         | 
| 41 | 
            +
                        stream.flush()
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def capture_logs(func, *args, **kwargs):
         | 
| 45 | 
            +
                log_capture_string = io.StringIO()
         | 
| 46 | 
            +
                tee = Tee(sys.__stdout__, log_capture_string)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                with contextlib.redirect_stdout(tee):
         | 
| 49 | 
            +
                    result = func(*args, **kwargs)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                return result, log_capture_string.getvalue()
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            @spaces.GPU(duration=350)
         | 
| 55 | 
            +
            # Training Pipeline
         | 
| 56 | 
            +
            def run_training_pipeline(scene_dir, 
         | 
| 57 | 
            +
                                      num_ref_views=16, 
         | 
| 58 | 
            +
                                      num_corrs_per_view=20000, 
         | 
| 59 | 
            +
                                      num_steps=1_000,
         | 
| 60 | 
            +
                                      mode_toggle="Ours (EDGS)"):
         | 
| 61 | 
            +
                with initialize(config_path="./configs", version_base="1.1"):
         | 
| 62 | 
            +
                    cfg = compose(config_name="train")
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                scene_name = os.path.basename(scene_dir)
         | 
| 65 | 
            +
                model_output_dir = f"./outputs/{scene_name}_trained"
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                cfg.wandb.mode = "disabled"
         | 
| 68 | 
            +
                cfg.gs.dataset.model_path = model_output_dir
         | 
| 69 | 
            +
                cfg.gs.dataset.source_path = scene_dir
         | 
| 70 | 
            +
                cfg.gs.dataset.images = "images"
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                cfg.gs.opt.TEST_CAM_IDX_TO_LOG = 12
         | 
| 73 | 
            +
                cfg.train.gs_epochs = 30000
         | 
| 74 | 
            +
                
         | 
| 75 | 
            +
                if mode_toggle=="Ours (EDGS)":
         | 
| 76 | 
            +
                    cfg.gs.opt.opacity_reset_interval = 1_000_000
         | 
| 77 | 
            +
                    cfg.train.reduce_opacity = True
         | 
| 78 | 
            +
                    cfg.train.no_densify = True
         | 
| 79 | 
            +
                    cfg.train.max_lr = True
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    cfg.init_wC.use = True
         | 
| 82 | 
            +
                    cfg.init_wC.matches_per_ref = num_corrs_per_view
         | 
| 83 | 
            +
                    cfg.init_wC.nns_per_ref = 1
         | 
| 84 | 
            +
                    cfg.init_wC.num_refs = num_ref_views
         | 
| 85 | 
            +
                    cfg.init_wC.add_SfM_init = False
         | 
| 86 | 
            +
                    cfg.init_wC.scaling_factor = 0.00077 * 2.
         | 
| 87 | 
            +
                    
         | 
| 88 | 
            +
                set_seed(cfg.seed)
         | 
| 89 | 
            +
                os.makedirs(cfg.gs.dataset.model_path, exist_ok=True)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                global trainer
         | 
| 92 | 
            +
                global MODEL_PATH
         | 
| 93 | 
            +
                generator3dgs = hydra.utils.instantiate(cfg.gs, do_train_test_split=False)
         | 
| 94 | 
            +
                trainer = EDGSTrainer(GS=generator3dgs, training_config=cfg.gs.opt, device=cfg.device, log_wandb=cfg.wandb.mode != 'disabled')
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                # Disable evaluation and saving
         | 
| 97 | 
            +
                trainer.saving_iterations = []
         | 
| 98 | 
            +
                trainer.evaluate_iterations = []
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                # Initialize
         | 
| 101 | 
            +
                trainer.timer.start()
         | 
| 102 | 
            +
                start_time = time.time()
         | 
| 103 | 
            +
                trainer.init_with_corr(cfg.init_wC, roma_model=roma_model)
         | 
| 104 | 
            +
                time_for_init = time.time()-start_time
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                viewpoint_cams = trainer.GS.scene.getTrainCameras()
         | 
| 107 | 
            +
                path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams, 
         | 
| 108 | 
            +
                                                                      n_selected=6, 
         | 
| 109 | 
            +
                                                                      n_points_per_segment=30, 
         | 
| 110 | 
            +
                                                                      closed=False)
         | 
| 111 | 
            +
                path_cameras = path_cameras + path_cameras[::-1]
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                path_renderings = []
         | 
| 114 | 
            +
                idx = 0
         | 
| 115 | 
            +
                # Visualize after init
         | 
| 116 | 
            +
                for _ in range(120):
         | 
| 117 | 
            +
                    with torch.no_grad():
         | 
| 118 | 
            +
                        viewpoint_cam = path_cameras[idx]
         | 
| 119 | 
            +
                        idx = (idx + 1) % len(path_cameras)
         | 
| 120 | 
            +
                        render_pkg = trainer.GS(viewpoint_cam)
         | 
| 121 | 
            +
                        image = render_pkg["render"]
         | 
| 122 | 
            +
                        image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
         | 
| 123 | 
            +
                        image_np = (image_np * 255).astype(np.uint8)
         | 
| 124 | 
            +
                        path_renderings.append(put_text_on_image(img=image_np, 
         | 
| 125 | 
            +
                                                                 text=f"Init stage.\nTime:{time_for_init:.3f}s.   "))
         | 
| 126 | 
            +
                path_renderings = path_renderings + [put_text_on_image(img=image_np, text=f"Start fitting.\nTime:{time_for_init:.3f}s.   ")]*30
         | 
| 127 | 
            +
                
         | 
| 128 | 
            +
                # Train and save visualizations during training.
         | 
| 129 | 
            +
                start_time = time.time()
         | 
| 130 | 
            +
                for _ in range(int(num_steps//10)):
         | 
| 131 | 
            +
                    with torch.no_grad():
         | 
| 132 | 
            +
                        viewpoint_cam = path_cameras[idx]
         | 
| 133 | 
            +
                        idx = (idx + 1) % len(path_cameras)
         | 
| 134 | 
            +
                        render_pkg = trainer.GS(viewpoint_cam)
         | 
| 135 | 
            +
                        image = render_pkg["render"]
         | 
| 136 | 
            +
                        image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
         | 
| 137 | 
            +
                        image_np = (image_np * 255).astype(np.uint8)
         | 
| 138 | 
            +
                        path_renderings.append(put_text_on_image(
         | 
| 139 | 
            +
                            img=image_np, 
         | 
| 140 | 
            +
                            text=f"Fitting stage.\nTime:{time_for_init + time.time()-start_time:.3f}s.   "))
         | 
| 141 | 
            +
                
         | 
| 142 | 
            +
                    cfg.train.gs_epochs = 10
         | 
| 143 | 
            +
                    trainer.train(cfg.train)
         | 
| 144 | 
            +
                    print(f"Time elapsed: {(time_for_init + time.time()-start_time):.2f}s.")
         | 
| 145 | 
            +
                    # if (cfg.init_wC.use == False) and (time_for_init + time.time()-start_time) > 60:
         | 
| 146 | 
            +
                    #     break
         | 
| 147 | 
            +
                final_time = time.time()
         | 
| 148 | 
            +
                
         | 
| 149 | 
            +
                # Add static frame. To highlight we're done
         | 
| 150 | 
            +
                path_renderings += [put_text_on_image(
         | 
| 151 | 
            +
                    img=image_np, text=f"Done.\nTime:{time_for_init + final_time -start_time:.3f}s.   ")]*30
         | 
| 152 | 
            +
                # Final rendering at the end.
         | 
| 153 | 
            +
                for _ in range(len(path_cameras)):
         | 
| 154 | 
            +
                    with torch.no_grad():
         | 
| 155 | 
            +
                        viewpoint_cam = path_cameras[idx]
         | 
| 156 | 
            +
                        idx = (idx + 1) % len(path_cameras)
         | 
| 157 | 
            +
                        render_pkg = trainer.GS(viewpoint_cam)
         | 
| 158 | 
            +
                        image = render_pkg["render"]
         | 
| 159 | 
            +
                        image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
         | 
| 160 | 
            +
                        image_np = (image_np * 255).astype(np.uint8)
         | 
| 161 | 
            +
                        path_renderings.append(put_text_on_image(img=image_np, 
         | 
| 162 | 
            +
                                                             text=f"Final result.\nTime:{time_for_init + final_time -start_time:.3f}s.   "))
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                trainer.save_model()
         | 
| 165 | 
            +
                final_video_path = os.path.join(STATIC_FILE_SERVING_FOLDER, f"{scene_name}_final.mp4")
         | 
| 166 | 
            +
                save_numpy_frames_as_mp4(frames=path_renderings, output_path=final_video_path, fps=30, center_crop=0.85)
         | 
| 167 | 
            +
                MODEL_PATH = cfg.gs.dataset.model_path
         | 
| 168 | 
            +
                ply_path = os.path.join(cfg.gs.dataset.model_path, f"point_cloud/iteration_{trainer.gs_step}/point_cloud.ply")
         | 
| 169 | 
            +
                shutil.copy(ply_path, os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"))
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                return final_video_path, ply_path
         | 
| 172 | 
            +
             | 
| 173 | 
            +
            # Gradio Interface
         | 
| 174 | 
            +
            def gradio_interface(input_path, num_ref_views, num_corrs, num_steps):
         | 
| 175 | 
            +
                images, scene_dir = run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024)
         | 
| 176 | 
            +
                shutil.copytree(scene_dir, STATIC_FILE_SERVING_FOLDER+'/scene_colmaped',  dirs_exist_ok=True)
         | 
| 177 | 
            +
                (final_video_path, ply_path), log_output = capture_logs(run_training_pipeline,
         | 
| 178 | 
            +
                                                                        scene_dir,
         | 
| 179 | 
            +
                                                                        num_ref_views,
         | 
| 180 | 
            +
                                                                        num_corrs,
         | 
| 181 | 
            +
                                                                        num_steps
         | 
| 182 | 
            +
                                                                        )
         | 
| 183 | 
            +
                images_rgb = [img[:, :, ::-1] for img in images]
         | 
| 184 | 
            +
                return images_rgb, final_video_path, scene_dir, ply_path, log_output
         | 
| 185 | 
            +
             | 
| 186 | 
            +
            # Dummy Render Functions
         | 
| 187 | 
            +
            @spaces.GPU(duration=60)
         | 
| 188 | 
            +
            def render_all_views(scene_dir):
         | 
| 189 | 
            +
                viewpoint_cams = trainer.GS.scene.getTrainCameras()
         | 
| 190 | 
            +
                path_cameras = generate_fully_smooth_cameras_with_tsp(existing_cameras=viewpoint_cams, 
         | 
| 191 | 
            +
                                                                      n_selected=8, 
         | 
| 192 | 
            +
                                                                      n_points_per_segment=60, 
         | 
| 193 | 
            +
                                                                      closed=False)
         | 
| 194 | 
            +
                path_cameras = path_cameras + path_cameras[::-1]
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                path_renderings = []
         | 
| 197 | 
            +
                with torch.no_grad():
         | 
| 198 | 
            +
                    for viewpoint_cam in path_cameras:
         | 
| 199 | 
            +
                        render_pkg = trainer.GS(viewpoint_cam)
         | 
| 200 | 
            +
                        image = render_pkg["render"]
         | 
| 201 | 
            +
                        image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
         | 
| 202 | 
            +
                        image_np = (image_np * 255).astype(np.uint8)
         | 
| 203 | 
            +
                        path_renderings.append(image_np)
         | 
| 204 | 
            +
                save_numpy_frames_as_mp4(frames=path_renderings, 
         | 
| 205 | 
            +
                                         output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4"), 
         | 
| 206 | 
            +
                                         fps=30, 
         | 
| 207 | 
            +
                                         center_crop=0.85)
         | 
| 208 | 
            +
                
         | 
| 209 | 
            +
                return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_all_views.mp4")
         | 
| 210 | 
            +
             | 
| 211 | 
            +
            @spaces.GPU(duration=60)
         | 
| 212 | 
            +
            def render_circular_path(scene_dir):
         | 
| 213 | 
            +
                viewpoint_cams = trainer.GS.scene.getTrainCameras()
         | 
| 214 | 
            +
                path_cameras = generate_circular_camera_path(existing_cameras=viewpoint_cams, 
         | 
| 215 | 
            +
                                                             N=240, 
         | 
| 216 | 
            +
                                                             radius_scale=0.65,
         | 
| 217 | 
            +
                                                             d=0)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                path_renderings = []
         | 
| 220 | 
            +
                with torch.no_grad():
         | 
| 221 | 
            +
                    for viewpoint_cam in path_cameras:
         | 
| 222 | 
            +
                        render_pkg = trainer.GS(viewpoint_cam)
         | 
| 223 | 
            +
                        image = render_pkg["render"]
         | 
| 224 | 
            +
                        image_np = np.clip(image.detach().cpu().numpy().transpose(1, 2, 0), 0, 1)
         | 
| 225 | 
            +
                        image_np = (image_np * 255).astype(np.uint8)
         | 
| 226 | 
            +
                        path_renderings.append(image_np)
         | 
| 227 | 
            +
                save_numpy_frames_as_mp4(frames=path_renderings, 
         | 
| 228 | 
            +
                                         output_path=os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4"), 
         | 
| 229 | 
            +
                                         fps=30, 
         | 
| 230 | 
            +
                                         center_crop=0.85)
         | 
| 231 | 
            +
                
         | 
| 232 | 
            +
                return os.path.join(STATIC_FILE_SERVING_FOLDER, "render_circular_path.mp4")
         | 
| 233 | 
            +
             | 
| 234 | 
            +
            # Download Functions
         | 
| 235 | 
            +
            def download_cameras():
         | 
| 236 | 
            +
                path = os.path.join(MODEL_PATH, "cameras.json")
         | 
| 237 | 
            +
                return f"[📥 Download Cameras.json](file={path})"
         | 
| 238 | 
            +
             | 
| 239 | 
            +
            def download_model():
         | 
| 240 | 
            +
                path = os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply")
         | 
| 241 | 
            +
                return f"[📥 Download Pretrained Model (.ply)](file={path})"
         | 
| 242 | 
            +
             | 
| 243 | 
            +
            # Full pipeline helpers
         | 
| 244 | 
            +
            def run_full_pipeline(input_path, num_ref_views, num_corrs, max_size=1024):
         | 
| 245 | 
            +
                tmpdirname = tempfile.mkdtemp()
         | 
| 246 | 
            +
                scene_dir = os.path.join(tmpdirname, "scene")
         | 
| 247 | 
            +
                os.makedirs(scene_dir, exist_ok=True)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size)
         | 
| 250 | 
            +
                run_colmap_on_scene(scene_dir)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                return selected_frames, scene_dir
         | 
| 253 | 
            +
             | 
| 254 | 
            +
            # Preprocess Input
         | 
| 255 | 
            +
            def process_input(input_path, num_ref_views, output_dir, max_size=1024):
         | 
| 256 | 
            +
                if isinstance(input_path, (str, os.PathLike)):
         | 
| 257 | 
            +
                    if os.path.isdir(input_path):
         | 
| 258 | 
            +
                        frames = []
         | 
| 259 | 
            +
                        for img_file in sorted(os.listdir(input_path)):
         | 
| 260 | 
            +
                            if img_file.lower().endswith(('jpg', 'jpeg', 'png')):
         | 
| 261 | 
            +
                                img = Image.open(os.path.join(output_dir, img_file)).convert('RGB')
         | 
| 262 | 
            +
                                img.thumbnail((1024, 1024))
         | 
| 263 | 
            +
                                frames.append(np.array(img))
         | 
| 264 | 
            +
                    else:
         | 
| 265 | 
            +
                        frames = read_video_frames(video_input=input_path, max_size=max_size)
         | 
| 266 | 
            +
                else:
         | 
| 267 | 
            +
                    frames = read_video_frames(video_input=input_path, max_size=max_size)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                frames_scores = preprocess_frames(frames)
         | 
| 270 | 
            +
                selected_frames_indices = select_optimal_frames(scores=frames_scores, k=min(num_ref_views, len(frames)))
         | 
| 271 | 
            +
                selected_frames = [frames[frame_idx] for frame_idx in selected_frames_indices]
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                save_frames_to_scene_dir(frames=selected_frames, scene_dir=output_dir)
         | 
| 274 | 
            +
                return selected_frames
         | 
| 275 | 
            +
             | 
| 276 | 
            +
            @spaces.GPU(duration=150)
         | 
| 277 | 
            +
            def preprocess_input(input_path, num_ref_views, max_size=1024):
         | 
| 278 | 
            +
                tmpdirname = tempfile.mkdtemp()
         | 
| 279 | 
            +
                scene_dir = os.path.join(tmpdirname, "scene")
         | 
| 280 | 
            +
                os.makedirs(scene_dir, exist_ok=True)
         | 
| 281 | 
            +
                selected_frames = process_input(input_path, num_ref_views, scene_dir, max_size)
         | 
| 282 | 
            +
                run_colmap_on_scene(scene_dir)
         | 
| 283 | 
            +
                return selected_frames, scene_dir
         | 
| 284 | 
            +
             | 
| 285 | 
            +
            def start_training(scene_dir, num_ref_views, num_corrs, num_steps):
         | 
| 286 | 
            +
                return capture_logs(run_training_pipeline, scene_dir, num_ref_views, num_corrs, num_steps)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
            # Gradio App
         | 
| 289 | 
            +
            with gr.Blocks() as demo:
         | 
| 290 | 
            +
                with gr.Row():
         | 
| 291 | 
            +
                    with gr.Column(scale=6):
         | 
| 292 | 
            +
                        gr.Markdown("""
         | 
| 293 | 
            +
                        ## <span style='font-size: 20px;'>📄 EDGS: Eliminating Densification for Efficient Convergence of 3DGS</span>
         | 
| 294 | 
            +
                        🔗 <a href='https://compvis.github.io/EDGS' target='_blank'>Project Page</a>
         | 
| 295 | 
            +
                        """, elem_id="header")
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                gr.Markdown("""
         | 
| 298 | 
            +
                ### <span style='font-size: 22px;'>🛠️ How to Use This Demo</span>
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                1. Upload a **front-facing video** or **a folder of images** of a **static** scene.
         | 
| 301 | 
            +
                2. Use the sliders to configure the number of reference views, correspondences, and optimization steps.
         | 
| 302 | 
            +
                3. First press on preprocess Input to extract frames from video(for videos) and COLMAP frames.
         | 
| 303 | 
            +
                4.Then click **🚀 Start Reconstruction** to actually launch the reconstruction pipeline.
         | 
| 304 | 
            +
                5. Watch the training visualization and explore the 3D model.
         | 
| 305 | 
            +
                ‼️ **If you see nothing in the 3D model viewer**, try rotating or zooming — sometimes the initial camera orientation is off.
         | 
| 306 | 
            +
             | 
| 307 | 
            +
             | 
| 308 | 
            +
                ✅ Best for scenes with small camera motion.
         | 
| 309 | 
            +
                ❗ For full 360° or large-scale scenes, we recommend the Colab version (see project page).
         | 
| 310 | 
            +
                """, elem_id="quickstart")
         | 
| 311 | 
            +
                scene_dir_state = gr.State()
         | 
| 312 | 
            +
                ply_model_state = gr.State()
         | 
| 313 | 
            +
                with gr.Row():
         | 
| 314 | 
            +
                    with gr.Column(scale=2):
         | 
| 315 | 
            +
                        input_file = gr.File(label="Upload Video or Images", 
         | 
| 316 | 
            +
                            file_types=[".mp4", ".avi", ".mov", ".png", ".jpg", ".jpeg"], 
         | 
| 317 | 
            +
                            file_count="multiple")
         | 
| 318 | 
            +
                        gr.Examples(
         | 
| 319 | 
            +
                            examples = [
         | 
| 320 | 
            +
                                [["assets/examples/video_bakery.mp4"]],
         | 
| 321 | 
            +
                                [["assets/examples/video_flowers.mp4"]],
         | 
| 322 | 
            +
                                [["assets/examples/video_fruits.mp4"]],
         | 
| 323 | 
            +
                                [["assets/examples/video_plant.mp4"]],
         | 
| 324 | 
            +
                                [["assets/examples/video_salad.mp4"]],
         | 
| 325 | 
            +
                                [["assets/examples/video_tram.mp4"]],
         | 
| 326 | 
            +
                                [["assets/examples/video_tulips.mp4"]]
         | 
| 327 | 
            +
                                ],
         | 
| 328 | 
            +
                            inputs=[input_file],
         | 
| 329 | 
            +
                            label="🎞️ ALternatively, try an Example Video",
         | 
| 330 | 
            +
                            examples_per_page=4
         | 
| 331 | 
            +
                        )
         | 
| 332 | 
            +
                        ref_slider = gr.Slider(4, 32, value=16, step=1, label="Number of Reference Views")
         | 
| 333 | 
            +
                        corr_slider = gr.Slider(5000, 30000, value=20000, step=1000, label="Correspondences per Reference View")
         | 
| 334 | 
            +
                        fit_steps_slider = gr.Slider(100, 5000, value=400, step=100, label="Number of optimization steps")
         | 
| 335 | 
            +
                        preprocess_button = gr.Button("📸 Preprocess Input")
         | 
| 336 | 
            +
                        start_button = gr.Button("🚀 Start Reconstruction", interactive=False)
         | 
| 337 | 
            +
                        gallery = gr.Gallery(label="Selected Reference Views", columns=4, height=300)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                    with gr.Column(scale=3):
         | 
| 340 | 
            +
                        gr.Markdown("### 🏋️ Training Visualization")
         | 
| 341 | 
            +
                        video_output = gr.Video(label="Training Video", autoplay=True)
         | 
| 342 | 
            +
                        render_all_views_button = gr.Button("🎥 Render All-Views Path")
         | 
| 343 | 
            +
                        render_circular_path_button = gr.Button("🎥 Render Circular Path")
         | 
| 344 | 
            +
                        rendered_video_output = gr.Video(label="Rendered Video", autoplay=True)
         | 
| 345 | 
            +
                    with gr.Column(scale=5):
         | 
| 346 | 
            +
                        gr.Markdown("### 🌐 Final 3D Model")
         | 
| 347 | 
            +
                        model3d_viewer = gr.Model3D(label="3D Model Viewer")
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                        gr.Markdown("### 📦 Output Files")
         | 
| 350 | 
            +
                        with gr.Row(height=50):
         | 
| 351 | 
            +
                            with gr.Column():
         | 
| 352 | 
            +
                                #gr.Markdown(value=f"[📥 Download .ply](file/point_cloud_final.ply)")
         | 
| 353 | 
            +
                                download_cameras_button = gr.Button("📥 Download Cameras.json")
         | 
| 354 | 
            +
                                download_cameras_file = gr.File(label="📄 Cameras.json")
         | 
| 355 | 
            +
                            with gr.Column():
         | 
| 356 | 
            +
                                download_model_button = gr.Button("📥 Download Pretrained Model (.ply)")
         | 
| 357 | 
            +
                                download_model_file = gr.File(label="📄 Pretrained Model (.ply)")
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                log_output_box = gr.Textbox(label="🖥️ Log", lines=10, interactive=False)
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                def on_preprocess_click(input_file, num_ref_views):
         | 
| 362 | 
            +
                    images, scene_dir = preprocess_input(input_file, num_ref_views)
         | 
| 363 | 
            +
                    return gr.update(value=[x[...,::-1] for x in images]), scene_dir, gr.update(interactive=True)
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                def on_start_click(scene_dir, num_ref_views, num_corrs, num_steps):
         | 
| 366 | 
            +
                    (video_path, ply_path), logs = start_training(scene_dir, num_ref_views, num_corrs, num_steps)
         | 
| 367 | 
            +
                    return video_path, ply_path, logs
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                preprocess_button.click(
         | 
| 370 | 
            +
                    fn=on_preprocess_click,
         | 
| 371 | 
            +
                    inputs=[input_file, ref_slider],
         | 
| 372 | 
            +
                    outputs=[gallery, scene_dir_state, start_button]
         | 
| 373 | 
            +
                )
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                start_button.click(
         | 
| 376 | 
            +
                    fn=on_start_click,
         | 
| 377 | 
            +
                    inputs=[scene_dir_state, ref_slider, corr_slider, fit_steps_slider],
         | 
| 378 | 
            +
                    outputs=[video_output, model3d_viewer, log_output_box]
         | 
| 379 | 
            +
                )
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                render_all_views_button.click(fn=render_all_views, inputs=[scene_dir_state], outputs=[rendered_video_output])
         | 
| 382 | 
            +
                render_circular_path_button.click(fn=render_circular_path, inputs=[scene_dir_state], outputs=[rendered_video_output])
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                download_cameras_button.click(fn=lambda: os.path.join(MODEL_PATH, "cameras.json"), inputs=[], outputs=[download_cameras_file])
         | 
| 385 | 
            +
                download_model_button.click(fn=lambda: os.path.join(STATIC_FILE_SERVING_FOLDER, "point_cloud_final.ply"), inputs=[], outputs=[download_model_file])
         | 
| 386 | 
            +
             | 
| 387 | 
            +
             | 
| 388 | 
            +
                gr.Markdown("""
         | 
| 389 | 
            +
                ---
         | 
| 390 | 
            +
                ### <span style='font-size: 20px;'>📖 Detailed Overview</span>
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                If you uploaded a video, it will be automatically cut into a smaller number of frames (default: 16).
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                The model pipeline:
         | 
| 395 | 
            +
                1. 🧠 Runs PyCOLMAP to estimate camera intrinsics & poses (~3–7 seconds for <16 images).
         | 
| 396 | 
            +
                2. 🔁 Computes 2D-2D correspondences between views. More correspondences generally improve quality.
         | 
| 397 | 
            +
                3. 🔧 Optimizes a 3D Gaussian Splatting model for several steps.
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                ### 🎥 Training Visualization
         | 
| 400 | 
            +
                You will see a visualization of the entire training process in the "Training Video" pane.
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                ### 🌀 Rendering & 3D Model
         | 
| 403 | 
            +
                - Render the scene from a circular path of novel views.
         | 
| 404 | 
            +
                - Or from camera views close to the original input.
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                The 3D model is shown in the right viewer. You can explore it interactively:
         | 
| 407 | 
            +
                - On PC: WASD keys, arrow keys, and mouse clicks
         | 
| 408 | 
            +
                - On mobile: pan and pinch to zoom
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                🕒 Note: the 3D viewer takes a few extra seconds (~5s) to display after training ends.
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                ---
         | 
| 413 | 
            +
                Preloaded models coming soon. (TODO)
         | 
| 414 | 
            +
                """, elem_id="details")
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                
         | 
| 417 | 
            +
             | 
| 418 | 
            +
             | 
| 419 | 
            +
            if __name__ == "__main__":
         | 
| 420 | 
            +
                install_submodules()
         | 
| 421 | 
            +
                from source.utils_aux import set_seed
         | 
| 422 | 
            +
                from source.utils_preprocess import read_video_frames, preprocess_frames, select_optimal_frames, save_frames_to_scene_dir, run_colmap_on_scene
         | 
| 423 | 
            +
                from source.trainer import EDGSTrainer
         | 
| 424 | 
            +
                from source.visualization import generate_circular_camera_path, save_numpy_frames_as_mp4, generate_fully_smooth_cameras_with_tsp, put_text_on_image
         | 
| 425 | 
            +
                # Init RoMA model:
         | 
| 426 | 
            +
                sys.path.append('../submodules/RoMa')
         | 
| 427 | 
            +
                from romatch import roma_outdoor, roma_indoor
         | 
| 428 | 
            +
                
         | 
| 429 | 
            +
                roma_model = roma_indoor(device="cpu")
         | 
| 430 | 
            +
                roma_model = roma_model.to("cuda")
         | 
| 431 | 
            +
                roma_model.upsample_preds = False
         | 
| 432 | 
            +
                roma_model.symmetric = False
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                demo.launch(share=True)
         | 
    	
        assets/examples/video_bakery.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:b65c813f2ef9637579350e145fdceed333544d933278be5613c6d49468f4eab0
         | 
| 3 | 
            +
            size 6362238
         | 
    	
        assets/examples/video_flowers.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:c81a7c28e0d59bad38d5a45f6bfa83b80990d1fb78a06f82137e8b57ec38e62b
         | 
| 3 | 
            +
            size 6466943
         | 
    	
        assets/examples/video_fruits.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:ac3b937e155d1965a314b478e940f6fd93e90371d3bf7d62d3225d840fe8e126
         | 
| 3 | 
            +
            size 3356915
         | 
    	
        assets/examples/video_plant.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:e67a40e62de9aacf0941d8cf33a2dd08d256fe60d0fc58f60426c251f9d8abd8
         | 
| 3 | 
            +
            size 13023885
         | 
    	
        assets/examples/video_salad.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:367d86071a201c124383f0a076ce644b7f86b9c2fbce6aa595c4989ebf259bfb
         | 
| 3 | 
            +
            size 8774427
         | 
    	
        assets/examples/video_tram.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:298c297155d3f52edcffcb6b6f9910c992b98ecfb93cfaf8fb64fb340aba1dae
         | 
| 3 | 
            +
            size 4697915
         | 
    	
        assets/examples/video_tulips.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:c19942319b8f7e33bb03cb4a39b11797fe38572bf7157af37471b7c8573fb495
         | 
| 3 | 
            +
            size 7298210
         | 
    	
        assets/video_fruits_ours_full.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:9c5b113566d3a083b81360b549ed89f70d5e81739f83e182518f6906811311a2
         | 
| 3 | 
            +
            size 14839197
         | 
    	
        configs/gs/base.yaml
    ADDED
    
    | @@ -0,0 +1,51 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            _target_: source.networks.Warper3DGS
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            verbose: True
         | 
| 4 | 
            +
            viewpoint_stack: !!null
         | 
| 5 | 
            +
            sh_degree: 3
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            opt:
         | 
| 8 | 
            +
              iterations: 30000
         | 
| 9 | 
            +
              position_lr_init: 0.00016
         | 
| 10 | 
            +
              position_lr_final: 1.6e-06
         | 
| 11 | 
            +
              position_lr_delay_mult: 0.01
         | 
| 12 | 
            +
              position_lr_max_steps: 30000
         | 
| 13 | 
            +
              feature_lr: 0.0025
         | 
| 14 | 
            +
              opacity_lr: 0.025
         | 
| 15 | 
            +
              scaling_lr: 0.005
         | 
| 16 | 
            +
              rotation_lr: 0.001
         | 
| 17 | 
            +
              percent_dense: 0.01
         | 
| 18 | 
            +
              lambda_dssim: 0.2
         | 
| 19 | 
            +
              densification_interval: 100
         | 
| 20 | 
            +
              opacity_reset_interval: 30000
         | 
| 21 | 
            +
              densify_from_iter: 500
         | 
| 22 | 
            +
              densify_until_iter: 15000
         | 
| 23 | 
            +
              densify_grad_threshold: 0.0002
         | 
| 24 | 
            +
              random_background: false
         | 
| 25 | 
            +
              save_iterations: [3000, 7000, 15000, 30000]
         | 
| 26 | 
            +
              batch_size: 64
         | 
| 27 | 
            +
              exposure_lr_init: 0.01
         | 
| 28 | 
            +
              exposure_lr_final: 0.0001
         | 
| 29 | 
            +
              exposure_lr_delay_steps: 0
         | 
| 30 | 
            +
              exposure_lr_delay_mult: 0.0
         | 
| 31 | 
            +
             | 
| 32 | 
            +
              TRAIN_CAM_IDX_TO_LOG: 50
         | 
| 33 | 
            +
              TEST_CAM_IDX_TO_LOG: 10
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            pipe:
         | 
| 36 | 
            +
              convert_SHs_python: False
         | 
| 37 | 
            +
              compute_cov3D_python: False
         | 
| 38 | 
            +
              debug: False
         | 
| 39 | 
            +
              antialiasing: False 
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            dataset:
         | 
| 42 | 
            +
              densify_until_iter: 15000
         | 
| 43 | 
            +
              source_path:  '' #path to dataset
         | 
| 44 | 
            +
              model_path:  '' #path to logs
         | 
| 45 | 
            +
              images: images
         | 
| 46 | 
            +
              resolution: -1
         | 
| 47 | 
            +
              white_background: false
         | 
| 48 | 
            +
              data_device: cuda
         | 
| 49 | 
            +
              eval: false
         | 
| 50 | 
            +
              depths: ""
         | 
| 51 | 
            +
              train_test_exp: False
         | 
    	
        configs/train.yaml
    ADDED
    
    | @@ -0,0 +1,38 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            defaults:
         | 
| 2 | 
            +
              - gs: base
         | 
| 3 | 
            +
              - _self_ 
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            seed: 228
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            wandb:
         | 
| 8 | 
            +
              mode: "online" # "disabled" for no logging
         | 
| 9 | 
            +
              entity: "3dcorrespondence"
         | 
| 10 | 
            +
              project: "Adv3DGS"
         | 
| 11 | 
            +
              group: null
         | 
| 12 | 
            +
              name: null
         | 
| 13 | 
            +
              tag: "debug"
         | 
| 14 | 
            +
                
         | 
| 15 | 
            +
            train:
         | 
| 16 | 
            +
              gs_epochs: 0 # number of 3dgs iterations
         | 
| 17 | 
            +
              reduce_opacity: True 
         | 
| 18 | 
            +
              no_densify: False # if True, the model will not be densified
         | 
| 19 | 
            +
              max_lr: True 
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            load:
         | 
| 22 | 
            +
              gs: null #path to 3dgs checkpoint
         | 
| 23 | 
            +
              gs_step: null #number of iterations, e.g. 7000
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            device: "cuda:0"
         | 
| 26 | 
            +
            verbose: true
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            init_wC:
         | 
| 29 | 
            +
              use: True # use EDGS
         | 
| 30 | 
            +
              matches_per_ref: 15_000 # number of matches per reference
         | 
| 31 | 
            +
              num_refs: 180 # number of reference images
         | 
| 32 | 
            +
              nns_per_ref: 3 # number of nearest neighbors per reference
         | 
| 33 | 
            +
              scaling_factor: 0.001
         | 
| 34 | 
            +
              proj_err_tolerance: 0.01
         | 
| 35 | 
            +
              roma_model: "outdoors" # you can change this to "indoors" or "outdoors"
         | 
| 36 | 
            +
              add_SfM_init : False
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,32 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            --extra-index-url https://download.pytorch.org/whl/cu124
         | 
| 2 | 
            +
            torch
         | 
| 3 | 
            +
            torchvision
         | 
| 4 | 
            +
            torchaudio 
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            # Required libraries from pip
         | 
| 8 | 
            +
            Pillow
         | 
| 9 | 
            +
            huggingface_hub
         | 
| 10 | 
            +
            einops
         | 
| 11 | 
            +
            safetensors
         | 
| 12 | 
            +
            sympy==1.13.1
         | 
| 13 | 
            +
            wandb
         | 
| 14 | 
            +
            hydra-core
         | 
| 15 | 
            +
            tqdm
         | 
| 16 | 
            +
            torchmetrics
         | 
| 17 | 
            +
            lpips
         | 
| 18 | 
            +
            matplotlib
         | 
| 19 | 
            +
            rich
         | 
| 20 | 
            +
            plyfile
         | 
| 21 | 
            +
            imageio
         | 
| 22 | 
            +
            imageio-ffmpeg
         | 
| 23 | 
            +
            numpy==1.26.4  # Match conda-installed version
         | 
| 24 | 
            +
            opencv-python
         | 
| 25 | 
            +
            pycolmap
         | 
| 26 | 
            +
            moviepy
         | 
| 27 | 
            +
            plotly
         | 
| 28 | 
            +
            scikit-learn
         | 
| 29 | 
            +
            ffmpeg
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            https://huggingface.co/spaces/magistrkoljan/test/resolve/main/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl?download=true
         | 
| 32 | 
            +
            https://huggingface.co/spaces/magistrkoljan/test/resolve/main/wheels/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl?download=true
         | 
    	
        source/EDGS.code-workspace
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
            	"folders": [
         | 
| 3 | 
            +
            		{
         | 
| 4 | 
            +
            			"path": ".."
         | 
| 5 | 
            +
            		},
         | 
| 6 | 
            +
            		{
         | 
| 7 | 
            +
            			"path": "../../../../.."
         | 
| 8 | 
            +
            		}
         | 
| 9 | 
            +
            	],
         | 
| 10 | 
            +
            	"settings": {}
         | 
| 11 | 
            +
            }
         | 
    	
        source/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        source/corr_init.py
    ADDED
    
    | @@ -0,0 +1,682 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            sys.path.append('../')
         | 
| 3 | 
            +
            sys.path.append("../submodules")
         | 
| 4 | 
            +
            sys.path.append('../submodules/RoMa')
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from matplotlib import pyplot as plt
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            #from tqdm import tqdm_notebook as tqdm
         | 
| 12 | 
            +
            from tqdm import tqdm
         | 
| 13 | 
            +
            from scipy.cluster.vq import kmeans, vq
         | 
| 14 | 
            +
            from scipy.spatial.distance import cdist
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
            from romatch import roma_outdoor, roma_indoor
         | 
| 18 | 
            +
            from utils.sh_utils import RGB2SH
         | 
| 19 | 
            +
            from romatch.utils import get_tuple_transform_ops
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def pairwise_distances(matrix):
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                Computes the pairwise Euclidean distances between all vectors in the input matrix.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                Args:
         | 
| 27 | 
            +
                    matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Returns:
         | 
| 30 | 
            +
                    torch.Tensor: Pairwise distance matrix of shape [N, N].
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                # Compute squared pairwise distances
         | 
| 33 | 
            +
                squared_diff = torch.cdist(matrix, matrix, p=2)
         | 
| 34 | 
            +
                return squared_diff
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def k_closest_vectors(matrix, k):
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                Args:
         | 
| 42 | 
            +
                    matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
         | 
| 43 | 
            +
                    k (int): Number of closest vectors to return for each vector.
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                Returns:
         | 
| 46 | 
            +
                    torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself.
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                # Compute pairwise distances
         | 
| 49 | 
            +
                distances = pairwise_distances(matrix)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                # For each vector, sort distances and get the indices of the k-closest vectors (excluding itself)
         | 
| 52 | 
            +
                # Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors
         | 
| 53 | 
            +
                distances.fill_diagonal_(float('inf'))
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                # Get the indices of the k smallest distances (k-closest vectors)
         | 
| 56 | 
            +
                _, indices = torch.topk(distances, k, largest=False, dim=1)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                return indices
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def select_cameras_kmeans(cameras, K):
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
                Selects K cameras from a set using K-means clustering.
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                Args:
         | 
| 66 | 
            +
                    cameras: NumPy array of shape (N, 16), representing N cameras with their 4x4 homogeneous matrices flattened.
         | 
| 67 | 
            +
                    K: Number of clusters (cameras to select).
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                Returns:
         | 
| 70 | 
            +
                    selected_indices: List of indices of the cameras closest to the cluster centers.
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
                # Ensure input is a NumPy array
         | 
| 73 | 
            +
                if not isinstance(cameras, np.ndarray):
         | 
| 74 | 
            +
                    cameras = np.asarray(cameras)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                if cameras.shape[1] != 16:
         | 
| 77 | 
            +
                    raise ValueError("Each camera must have 16 values corresponding to a flattened 4x4 matrix.")
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                # Perform K-means clustering
         | 
| 80 | 
            +
                cluster_centers, _ = kmeans(cameras, K)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                # Assign each camera to a cluster and find distances to cluster centers
         | 
| 83 | 
            +
                cluster_assignments, _ = vq(cameras, cluster_centers)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                # Find the camera nearest to each cluster center
         | 
| 86 | 
            +
                selected_indices = []
         | 
| 87 | 
            +
                for k in range(K):
         | 
| 88 | 
            +
                    cluster_members = cameras[cluster_assignments == k]
         | 
| 89 | 
            +
                    distances = cdist([cluster_centers[k]], cluster_members)[0]
         | 
| 90 | 
            +
                    nearest_camera_idx = np.where(cluster_assignments == k)[0][np.argmin(distances)]
         | 
| 91 | 
            +
                    selected_indices.append(nearest_camera_idx)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                return selected_indices
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            def compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, device="cuda", verbose=False, output_dict={}):
         | 
| 97 | 
            +
                """
         | 
| 98 | 
            +
                Computes the warp and confidence between two viewpoint cameras using the roma_model.
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                Args:
         | 
| 101 | 
            +
                    viewpoint_cam1: Source viewpoint camera.
         | 
| 102 | 
            +
                    viewpoint_cam2: Target viewpoint camera.
         | 
| 103 | 
            +
                    roma_model: Pre-trained Roma model for correspondence matching.
         | 
| 104 | 
            +
                    device: Device to run the computation on.
         | 
| 105 | 
            +
                    verbose: If True, displays the images.
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                Returns:
         | 
| 108 | 
            +
                    certainty: Confidence tensor.
         | 
| 109 | 
            +
                    warp: Warp tensor.
         | 
| 110 | 
            +
                    imB: Processed image B as numpy array.
         | 
| 111 | 
            +
                """
         | 
| 112 | 
            +
                # Prepare images
         | 
| 113 | 
            +
                imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 114 | 
            +
                imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 115 | 
            +
                imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
         | 
| 116 | 
            +
                imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                if verbose:
         | 
| 119 | 
            +
                    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))
         | 
| 120 | 
            +
                    cax1 = ax[0].imshow(imA)
         | 
| 121 | 
            +
                    ax[0].set_title("Image 1")
         | 
| 122 | 
            +
                    cax2 = ax[1].imshow(imB)
         | 
| 123 | 
            +
                    ax[1].set_title("Image 2")
         | 
| 124 | 
            +
                    fig.colorbar(cax1, ax=ax[0])
         | 
| 125 | 
            +
                    fig.colorbar(cax2, ax=ax[1])
         | 
| 126 | 
            +
                
         | 
| 127 | 
            +
                    for axis in ax:
         | 
| 128 | 
            +
                        axis.axis('off')
         | 
| 129 | 
            +
                    # Save the figure into the dictionary
         | 
| 130 | 
            +
                    output_dict[f'image_pair'] = fig
         | 
| 131 | 
            +
               
         | 
| 132 | 
            +
                # Transform images
         | 
| 133 | 
            +
                ws, hs = roma_model.w_resized, roma_model.h_resized
         | 
| 134 | 
            +
                test_transform = get_tuple_transform_ops(resize=(hs, ws), normalize=True)
         | 
| 135 | 
            +
                im_A, im_B = test_transform((imA, imB))
         | 
| 136 | 
            +
                batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                # Forward pass through Roma model
         | 
| 139 | 
            +
                corresps = roma_model.forward(batch) if not roma_model.symmetric else roma_model.forward_symmetric(batch)
         | 
| 140 | 
            +
                finest_scale = 1
         | 
| 141 | 
            +
                hs, ws = roma_model.upsample_res if roma_model.upsample_preds else (hs, ws)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                # Process certainty and warp
         | 
| 144 | 
            +
                certainty = corresps[finest_scale]["certainty"]
         | 
| 145 | 
            +
                im_A_to_im_B = corresps[finest_scale]["flow"]
         | 
| 146 | 
            +
                if roma_model.attenuate_cert:
         | 
| 147 | 
            +
                    low_res_certainty = F.interpolate(
         | 
| 148 | 
            +
                        corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
         | 
| 149 | 
            +
                    )
         | 
| 150 | 
            +
                    certainty -= 0.5 * low_res_certainty * (low_res_certainty < 0)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                # Upsample predictions if needed
         | 
| 153 | 
            +
                if roma_model.upsample_preds:
         | 
| 154 | 
            +
                    im_A_to_im_B = F.interpolate(
         | 
| 155 | 
            +
                        im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
         | 
| 156 | 
            +
                    )
         | 
| 157 | 
            +
                    certainty = F.interpolate(
         | 
| 158 | 
            +
                        certainty, size=(hs, ws), align_corners=False, mode="bilinear"
         | 
| 159 | 
            +
                    )
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                # Convert predictions to final format
         | 
| 162 | 
            +
                im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
         | 
| 163 | 
            +
                im_A_coords = torch.stack(torch.meshgrid(
         | 
| 164 | 
            +
                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
         | 
| 165 | 
            +
                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
         | 
| 166 | 
            +
                    indexing='ij'
         | 
| 167 | 
            +
                ), dim=0).permute(1, 2, 0).unsqueeze(0).expand(im_A_to_im_B.size(0), -1, -1, -1)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
         | 
| 170 | 
            +
                certainty = certainty.sigmoid()
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                return certainty[0, 0], warp[0], np.array(imB)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
             | 
| 175 | 
            +
            def resize_batch(tensors_3d, tensors_4d, target_shape):
         | 
| 176 | 
            +
                """
         | 
| 177 | 
            +
                Resizes a batch of tensors with shapes [B, H, W] and [B, H, W, 4] to the target spatial dimensions.
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                Args:
         | 
| 180 | 
            +
                    tensors_3d: Tensor of shape [B, H, W].
         | 
| 181 | 
            +
                    tensors_4d: Tensor of shape [B, H, W, 4].
         | 
| 182 | 
            +
                    target_shape: Tuple (target_H, target_W) specifying the target spatial dimensions.
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                Returns:
         | 
| 185 | 
            +
                    resized_tensors_3d: Tensor of shape [B, target_H, target_W].
         | 
| 186 | 
            +
                    resized_tensors_4d: Tensor of shape [B, target_H, target_W, 4].
         | 
| 187 | 
            +
                """
         | 
| 188 | 
            +
                target_H, target_W = target_shape
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                # Resize [B, H, W] tensor
         | 
| 191 | 
            +
                resized_tensors_3d = F.interpolate(
         | 
| 192 | 
            +
                    tensors_3d.unsqueeze(1), size=(target_H, target_W), mode="bilinear", align_corners=False
         | 
| 193 | 
            +
                ).squeeze(1)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                # Resize [B, H, W, 4] tensor
         | 
| 196 | 
            +
                B, _, _, C = tensors_4d.shape
         | 
| 197 | 
            +
                resized_tensors_4d = F.interpolate(
         | 
| 198 | 
            +
                    tensors_4d.permute(0, 3, 1, 2), size=(target_H, target_W), mode="bilinear", align_corners=False
         | 
| 199 | 
            +
                ).permute(0, 2, 3, 1)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                return resized_tensors_3d, resized_tensors_4d
         | 
| 202 | 
            +
             | 
| 203 | 
            +
             | 
| 204 | 
            +
            def aggregate_confidences_and_warps(viewpoint_stack, closest_indices, roma_model, source_idx, verbose=False, output_dict={}):
         | 
| 205 | 
            +
                """
         | 
| 206 | 
            +
                Aggregates confidences and warps by iterating over the nearest neighbors of the source viewpoint.
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                Args:
         | 
| 209 | 
            +
                    viewpoint_stack: Stack of viewpoint cameras.
         | 
| 210 | 
            +
                    closest_indices: Indices of the nearest neighbors for each viewpoint.
         | 
| 211 | 
            +
                    roma_model: Pre-trained Roma model.
         | 
| 212 | 
            +
                    source_idx: Index of the source viewpoint.
         | 
| 213 | 
            +
                    verbose: If True, displays intermediate results.
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                Returns:
         | 
| 216 | 
            +
                    certainties_max: Aggregated maximum confidences.
         | 
| 217 | 
            +
                    warps_max: Aggregated warps corresponding to maximum confidences.
         | 
| 218 | 
            +
                    certainties_max_idcs: Pixel-wise index of the image  from which we taken the best matching.
         | 
| 219 | 
            +
                    imB_compound: List of the neighboring images.
         | 
| 220 | 
            +
                """
         | 
| 221 | 
            +
                certainties_all, warps_all, imB_compound = [], [], []
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                for nn in tqdm(closest_indices[source_idx]):
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    viewpoint_cam1 = viewpoint_stack[source_idx]
         | 
| 226 | 
            +
                    viewpoint_cam2 = viewpoint_stack[nn]
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    certainty, warp, imB = compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, verbose=verbose, output_dict=output_dict)
         | 
| 229 | 
            +
                    certainties_all.append(certainty)
         | 
| 230 | 
            +
                    warps_all.append(warp)
         | 
| 231 | 
            +
                    imB_compound.append(imB)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                certainties_all = torch.stack(certainties_all, dim=0)
         | 
| 234 | 
            +
                target_shape = imB_compound[0].shape[:2]
         | 
| 235 | 
            +
                if verbose: 
         | 
| 236 | 
            +
                    print("certainties_all.shape:", certainties_all.shape)
         | 
| 237 | 
            +
                    print("torch.stack(warps_all, dim=0).shape:", torch.stack(warps_all, dim=0).shape)
         | 
| 238 | 
            +
                    print("target_shape:", target_shape)        
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                certainties_all_resized, warps_all_resized = resize_batch(certainties_all,
         | 
| 241 | 
            +
                                                                          torch.stack(warps_all, dim=0),
         | 
| 242 | 
            +
                                                                          target_shape
         | 
| 243 | 
            +
                                                                          )
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                if verbose:
         | 
| 246 | 
            +
                    print("warps_all_resized.shape:", warps_all_resized.shape)
         | 
| 247 | 
            +
                    for n, cert in enumerate(certainties_all):
         | 
| 248 | 
            +
                        fig, ax = plt.subplots()
         | 
| 249 | 
            +
                        cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
         | 
| 250 | 
            +
                        fig.colorbar(cax, ax=ax)
         | 
| 251 | 
            +
                        ax.set_title("Pixel-wise Confidence")
         | 
| 252 | 
            +
                        output_dict[f'certainty_{n}'] = fig
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    for n, warp in enumerate(warps_all):
         | 
| 255 | 
            +
                        fig, ax = plt.subplots()
         | 
| 256 | 
            +
                        cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
         | 
| 257 | 
            +
                        fig.colorbar(cax, ax=ax)
         | 
| 258 | 
            +
                        ax.set_title("Pixel-wise warp")
         | 
| 259 | 
            +
                        output_dict[f'warp_resized_{n}'] = fig
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    for n, cert in enumerate(certainties_all_resized):
         | 
| 262 | 
            +
                        fig, ax = plt.subplots()
         | 
| 263 | 
            +
                        cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
         | 
| 264 | 
            +
                        fig.colorbar(cax, ax=ax)
         | 
| 265 | 
            +
                        ax.set_title("Pixel-wise Confidence resized")
         | 
| 266 | 
            +
                        output_dict[f'certainty_resized_{n}'] = fig
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    for n, warp in enumerate(warps_all_resized):
         | 
| 269 | 
            +
                        fig, ax = plt.subplots()
         | 
| 270 | 
            +
                        cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
         | 
| 271 | 
            +
                        fig.colorbar(cax, ax=ax)
         | 
| 272 | 
            +
                        ax.set_title("Pixel-wise warp resized")
         | 
| 273 | 
            +
                        output_dict[f'warp_resized_{n}'] = fig
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                certainties_max, certainties_max_idcs = torch.max(certainties_all_resized, dim=0)
         | 
| 276 | 
            +
                H, W = certainties_max.shape
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                warps_max = warps_all_resized[certainties_max_idcs, torch.arange(H).unsqueeze(1), torch.arange(W)]
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 281 | 
            +
                imA = np.clip(imA * 255, 0, 255).astype(np.uint8)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                return certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all_resized, warps_all_resized
         | 
| 284 | 
            +
             | 
| 285 | 
            +
             | 
| 286 | 
            +
             | 
| 287 | 
            +
            def extract_keypoints_and_colors(imA, imB_compound, certainties_max, certainties_max_idcs, matches, roma_model,
         | 
| 288 | 
            +
                                             verbose=False, output_dict={}):
         | 
| 289 | 
            +
                """
         | 
| 290 | 
            +
                Extracts keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                Args:
         | 
| 293 | 
            +
                    imA: Source image as a NumPy array (H_A, W_A, C).
         | 
| 294 | 
            +
                    imB_compound: List of target images as NumPy arrays [(H_B, W_B, C), ...].
         | 
| 295 | 
            +
                    certainties_max: Tensor of pixel-wise maximum confidences.
         | 
| 296 | 
            +
                    certainties_max_idcs: Tensor of pixel-wise indices for the best matches.
         | 
| 297 | 
            +
                    matches: Matches in normalized coordinates.
         | 
| 298 | 
            +
                    roma_model: Roma model instance for keypoint operations.
         | 
| 299 | 
            +
                    verbose: if to show intermediate outputs and visualize results
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                Returns:
         | 
| 302 | 
            +
                    kptsA_np: Keypoints in imA in normalized coordinates.
         | 
| 303 | 
            +
                    kptsB_np: Keypoints in imB in normalized coordinates.
         | 
| 304 | 
            +
                    kptsA_color: Colors of keypoints in imA.
         | 
| 305 | 
            +
                    kptsB_color: Colors of keypoints in imB based on certainties_max_idcs.
         | 
| 306 | 
            +
                """
         | 
| 307 | 
            +
                H_A, W_A, _ = imA.shape
         | 
| 308 | 
            +
                H, W = certainties_max.shape
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                # Convert matches to pixel coordinates
         | 
| 311 | 
            +
                kptsA, kptsB = roma_model.to_pixel_coordinates(
         | 
| 312 | 
            +
                    matches, W_A, H_A, H, W  # W, H
         | 
| 313 | 
            +
                )
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                kptsA_np = kptsA.detach().cpu().numpy()
         | 
| 316 | 
            +
                kptsB_np = kptsB.detach().cpu().numpy()
         | 
| 317 | 
            +
                kptsA_np = kptsA_np[:, [1, 0]]
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                if verbose:
         | 
| 320 | 
            +
                    fig, ax = plt.subplots(figsize=(12, 6))
         | 
| 321 | 
            +
                    cax = ax.imshow(imA)
         | 
| 322 | 
            +
                    ax.set_title("Reference image, imA")
         | 
| 323 | 
            +
                    output_dict[f'reference_image'] = fig
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    fig, ax = plt.subplots(figsize=(12, 6))
         | 
| 326 | 
            +
                    cax = ax.imshow(imB_compound[0])
         | 
| 327 | 
            +
                    ax.set_title("Image to compare to image, imB_compound")
         | 
| 328 | 
            +
                    output_dict[f'imB_compound'] = fig
         | 
| 329 | 
            +
                
         | 
| 330 | 
            +
                    fig, ax = plt.subplots(figsize=(12, 6))
         | 
| 331 | 
            +
                    cax = ax.imshow(np.flipud(imA))
         | 
| 332 | 
            +
                    cax = ax.scatter(kptsA_np[:, 0], H_A - kptsA_np[:, 1], s=.03)
         | 
| 333 | 
            +
                    ax.set_title("Keypoints in imA")
         | 
| 334 | 
            +
                    ax.set_xlim(0, W_A)
         | 
| 335 | 
            +
                    ax.set_ylim(0, H_A)
         | 
| 336 | 
            +
                    output_dict[f'kptsA'] = fig
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    fig, ax = plt.subplots(figsize=(12, 6))
         | 
| 339 | 
            +
                    cax = ax.imshow(np.flipud(imB_compound[0]))
         | 
| 340 | 
            +
                    cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
         | 
| 341 | 
            +
                    ax.set_title("Keypoints in imB")
         | 
| 342 | 
            +
                    ax.set_xlim(0, W_A)
         | 
| 343 | 
            +
                    ax.set_ylim(0, H_A)
         | 
| 344 | 
            +
                    output_dict[f'kptsB'] = fig
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                # Keypoints are in format (row, column) so the first value is alwain in range [0;height] and second is in range[0;width]
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                kptsA_np = kptsA.detach().cpu().numpy()
         | 
| 349 | 
            +
                kptsB_np = kptsB.detach().cpu().numpy()
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                # Extract colors for keypoints in imA (vectorized)
         | 
| 352 | 
            +
                # New experimental version
         | 
| 353 | 
            +
                kptsA_x = np.round(kptsA_np[:, 0] / 1.).astype(int)
         | 
| 354 | 
            +
                kptsA_y = np.round(kptsA_np[:, 1] / 1.).astype(int)
         | 
| 355 | 
            +
                kptsA_color = imA[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
         | 
| 356 | 
            +
               
         | 
| 357 | 
            +
                # Create a composite image from imB_compound
         | 
| 358 | 
            +
                imB_compound_np = np.stack(imB_compound, axis=0)
         | 
| 359 | 
            +
                H_B, W_B, _ = imB_compound[0].shape
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                # Extract colors for keypoints in imB using certainties_max_idcs
         | 
| 362 | 
            +
                imB_np = imB_compound_np[
         | 
| 363 | 
            +
                        certainties_max_idcs.detach().cpu().numpy(),
         | 
| 364 | 
            +
                        np.arange(H).reshape(-1, 1),
         | 
| 365 | 
            +
                        np.arange(W)
         | 
| 366 | 
            +
                    ]
         | 
| 367 | 
            +
                
         | 
| 368 | 
            +
                if verbose:
         | 
| 369 | 
            +
                    print("imB_np.shape:", imB_np.shape)
         | 
| 370 | 
            +
                    print("imB_np:", imB_np)
         | 
| 371 | 
            +
                    fig, ax = plt.subplots(figsize=(12, 6))
         | 
| 372 | 
            +
                    cax = ax.imshow(np.flipud(imB_np))
         | 
| 373 | 
            +
                    cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
         | 
| 374 | 
            +
                    ax.set_title("np.flipud(imB_np[0]")
         | 
| 375 | 
            +
                    ax.set_xlim(0, W_A)
         | 
| 376 | 
            +
                    ax.set_ylim(0, H_A)
         | 
| 377 | 
            +
                    output_dict[f'np.flipud(imB_np[0]'] = fig
         | 
| 378 | 
            +
             | 
| 379 | 
            +
             | 
| 380 | 
            +
                kptsB_x = np.round(kptsB_np[:, 0]).astype(int)
         | 
| 381 | 
            +
                kptsB_y = np.round(kptsB_np[:, 1]).astype(int)
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                certainties_max_idcs_np = certainties_max_idcs.detach().cpu().numpy()
         | 
| 384 | 
            +
                kptsB_proj_matrices_idx = certainties_max_idcs_np[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
         | 
| 385 | 
            +
                kptsB_color = imB_compound_np[kptsB_proj_matrices_idx, np.clip(kptsB_y, 0, H - 1), np.clip(kptsB_x, 0, W - 1)]
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                # Normalize keypoints in both images
         | 
| 388 | 
            +
                kptsA_np[:, 0] = kptsA_np[:, 0] / H * 2.0 - 1.0
         | 
| 389 | 
            +
                kptsA_np[:, 1] = kptsA_np[:, 1] / W * 2.0 - 1.0
         | 
| 390 | 
            +
                kptsB_np[:, 0] = kptsB_np[:, 0] / W_B * 2.0 - 1.0
         | 
| 391 | 
            +
                kptsB_np[:, 1] = kptsB_np[:, 1] / H_B * 2.0 - 1.0
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                return kptsA_np[:, [1, 0]], kptsB_np, kptsB_proj_matrices_idx, kptsA_color, kptsB_color
         | 
| 394 | 
            +
             | 
| 395 | 
            +
            def prepare_tensor(input_array, device):
         | 
| 396 | 
            +
                """
         | 
| 397 | 
            +
                Converts an input array to a torch tensor, clones it, and detaches it for safe computation.
         | 
| 398 | 
            +
                Args:
         | 
| 399 | 
            +
                    input_array (array-like): The input array to convert.
         | 
| 400 | 
            +
                    device (str or torch.device): The device to move the tensor to.
         | 
| 401 | 
            +
                Returns:
         | 
| 402 | 
            +
                    torch.Tensor: A detached tensor clone of the input array on the specified device.
         | 
| 403 | 
            +
                """
         | 
| 404 | 
            +
                if not isinstance(input_array, torch.Tensor):
         | 
| 405 | 
            +
                    return torch.tensor(input_array, dtype=torch.float32).to(device).clone().detach()
         | 
| 406 | 
            +
                return input_array.clone().detach().to(device).to(torch.float32)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
            def triangulate_points(P1, P2, k1_x, k1_y, k2_x, k2_y, device="cuda"):
         | 
| 409 | 
            +
                """
         | 
| 410 | 
            +
                Solves for a batch of 3D points given batches of projection matrices and corresponding image points.
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                Parameters:
         | 
| 413 | 
            +
                - P1, P2: Tensors of projection matrices of size (batch_size, 4, 4) or (4, 4)
         | 
| 414 | 
            +
                - k1_x, k1_y: Tensors of shape (batch_size,)
         | 
| 415 | 
            +
                - k2_x, k2_y: Tensors of shape (batch_size,)
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                Returns:
         | 
| 418 | 
            +
                - X: A tensor containing the 3D homogeneous coordinates, shape (batch_size, 4)
         | 
| 419 | 
            +
                """
         | 
| 420 | 
            +
                EPS = 1e-4
         | 
| 421 | 
            +
                # Ensure inputs are tensors
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                P1 = prepare_tensor(P1, device)
         | 
| 424 | 
            +
                P2 = prepare_tensor(P2, device)
         | 
| 425 | 
            +
                k1_x = prepare_tensor(k1_x, device)
         | 
| 426 | 
            +
                k1_y = prepare_tensor(k1_y, device)
         | 
| 427 | 
            +
                k2_x = prepare_tensor(k2_x, device)
         | 
| 428 | 
            +
                k2_y =  prepare_tensor(k2_y, device)
         | 
| 429 | 
            +
                batch_size = k1_x.shape[0]
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                # Expand P1 and P2 if they are not batched
         | 
| 432 | 
            +
                if P1.ndim == 2:
         | 
| 433 | 
            +
                    P1 = P1.unsqueeze(0).expand(batch_size, -1, -1)
         | 
| 434 | 
            +
                if P2.ndim == 2:
         | 
| 435 | 
            +
                    P2 = P2.unsqueeze(0).expand(batch_size, -1, -1)
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                # Extract columns from P1 and P2
         | 
| 438 | 
            +
                P1_0 = P1[:, :, 0]  # Shape: (batch_size, 4)
         | 
| 439 | 
            +
                P1_1 = P1[:, :, 1]
         | 
| 440 | 
            +
                P1_2 = P1[:, :, 2]
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                P2_0 = P2[:, :, 0]
         | 
| 443 | 
            +
                P2_1 = P2[:, :, 1]
         | 
| 444 | 
            +
                P2_2 = P2[:, :, 2]
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                # Reshape kx and ky to (batch_size, 1)
         | 
| 447 | 
            +
                k1_x = k1_x.view(-1, 1)
         | 
| 448 | 
            +
                k1_y = k1_y.view(-1, 1)
         | 
| 449 | 
            +
                k2_x = k2_x.view(-1, 1)
         | 
| 450 | 
            +
                k2_y = k2_y.view(-1, 1)
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                # Construct the equations for each batch
         | 
| 453 | 
            +
                # For camera 1
         | 
| 454 | 
            +
                A1 = P1_0 - k1_x * P1_2  # Shape: (batch_size, 4)
         | 
| 455 | 
            +
                A2 = P1_1 - k1_y * P1_2
         | 
| 456 | 
            +
                # For camera 2
         | 
| 457 | 
            +
                A3 = P2_0 - k2_x * P2_2
         | 
| 458 | 
            +
                A4 = P2_1 - k2_y * P2_2
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                # Stack the equations
         | 
| 461 | 
            +
                A = torch.stack([A1, A2, A3, A4], dim=1)  # Shape: (batch_size, 4, 4)
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                # Right-hand side (constants)
         | 
| 464 | 
            +
                b = -A[:, :, 3]  # Shape: (batch_size, 4)
         | 
| 465 | 
            +
                A_reduced = A[:, :, :3]  # Coefficients of x, y, z
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                # Solve using torch.linalg.lstsq (supports batching)
         | 
| 468 | 
            +
                X_xyz = torch.linalg.lstsq(A_reduced, b.unsqueeze(2)).solution.squeeze(2)  # Shape: (batch_size, 3)
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                # Append 1 to get homogeneous coordinates
         | 
| 471 | 
            +
                ones = torch.ones((batch_size, 1), dtype=torch.float32, device=X_xyz.device)
         | 
| 472 | 
            +
                X = torch.cat([X_xyz, ones], dim=1)  # Shape: (batch_size, 4)
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                # Now compute the errors of projections.
         | 
| 475 | 
            +
                seeked_splats_proj1 = (X.unsqueeze(1) @ P1).squeeze(1)
         | 
| 476 | 
            +
                seeked_splats_proj1 = seeked_splats_proj1 / (EPS + seeked_splats_proj1[:, [3]])
         | 
| 477 | 
            +
                seeked_splats_proj2 = (X.unsqueeze(1) @ P2).squeeze(1)
         | 
| 478 | 
            +
                seeked_splats_proj2 = seeked_splats_proj2 / (EPS + seeked_splats_proj2[:, [3]])
         | 
| 479 | 
            +
                proj1_target = torch.concat([k1_x, k1_y], dim=1)
         | 
| 480 | 
            +
                proj2_target = torch.concat([k2_x, k2_y], dim=1)
         | 
| 481 | 
            +
                errors_proj1 = torch.abs(seeked_splats_proj1[:, :2] - proj1_target).sum(1).detach().cpu().numpy()
         | 
| 482 | 
            +
                errors_proj2 = torch.abs(seeked_splats_proj2[:, :2] - proj2_target).sum(1).detach().cpu().numpy()
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                return X, errors_proj1, errors_proj2
         | 
| 485 | 
            +
             | 
| 486 | 
            +
             | 
| 487 | 
            +
             | 
| 488 | 
            +
            def select_best_keypoints(
         | 
| 489 | 
            +
                    NNs_triangulated_points, NNs_errors_proj1, NNs_errors_proj2, device="cuda"):
         | 
| 490 | 
            +
                """
         | 
| 491 | 
            +
                From all the points fitted to  keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                Args:
         | 
| 494 | 
            +
                    NNs_triangulated_points:  torch tensor with keypoints coordinates (num_nns, num_points, dim). dim can be arbitrary,
         | 
| 495 | 
            +
                        usually 3 or 4(for homogeneous representation).
         | 
| 496 | 
            +
                    NNs_errors_proj1:  numpy array with projection error of the estimated keypoint on the reference frame (num_nns, num_points).
         | 
| 497 | 
            +
                    NNs_errors_proj2:  numpy array with projection error of the estimated keypoint on the neighbor frame (num_nns, num_points).
         | 
| 498 | 
            +
                Returns:
         | 
| 499 | 
            +
                    selected_keypoints: keypoints with the best score.
         | 
| 500 | 
            +
                """
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                NNs_errors_proj = np.maximum(NNs_errors_proj1, NNs_errors_proj2)
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                # Convert indices to PyTorch tensor
         | 
| 505 | 
            +
                indices = torch.from_numpy(np.argmin(NNs_errors_proj, axis=0)).long().to(device)
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                # Create index tensor for the second dimension
         | 
| 508 | 
            +
                n_indices = torch.arange(NNs_triangulated_points.shape[1]).long().to(device)
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                # Use advanced indexing to select elements
         | 
| 511 | 
            +
                NNs_triangulated_points_selected = NNs_triangulated_points[indices, n_indices, :]  # Shape: [N, k]
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                return NNs_triangulated_points_selected, np.min(NNs_errors_proj, axis=0)
         | 
| 514 | 
            +
             | 
| 515 | 
            +
             | 
| 516 | 
            +
             | 
| 517 | 
            +
            def init_gaussians_with_corr(gaussians, scene, cfg, device, verbose = False, roma_model=None):
         | 
| 518 | 
            +
                """
         | 
| 519 | 
            +
                For a given input gaussians and a scene we instantiate a RoMa model(change to indoors if necessary) and process scene
         | 
| 520 | 
            +
                training frames to extract correspondences. Those are used to initialize gaussians
         | 
| 521 | 
            +
                Args:
         | 
| 522 | 
            +
                    gaussians: object gaussians of the class GaussianModel that we need to enrich with gaussians.
         | 
| 523 | 
            +
                    scene: object of the Scene class.
         | 
| 524 | 
            +
                    cfg: configuration. Use init_wC
         | 
| 525 | 
            +
                Returns:
         | 
| 526 | 
            +
                    gaussians: inplace transforms object gaussians of the class GaussianModel.
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                """
         | 
| 529 | 
            +
                if roma_model is None:
         | 
| 530 | 
            +
                    if cfg.roma_model == "indoors":
         | 
| 531 | 
            +
                        roma_model = roma_indoor(device=device)
         | 
| 532 | 
            +
                    else:
         | 
| 533 | 
            +
                        roma_model = roma_outdoor(device=device)
         | 
| 534 | 
            +
                    roma_model.upsample_preds = False
         | 
| 535 | 
            +
                    roma_model.symmetric = False
         | 
| 536 | 
            +
                M = cfg.matches_per_ref
         | 
| 537 | 
            +
                upper_thresh = roma_model.sample_thresh
         | 
| 538 | 
            +
                scaling_factor = cfg.scaling_factor
         | 
| 539 | 
            +
                expansion_factor = 1
         | 
| 540 | 
            +
                keypoint_fit_error_tolerance = cfg.proj_err_tolerance
         | 
| 541 | 
            +
                visualizations = {}
         | 
| 542 | 
            +
                viewpoint_stack = scene.getTrainCameras().copy()
         | 
| 543 | 
            +
                NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack))
         | 
| 544 | 
            +
                NUM_NNS_PER_REFERENCE = min(cfg.nns_per_ref , len(viewpoint_stack))
         | 
| 545 | 
            +
                # Select cameras using K-means
         | 
| 546 | 
            +
                viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES)
         | 
| 549 | 
            +
                selected_indices = sorted(selected_indices)
         | 
| 550 | 
            +
               
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                # Find the k-closest vectors for each vector
         | 
| 553 | 
            +
                viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
         | 
| 554 | 
            +
                closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE)
         | 
| 555 | 
            +
                if verbose: print("Indices of k-closest vectors for each vector:\n", closest_indices)
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                closest_indices_selected = closest_indices[:, :].detach().cpu().numpy()
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                all_new_xyz = []
         | 
| 560 | 
            +
                all_new_features_dc = []
         | 
| 561 | 
            +
                all_new_features_rest = []
         | 
| 562 | 
            +
                all_new_opacities = []
         | 
| 563 | 
            +
                all_new_scaling = []
         | 
| 564 | 
            +
                all_new_rotation = []
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                # Run roma_model.match once to kinda initialize the model
         | 
| 567 | 
            +
                with torch.no_grad():
         | 
| 568 | 
            +
                    viewpoint_cam1 = viewpoint_stack[0]
         | 
| 569 | 
            +
                    viewpoint_cam2 = viewpoint_stack[1]
         | 
| 570 | 
            +
                    imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 571 | 
            +
                    imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 572 | 
            +
                    imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
         | 
| 573 | 
            +
                    imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
         | 
| 574 | 
            +
                    warp, certainty_warp = roma_model.match(imA, imB, device=device)
         | 
| 575 | 
            +
                    print("Once run full roma_model.match warp.shape:", warp.shape)
         | 
| 576 | 
            +
                    print("Once run full roma_model.match certainty_warp.shape:", certainty_warp.shape)
         | 
| 577 | 
            +
                    del warp, certainty_warp
         | 
| 578 | 
            +
                    torch.cuda.empty_cache()
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                for source_idx in tqdm(sorted(selected_indices)):
         | 
| 581 | 
            +
                    # 1. Compute keypoints and warping for all the neigboring views
         | 
| 582 | 
            +
                    with torch.no_grad():
         | 
| 583 | 
            +
                        # Call the aggregation function to get imA and imB_compound
         | 
| 584 | 
            +
                        certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all, warps_all = aggregate_confidences_and_warps(
         | 
| 585 | 
            +
                            viewpoint_stack=viewpoint_stack,
         | 
| 586 | 
            +
                            closest_indices=closest_indices_selected,
         | 
| 587 | 
            +
                            roma_model=roma_model,
         | 
| 588 | 
            +
                            source_idx=source_idx,
         | 
| 589 | 
            +
                            verbose=verbose, output_dict=visualizations
         | 
| 590 | 
            +
                        )
         | 
| 591 | 
            +
             | 
| 592 | 
            +
             | 
| 593 | 
            +
                    # Triangulate keypoints
         | 
| 594 | 
            +
                    with torch.no_grad():
         | 
| 595 | 
            +
                        matches = warps_max
         | 
| 596 | 
            +
                        certainty = certainties_max
         | 
| 597 | 
            +
                        certainty = certainty.clone()
         | 
| 598 | 
            +
                        certainty[certainty > upper_thresh] = 1
         | 
| 599 | 
            +
                        matches, certainty = (
         | 
| 600 | 
            +
                            matches.reshape(-1, 4),
         | 
| 601 | 
            +
                            certainty.reshape(-1),
         | 
| 602 | 
            +
                        )
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                        # Select based on certainty elements with high confidence. These are basically all of
         | 
| 605 | 
            +
                        # kptsA_np.
         | 
| 606 | 
            +
                        good_samples = torch.multinomial(certainty,
         | 
| 607 | 
            +
                                                         num_samples=min(expansion_factor * M, len(certainty)),
         | 
| 608 | 
            +
                                                         replacement=False)
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                    certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all, warps_all
         | 
| 611 | 
            +
                    reference_image_dict = {
         | 
| 612 | 
            +
                        "ref_image": imA,
         | 
| 613 | 
            +
                        "NNs_images": imB_compound,
         | 
| 614 | 
            +
                        "certainties_all": certainties_all,
         | 
| 615 | 
            +
                        "warps_all": warps_all,
         | 
| 616 | 
            +
                        "triangulated_points": [],
         | 
| 617 | 
            +
                        "triangulated_points_errors_proj1": [],
         | 
| 618 | 
            +
                        "triangulated_points_errors_proj2": []
         | 
| 619 | 
            +
             | 
| 620 | 
            +
                    }
         | 
| 621 | 
            +
                    with torch.no_grad():
         | 
| 622 | 
            +
                        for NN_idx in tqdm(range(len(warps_all))):
         | 
| 623 | 
            +
                            matches_NN = warps_all[NN_idx].reshape(-1, 4)[good_samples]
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                            # Extract keypoints and colors
         | 
| 626 | 
            +
                            kptsA_np, kptsB_np, kptsB_proj_matrices_idcs, kptsA_color, kptsB_color = extract_keypoints_and_colors(
         | 
| 627 | 
            +
                                imA, imB_compound, certainties_max, certainties_max_idcs, matches_NN, roma_model
         | 
| 628 | 
            +
                            )
         | 
| 629 | 
            +
             | 
| 630 | 
            +
                            proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform
         | 
| 631 | 
            +
                            proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, NN_idx]].full_proj_transform
         | 
| 632 | 
            +
                            triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points(
         | 
| 633 | 
            +
                                P1=torch.stack([proj_matrices_A] * M, axis=0),
         | 
| 634 | 
            +
                                P2=torch.stack([proj_matrices_B] * M, axis=0),
         | 
| 635 | 
            +
                                k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1],
         | 
| 636 | 
            +
                                k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1])
         | 
| 637 | 
            +
             | 
| 638 | 
            +
                            reference_image_dict["triangulated_points"].append(triangulated_points)
         | 
| 639 | 
            +
                            reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1)
         | 
| 640 | 
            +
                            reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2)
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                    with torch.no_grad():
         | 
| 643 | 
            +
                        NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints(
         | 
| 644 | 
            +
                            NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0),
         | 
| 645 | 
            +
                            NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0),
         | 
| 646 | 
            +
                            NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0))
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                    # 4. Save as gaussians
         | 
| 649 | 
            +
                    viewpoint_cam1 = viewpoint_stack[source_idx]
         | 
| 650 | 
            +
                    N = len(NNs_triangulated_points_selected)
         | 
| 651 | 
            +
                    with torch.no_grad():
         | 
| 652 | 
            +
                        new_xyz = NNs_triangulated_points_selected[:, :-1]
         | 
| 653 | 
            +
                        all_new_xyz.append(new_xyz)  # seeked_splats
         | 
| 654 | 
            +
                        all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1))
         | 
| 655 | 
            +
                        all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0))
         | 
| 656 | 
            +
                        # new version that sets points with large error invisible
         | 
| 657 | 
            +
                        # TODO: remove those points instead. However it doesn't affect the performance.
         | 
| 658 | 
            +
                        mask_bad_points = torch.tensor(
         | 
| 659 | 
            +
                            NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance,
         | 
| 660 | 
            +
                            dtype=torch.float32).unsqueeze(1).to(device)
         | 
| 661 | 
            +
                        all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1))
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                        dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz,
         | 
| 664 | 
            +
                                                                dim=1, ord=2)
         | 
| 665 | 
            +
                        #all_new_scaling.append(torch.log(((dist_points_to_cam1) / 1. * scaling_factor).unsqueeze(1).repeat(1, 3)))
         | 
| 666 | 
            +
                        all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3)))
         | 
| 667 | 
            +
                        all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0))
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                all_new_xyz = torch.cat(all_new_xyz, dim=0) 
         | 
| 670 | 
            +
                all_new_features_dc = torch.cat(all_new_features_dc, dim=0)
         | 
| 671 | 
            +
                new_tmp_radii = torch.zeros(all_new_xyz.shape[0])
         | 
| 672 | 
            +
                prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool)
         | 
| 673 | 
            +
                
         | 
| 674 | 
            +
                gaussians.densification_postfix(all_new_xyz[prune_mask].to(device),
         | 
| 675 | 
            +
                                                all_new_features_dc[prune_mask].to(device),
         | 
| 676 | 
            +
                                                torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device),
         | 
| 677 | 
            +
                                                torch.cat(all_new_opacities, dim=0)[prune_mask].to(device),
         | 
| 678 | 
            +
                                                torch.cat(all_new_scaling, dim=0)[prune_mask].to(device),
         | 
| 679 | 
            +
                                                torch.cat(all_new_rotation, dim=0)[prune_mask].to(device),
         | 
| 680 | 
            +
                                                new_tmp_radii[prune_mask].to(device))
         | 
| 681 | 
            +
                
         | 
| 682 | 
            +
                return viewpoint_stack, closest_indices_selected, visualizations
         | 
    	
        source/corr_init_new.py
    ADDED
    
    | @@ -0,0 +1,904 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            sys.path.append('../')
         | 
| 3 | 
            +
            sys.path.append("../submodules")
         | 
| 4 | 
            +
            sys.path.append('../submodules/RoMa')
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from matplotlib import pyplot as plt
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            #from tqdm import tqdm_notebook as tqdm
         | 
| 12 | 
            +
            from tqdm import tqdm
         | 
| 13 | 
            +
            from scipy.cluster.vq import kmeans, vq
         | 
| 14 | 
            +
            from scipy.spatial.distance import cdist
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
            from romatch import roma_outdoor, roma_indoor
         | 
| 18 | 
            +
            from utils.sh_utils import RGB2SH
         | 
| 19 | 
            +
            from romatch.utils import get_tuple_transform_ops
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def pairwise_distances(matrix):
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                Computes the pairwise Euclidean distances between all vectors in the input matrix.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                Args:
         | 
| 27 | 
            +
                    matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Returns:
         | 
| 30 | 
            +
                    torch.Tensor: Pairwise distance matrix of shape [N, N].
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                # Compute squared pairwise distances
         | 
| 33 | 
            +
                squared_diff = torch.cdist(matrix, matrix, p=2)
         | 
| 34 | 
            +
                return squared_diff
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def k_closest_vectors(matrix, k):
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                Args:
         | 
| 42 | 
            +
                    matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
         | 
| 43 | 
            +
                    k (int): Number of closest vectors to return for each vector.
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                Returns:
         | 
| 46 | 
            +
                    torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself.
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                # Compute pairwise distances
         | 
| 49 | 
            +
                distances = pairwise_distances(matrix)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                # For each vector, sort distances and get the indices of the k-closest vectors (excluding itself)
         | 
| 52 | 
            +
                # Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors
         | 
| 53 | 
            +
                distances.fill_diagonal_(float('inf'))
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                # Get the indices of the k smallest distances (k-closest vectors)
         | 
| 56 | 
            +
                _, indices = torch.topk(distances, k, largest=False, dim=1)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                return indices
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def select_cameras_kmeans(cameras, K):
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
                Selects K cameras from a set using K-means clustering.
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                Args:
         | 
| 66 | 
            +
                    cameras: NumPy array of shape (N, 16), representing N cameras with their 4x4 homogeneous matrices flattened.
         | 
| 67 | 
            +
                    K: Number of clusters (cameras to select).
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                Returns:
         | 
| 70 | 
            +
                    selected_indices: List of indices of the cameras closest to the cluster centers.
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
                # Ensure input is a NumPy array
         | 
| 73 | 
            +
                if not isinstance(cameras, np.ndarray):
         | 
| 74 | 
            +
                    cameras = np.asarray(cameras)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                if cameras.shape[1] != 16:
         | 
| 77 | 
            +
                    raise ValueError("Each camera must have 16 values corresponding to a flattened 4x4 matrix.")
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                # Perform K-means clustering
         | 
| 80 | 
            +
                cluster_centers, _ = kmeans(cameras, K)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                # Assign each camera to a cluster and find distances to cluster centers
         | 
| 83 | 
            +
                cluster_assignments, _ = vq(cameras, cluster_centers)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                # Find the camera nearest to each cluster center
         | 
| 86 | 
            +
                selected_indices = []
         | 
| 87 | 
            +
                for k in range(K):
         | 
| 88 | 
            +
                    cluster_members = cameras[cluster_assignments == k]
         | 
| 89 | 
            +
                    distances = cdist([cluster_centers[k]], cluster_members)[0]
         | 
| 90 | 
            +
                    nearest_camera_idx = np.where(cluster_assignments == k)[0][np.argmin(distances)]
         | 
| 91 | 
            +
                    selected_indices.append(nearest_camera_idx)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                return selected_indices
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            def compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, device="cuda", verbose=False, output_dict={}):
         | 
| 97 | 
            +
                """
         | 
| 98 | 
            +
                Computes the warp and confidence between two viewpoint cameras using the roma_model.
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                Args:
         | 
| 101 | 
            +
                    viewpoint_cam1: Source viewpoint camera.
         | 
| 102 | 
            +
                    viewpoint_cam2: Target viewpoint camera.
         | 
| 103 | 
            +
                    roma_model: Pre-trained Roma model for correspondence matching.
         | 
| 104 | 
            +
                    device: Device to run the computation on.
         | 
| 105 | 
            +
                    verbose: If True, displays the images.
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                Returns:
         | 
| 108 | 
            +
                    certainty: Confidence tensor.
         | 
| 109 | 
            +
                    warp: Warp tensor.
         | 
| 110 | 
            +
                    imB: Processed image B as numpy array.
         | 
| 111 | 
            +
                """
         | 
| 112 | 
            +
                # Prepare images
         | 
| 113 | 
            +
                imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 114 | 
            +
                imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 115 | 
            +
                imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
         | 
| 116 | 
            +
                imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                if verbose:
         | 
| 119 | 
            +
                    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))
         | 
| 120 | 
            +
                    cax1 = ax[0].imshow(imA)
         | 
| 121 | 
            +
                    ax[0].set_title("Image 1")
         | 
| 122 | 
            +
                    cax2 = ax[1].imshow(imB)
         | 
| 123 | 
            +
                    ax[1].set_title("Image 2")
         | 
| 124 | 
            +
                    fig.colorbar(cax1, ax=ax[0])
         | 
| 125 | 
            +
                    fig.colorbar(cax2, ax=ax[1])
         | 
| 126 | 
            +
                
         | 
| 127 | 
            +
                    for axis in ax:
         | 
| 128 | 
            +
                        axis.axis('off')
         | 
| 129 | 
            +
                    # Save the figure into the dictionary
         | 
| 130 | 
            +
                    output_dict[f'image_pair'] = fig
         | 
| 131 | 
            +
               
         | 
| 132 | 
            +
                # Transform images
         | 
| 133 | 
            +
                ws, hs = roma_model.w_resized, roma_model.h_resized
         | 
| 134 | 
            +
                test_transform = get_tuple_transform_ops(resize=(hs, ws), normalize=True)
         | 
| 135 | 
            +
                im_A, im_B = test_transform((imA, imB))
         | 
| 136 | 
            +
                batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                # Forward pass through Roma model
         | 
| 139 | 
            +
                corresps = roma_model.forward(batch) if not roma_model.symmetric else roma_model.forward_symmetric(batch)
         | 
| 140 | 
            +
                finest_scale = 1
         | 
| 141 | 
            +
                hs, ws = roma_model.upsample_res if roma_model.upsample_preds else (hs, ws)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                # Process certainty and warp
         | 
| 144 | 
            +
                certainty = corresps[finest_scale]["certainty"]
         | 
| 145 | 
            +
                im_A_to_im_B = corresps[finest_scale]["flow"]
         | 
| 146 | 
            +
                if roma_model.attenuate_cert:
         | 
| 147 | 
            +
                    low_res_certainty = F.interpolate(
         | 
| 148 | 
            +
                        corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
         | 
| 149 | 
            +
                    )
         | 
| 150 | 
            +
                    certainty -= 0.5 * low_res_certainty * (low_res_certainty < 0)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                # Upsample predictions if needed
         | 
| 153 | 
            +
                if roma_model.upsample_preds:
         | 
| 154 | 
            +
                    im_A_to_im_B = F.interpolate(
         | 
| 155 | 
            +
                        im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
         | 
| 156 | 
            +
                    )
         | 
| 157 | 
            +
                    certainty = F.interpolate(
         | 
| 158 | 
            +
                        certainty, size=(hs, ws), align_corners=False, mode="bilinear"
         | 
| 159 | 
            +
                    )
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                # Convert predictions to final format
         | 
| 162 | 
            +
                im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
         | 
| 163 | 
            +
                im_A_coords = torch.stack(torch.meshgrid(
         | 
| 164 | 
            +
                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
         | 
| 165 | 
            +
                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
         | 
| 166 | 
            +
                    indexing='ij'
         | 
| 167 | 
            +
                ), dim=0).permute(1, 2, 0).unsqueeze(0).expand(im_A_to_im_B.size(0), -1, -1, -1)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
         | 
| 170 | 
            +
                certainty = certainty.sigmoid()
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                return certainty[0, 0], warp[0], np.array(imB)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
             | 
| 175 | 
            +
            def resize_batch(tensors_3d, tensors_4d, target_shape):
         | 
| 176 | 
            +
                """
         | 
| 177 | 
            +
                Resizes a batch of tensors with shapes [B, H, W] and [B, H, W, 4] to the target spatial dimensions.
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                Args:
         | 
| 180 | 
            +
                    tensors_3d: Tensor of shape [B, H, W].
         | 
| 181 | 
            +
                    tensors_4d: Tensor of shape [B, H, W, 4].
         | 
| 182 | 
            +
                    target_shape: Tuple (target_H, target_W) specifying the target spatial dimensions.
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                Returns:
         | 
| 185 | 
            +
                    resized_tensors_3d: Tensor of shape [B, target_H, target_W].
         | 
| 186 | 
            +
                    resized_tensors_4d: Tensor of shape [B, target_H, target_W, 4].
         | 
| 187 | 
            +
                """
         | 
| 188 | 
            +
                target_H, target_W = target_shape
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                # Resize [B, H, W] tensor
         | 
| 191 | 
            +
                resized_tensors_3d = F.interpolate(
         | 
| 192 | 
            +
                    tensors_3d.unsqueeze(1), size=(target_H, target_W), mode="bilinear", align_corners=False
         | 
| 193 | 
            +
                ).squeeze(1)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                # Resize [B, H, W, 4] tensor
         | 
| 196 | 
            +
                B, _, _, C = tensors_4d.shape
         | 
| 197 | 
            +
                resized_tensors_4d = F.interpolate(
         | 
| 198 | 
            +
                    tensors_4d.permute(0, 3, 1, 2), size=(target_H, target_W), mode="bilinear", align_corners=False
         | 
| 199 | 
            +
                ).permute(0, 2, 3, 1)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                return resized_tensors_3d, resized_tensors_4d
         | 
| 202 | 
            +
             | 
| 203 | 
            +
             | 
| 204 | 
            +
            def aggregate_confidences_and_warps(viewpoint_stack, closest_indices, roma_model, source_idx, verbose=False, output_dict={}):
         | 
| 205 | 
            +
                """
         | 
| 206 | 
            +
                Aggregates confidences and warps by iterating over the nearest neighbors of the source viewpoint.
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                Args:
         | 
| 209 | 
            +
                    viewpoint_stack: Stack of viewpoint cameras.
         | 
| 210 | 
            +
                    closest_indices: Indices of the nearest neighbors for each viewpoint.
         | 
| 211 | 
            +
                    roma_model: Pre-trained Roma model.
         | 
| 212 | 
            +
                    source_idx: Index of the source viewpoint.
         | 
| 213 | 
            +
                    verbose: If True, displays intermediate results.
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                Returns:
         | 
| 216 | 
            +
                    certainties_max: Aggregated maximum confidences.
         | 
| 217 | 
            +
                    warps_max: Aggregated warps corresponding to maximum confidences.
         | 
| 218 | 
            +
                    certainties_max_idcs: Pixel-wise index of the image  from which we taken the best matching.
         | 
| 219 | 
            +
                    imB_compound: List of the neighboring images.
         | 
| 220 | 
            +
                """
         | 
| 221 | 
            +
                certainties_all, warps_all, imB_compound = [], [], []
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                for nn in tqdm(closest_indices[source_idx]):
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    viewpoint_cam1 = viewpoint_stack[source_idx]
         | 
| 226 | 
            +
                    viewpoint_cam2 = viewpoint_stack[nn]
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    certainty, warp, imB = compute_warp_and_confidence(viewpoint_cam1, viewpoint_cam2, roma_model, verbose=verbose, output_dict=output_dict)
         | 
| 229 | 
            +
                    certainties_all.append(certainty)
         | 
| 230 | 
            +
                    warps_all.append(warp)
         | 
| 231 | 
            +
                    imB_compound.append(imB)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                certainties_all = torch.stack(certainties_all, dim=0)
         | 
| 234 | 
            +
                target_shape = imB_compound[0].shape[:2]
         | 
| 235 | 
            +
                if verbose: 
         | 
| 236 | 
            +
                    print("certainties_all.shape:", certainties_all.shape)
         | 
| 237 | 
            +
                    print("torch.stack(warps_all, dim=0).shape:", torch.stack(warps_all, dim=0).shape)
         | 
| 238 | 
            +
                    print("target_shape:", target_shape)        
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                certainties_all_resized, warps_all_resized = resize_batch(certainties_all,
         | 
| 241 | 
            +
                                                                          torch.stack(warps_all, dim=0),
         | 
| 242 | 
            +
                                                                          target_shape
         | 
| 243 | 
            +
                                                                          )
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                if verbose:
         | 
| 246 | 
            +
                    print("warps_all_resized.shape:", warps_all_resized.shape)
         | 
| 247 | 
            +
                    for n, cert in enumerate(certainties_all):
         | 
| 248 | 
            +
                        fig, ax = plt.subplots()
         | 
| 249 | 
            +
                        cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
         | 
| 250 | 
            +
                        fig.colorbar(cax, ax=ax)
         | 
| 251 | 
            +
                        ax.set_title("Pixel-wise Confidence")
         | 
| 252 | 
            +
                        output_dict[f'certainty_{n}'] = fig
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    for n, warp in enumerate(warps_all):
         | 
| 255 | 
            +
                        fig, ax = plt.subplots()
         | 
| 256 | 
            +
                        cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
         | 
| 257 | 
            +
                        fig.colorbar(cax, ax=ax)
         | 
| 258 | 
            +
                        ax.set_title("Pixel-wise warp")
         | 
| 259 | 
            +
                        output_dict[f'warp_resized_{n}'] = fig
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    for n, cert in enumerate(certainties_all_resized):
         | 
| 262 | 
            +
                        fig, ax = plt.subplots()
         | 
| 263 | 
            +
                        cax = ax.imshow(cert.cpu().numpy(), cmap='viridis')
         | 
| 264 | 
            +
                        fig.colorbar(cax, ax=ax)
         | 
| 265 | 
            +
                        ax.set_title("Pixel-wise Confidence resized")
         | 
| 266 | 
            +
                        output_dict[f'certainty_resized_{n}'] = fig
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    for n, warp in enumerate(warps_all_resized):
         | 
| 269 | 
            +
                        fig, ax = plt.subplots()
         | 
| 270 | 
            +
                        cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis')
         | 
| 271 | 
            +
                        fig.colorbar(cax, ax=ax)
         | 
| 272 | 
            +
                        ax.set_title("Pixel-wise warp resized")
         | 
| 273 | 
            +
                        output_dict[f'warp_resized_{n}'] = fig
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                certainties_max, certainties_max_idcs = torch.max(certainties_all_resized, dim=0)
         | 
| 276 | 
            +
                H, W = certainties_max.shape
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                warps_max = warps_all_resized[certainties_max_idcs, torch.arange(H).unsqueeze(1), torch.arange(W)]
         | 
| 279 | 
            +
             | 
| 280 | 
            +
             | 
| 281 | 
            +
                return certainties_max, warps_max, certainties_max_idcs, imB_compound, certainties_all_resized, warps_all_resized
         | 
| 282 | 
            +
             | 
| 283 | 
            +
             | 
| 284 | 
            +
             | 
| 285 | 
            +
            def extract_keypoints_and_colors(imA, imB_compound, certainties_max, certainties_max_idcs, matches, roma_model,
         | 
| 286 | 
            +
                                             verbose=False, output_dict={}):
         | 
| 287 | 
            +
                """
         | 
| 288 | 
            +
                Extracts keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                Args:
         | 
| 291 | 
            +
                    imA: Source image as a NumPy array (H_A, W_A, C).
         | 
| 292 | 
            +
                    imB_compound: List of target images as NumPy arrays [(H_B, W_B, C), ...].
         | 
| 293 | 
            +
                    certainties_max: Tensor of pixel-wise maximum confidences.
         | 
| 294 | 
            +
                    certainties_max_idcs: Tensor of pixel-wise indices for the best matches.
         | 
| 295 | 
            +
                    matches: Matches in normalized coordinates.
         | 
| 296 | 
            +
                    roma_model: Roma model instance for keypoint operations.
         | 
| 297 | 
            +
                    verbose: if to show intermediate outputs and visualize results
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                Returns:
         | 
| 300 | 
            +
                    kptsA_np: Keypoints in imA in normalized coordinates.
         | 
| 301 | 
            +
                    kptsB_np: Keypoints in imB in normalized coordinates.
         | 
| 302 | 
            +
                    kptsA_color: Colors of keypoints in imA.
         | 
| 303 | 
            +
                    kptsB_color: Colors of keypoints in imB based on certainties_max_idcs.
         | 
| 304 | 
            +
                """
         | 
| 305 | 
            +
                H_A, W_A, _ = imA.shape
         | 
| 306 | 
            +
                H, W = certainties_max.shape
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                # Convert matches to pixel coordinates
         | 
| 309 | 
            +
                kptsA, kptsB = roma_model.to_pixel_coordinates(
         | 
| 310 | 
            +
                    matches, W_A, H_A, H, W  # W, H
         | 
| 311 | 
            +
                )
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                kptsA_np = kptsA.detach().cpu().numpy()
         | 
| 314 | 
            +
                kptsB_np = kptsB.detach().cpu().numpy()
         | 
| 315 | 
            +
                kptsA_np = kptsA_np[:, [1, 0]]
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                if verbose:
         | 
| 318 | 
            +
                    fig, ax = plt.subplots(figsize=(12, 6))
         | 
| 319 | 
            +
                    cax = ax.imshow(imA)
         | 
| 320 | 
            +
                    ax.set_title("Reference image, imA")
         | 
| 321 | 
            +
                    output_dict[f'reference_image'] = fig
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    fig, ax = plt.subplots(figsize=(12, 6))
         | 
| 324 | 
            +
                    cax = ax.imshow(imB_compound[0])
         | 
| 325 | 
            +
                    ax.set_title("Image to compare to image, imB_compound")
         | 
| 326 | 
            +
                    output_dict[f'imB_compound'] = fig
         | 
| 327 | 
            +
                
         | 
| 328 | 
            +
                    fig, ax = plt.subplots(figsize=(12, 6))
         | 
| 329 | 
            +
                    cax = ax.imshow(np.flipud(imA))
         | 
| 330 | 
            +
                    cax = ax.scatter(kptsA_np[:, 0], H_A - kptsA_np[:, 1], s=.03)
         | 
| 331 | 
            +
                    ax.set_title("Keypoints in imA")
         | 
| 332 | 
            +
                    ax.set_xlim(0, W_A)
         | 
| 333 | 
            +
                    ax.set_ylim(0, H_A)
         | 
| 334 | 
            +
                    output_dict[f'kptsA'] = fig
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    fig, ax = plt.subplots(figsize=(12, 6))
         | 
| 337 | 
            +
                    cax = ax.imshow(np.flipud(imB_compound[0]))
         | 
| 338 | 
            +
                    cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
         | 
| 339 | 
            +
                    ax.set_title("Keypoints in imB")
         | 
| 340 | 
            +
                    ax.set_xlim(0, W_A)
         | 
| 341 | 
            +
                    ax.set_ylim(0, H_A)
         | 
| 342 | 
            +
                    output_dict[f'kptsB'] = fig
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                # Keypoints are in format (row, column) so the first value is alwain in range [0;height] and second is in range[0;width]
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                kptsA_np = kptsA.detach().cpu().numpy()
         | 
| 347 | 
            +
                kptsB_np = kptsB.detach().cpu().numpy()
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                # Extract colors for keypoints in imA (vectorized)
         | 
| 350 | 
            +
                # New experimental version
         | 
| 351 | 
            +
                kptsA_x = np.round(kptsA_np[:, 0] / 1.).astype(int)
         | 
| 352 | 
            +
                kptsA_y = np.round(kptsA_np[:, 1] / 1.).astype(int)
         | 
| 353 | 
            +
                kptsA_color = imA[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
         | 
| 354 | 
            +
               
         | 
| 355 | 
            +
                # Create a composite image from imB_compound
         | 
| 356 | 
            +
                imB_compound_np = np.stack(imB_compound, axis=0)
         | 
| 357 | 
            +
                H_B, W_B, _ = imB_compound[0].shape
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                # Extract colors for keypoints in imB using certainties_max_idcs
         | 
| 360 | 
            +
                imB_np = imB_compound_np[
         | 
| 361 | 
            +
                        certainties_max_idcs.detach().cpu().numpy(),
         | 
| 362 | 
            +
                        np.arange(H).reshape(-1, 1),
         | 
| 363 | 
            +
                        np.arange(W)
         | 
| 364 | 
            +
                    ]
         | 
| 365 | 
            +
                
         | 
| 366 | 
            +
                if verbose:
         | 
| 367 | 
            +
                    print("imB_np.shape:", imB_np.shape)
         | 
| 368 | 
            +
                    print("imB_np:", imB_np)
         | 
| 369 | 
            +
                    fig, ax = plt.subplots(figsize=(12, 6))
         | 
| 370 | 
            +
                    cax = ax.imshow(np.flipud(imB_np))
         | 
| 371 | 
            +
                    cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03)
         | 
| 372 | 
            +
                    ax.set_title("np.flipud(imB_np[0]")
         | 
| 373 | 
            +
                    ax.set_xlim(0, W_A)
         | 
| 374 | 
            +
                    ax.set_ylim(0, H_A)
         | 
| 375 | 
            +
                    output_dict[f'np.flipud(imB_np[0]'] = fig
         | 
| 376 | 
            +
             | 
| 377 | 
            +
             | 
| 378 | 
            +
                kptsB_x = np.round(kptsB_np[:, 0]).astype(int)
         | 
| 379 | 
            +
                kptsB_y = np.round(kptsB_np[:, 1]).astype(int)
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                certainties_max_idcs_np = certainties_max_idcs.detach().cpu().numpy()
         | 
| 382 | 
            +
                kptsB_proj_matrices_idx = certainties_max_idcs_np[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)]
         | 
| 383 | 
            +
                kptsB_color = imB_compound_np[kptsB_proj_matrices_idx, np.clip(kptsB_y, 0, H - 1), np.clip(kptsB_x, 0, W - 1)]
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                # Normalize keypoints in both images
         | 
| 386 | 
            +
                kptsA_np[:, 0] = kptsA_np[:, 0] / H * 2.0 - 1.0
         | 
| 387 | 
            +
                kptsA_np[:, 1] = kptsA_np[:, 1] / W * 2.0 - 1.0
         | 
| 388 | 
            +
                kptsB_np[:, 0] = kptsB_np[:, 0] / W_B * 2.0 - 1.0
         | 
| 389 | 
            +
                kptsB_np[:, 1] = kptsB_np[:, 1] / H_B * 2.0 - 1.0
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                return kptsA_np[:, [1, 0]], kptsB_np, kptsB_proj_matrices_idx, kptsA_color, kptsB_color
         | 
| 392 | 
            +
             | 
| 393 | 
            +
            def prepare_tensor(input_array, device):
         | 
| 394 | 
            +
                """
         | 
| 395 | 
            +
                Converts an input array to a torch tensor, clones it, and detaches it for safe computation.
         | 
| 396 | 
            +
                Args:
         | 
| 397 | 
            +
                    input_array (array-like): The input array to convert.
         | 
| 398 | 
            +
                    device (str or torch.device): The device to move the tensor to.
         | 
| 399 | 
            +
                Returns:
         | 
| 400 | 
            +
                    torch.Tensor: A detached tensor clone of the input array on the specified device.
         | 
| 401 | 
            +
                """
         | 
| 402 | 
            +
                if not isinstance(input_array, torch.Tensor):
         | 
| 403 | 
            +
                    return torch.tensor(input_array, dtype=torch.float32).to(device).clone().detach()
         | 
| 404 | 
            +
                return input_array.clone().detach().to(device).to(torch.float32)
         | 
| 405 | 
            +
             | 
| 406 | 
            +
            def triangulate_points(P1, P2, k1_x, k1_y, k2_x, k2_y, device="cuda"):
         | 
| 407 | 
            +
                """
         | 
| 408 | 
            +
                Solves for a batch of 3D points given batches of projection matrices and corresponding image points.
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                Parameters:
         | 
| 411 | 
            +
                - P1, P2: Tensors of projection matrices of size (batch_size, 4, 4) or (4, 4)
         | 
| 412 | 
            +
                - k1_x, k1_y: Tensors of shape (batch_size,)
         | 
| 413 | 
            +
                - k2_x, k2_y: Tensors of shape (batch_size,)
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                Returns:
         | 
| 416 | 
            +
                - X: A tensor containing the 3D homogeneous coordinates, shape (batch_size, 4)
         | 
| 417 | 
            +
                """
         | 
| 418 | 
            +
                EPS = 1e-4
         | 
| 419 | 
            +
                # Ensure inputs are tensors
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                P1 = prepare_tensor(P1, device)
         | 
| 422 | 
            +
                P2 = prepare_tensor(P2, device)
         | 
| 423 | 
            +
                k1_x = prepare_tensor(k1_x, device)
         | 
| 424 | 
            +
                k1_y = prepare_tensor(k1_y, device)
         | 
| 425 | 
            +
                k2_x = prepare_tensor(k2_x, device)
         | 
| 426 | 
            +
                k2_y =  prepare_tensor(k2_y, device)
         | 
| 427 | 
            +
                batch_size = k1_x.shape[0]
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                # Expand P1 and P2 if they are not batched
         | 
| 430 | 
            +
                if P1.ndim == 2:
         | 
| 431 | 
            +
                    P1 = P1.unsqueeze(0).expand(batch_size, -1, -1)
         | 
| 432 | 
            +
                if P2.ndim == 2:
         | 
| 433 | 
            +
                    P2 = P2.unsqueeze(0).expand(batch_size, -1, -1)
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                # Extract columns from P1 and P2
         | 
| 436 | 
            +
                P1_0 = P1[:, :, 0]  # Shape: (batch_size, 4)
         | 
| 437 | 
            +
                P1_1 = P1[:, :, 1]
         | 
| 438 | 
            +
                P1_2 = P1[:, :, 2]
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                P2_0 = P2[:, :, 0]
         | 
| 441 | 
            +
                P2_1 = P2[:, :, 1]
         | 
| 442 | 
            +
                P2_2 = P2[:, :, 2]
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                # Reshape kx and ky to (batch_size, 1)
         | 
| 445 | 
            +
                k1_x = k1_x.view(-1, 1)
         | 
| 446 | 
            +
                k1_y = k1_y.view(-1, 1)
         | 
| 447 | 
            +
                k2_x = k2_x.view(-1, 1)
         | 
| 448 | 
            +
                k2_y = k2_y.view(-1, 1)
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                # Construct the equations for each batch
         | 
| 451 | 
            +
                # For camera 1
         | 
| 452 | 
            +
                A1 = P1_0 - k1_x * P1_2  # Shape: (batch_size, 4)
         | 
| 453 | 
            +
                A2 = P1_1 - k1_y * P1_2
         | 
| 454 | 
            +
                # For camera 2
         | 
| 455 | 
            +
                A3 = P2_0 - k2_x * P2_2
         | 
| 456 | 
            +
                A4 = P2_1 - k2_y * P2_2
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                # Stack the equations
         | 
| 459 | 
            +
                A = torch.stack([A1, A2, A3, A4], dim=1)  # Shape: (batch_size, 4, 4)
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                # Right-hand side (constants)
         | 
| 462 | 
            +
                b = -A[:, :, 3]  # Shape: (batch_size, 4)
         | 
| 463 | 
            +
                A_reduced = A[:, :, :3]  # Coefficients of x, y, z
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                # Solve using torch.linalg.lstsq (supports batching)
         | 
| 466 | 
            +
                X_xyz = torch.linalg.lstsq(A_reduced, b.unsqueeze(2)).solution.squeeze(2)  # Shape: (batch_size, 3)
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                # Append 1 to get homogeneous coordinates
         | 
| 469 | 
            +
                ones = torch.ones((batch_size, 1), dtype=torch.float32, device=X_xyz.device)
         | 
| 470 | 
            +
                X = torch.cat([X_xyz, ones], dim=1)  # Shape: (batch_size, 4)
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                # Now compute the errors of projections.
         | 
| 473 | 
            +
                seeked_splats_proj1 = (X.unsqueeze(1) @ P1).squeeze(1)
         | 
| 474 | 
            +
                seeked_splats_proj1 = seeked_splats_proj1 / (EPS + seeked_splats_proj1[:, [3]])
         | 
| 475 | 
            +
                seeked_splats_proj2 = (X.unsqueeze(1) @ P2).squeeze(1)
         | 
| 476 | 
            +
                seeked_splats_proj2 = seeked_splats_proj2 / (EPS + seeked_splats_proj2[:, [3]])
         | 
| 477 | 
            +
                proj1_target = torch.concat([k1_x, k1_y], dim=1)
         | 
| 478 | 
            +
                proj2_target = torch.concat([k2_x, k2_y], dim=1)
         | 
| 479 | 
            +
                errors_proj1 = torch.abs(seeked_splats_proj1[:, :2] - proj1_target).sum(1).detach().cpu().numpy()
         | 
| 480 | 
            +
                errors_proj2 = torch.abs(seeked_splats_proj2[:, :2] - proj2_target).sum(1).detach().cpu().numpy()
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                return X, errors_proj1, errors_proj2
         | 
| 483 | 
            +
             | 
| 484 | 
            +
             | 
| 485 | 
            +
             | 
| 486 | 
            +
            def select_best_keypoints(
         | 
| 487 | 
            +
                    NNs_triangulated_points, NNs_errors_proj1, NNs_errors_proj2, device="cuda"):
         | 
| 488 | 
            +
                """
         | 
| 489 | 
            +
                From all the points fitted to  keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound).
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                Args:
         | 
| 492 | 
            +
                    NNs_triangulated_points:  torch tensor with keypoints coordinates (num_nns, num_points, dim). dim can be arbitrary,
         | 
| 493 | 
            +
                        usually 3 or 4(for homogeneous representation).
         | 
| 494 | 
            +
                    NNs_errors_proj1:  numpy array with projection error of the estimated keypoint on the reference frame (num_nns, num_points).
         | 
| 495 | 
            +
                    NNs_errors_proj2:  numpy array with projection error of the estimated keypoint on the neighbor frame (num_nns, num_points).
         | 
| 496 | 
            +
                Returns:
         | 
| 497 | 
            +
                    selected_keypoints: keypoints with the best score.
         | 
| 498 | 
            +
                """
         | 
| 499 | 
            +
             | 
| 500 | 
            +
                NNs_errors_proj = np.maximum(NNs_errors_proj1, NNs_errors_proj2)
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                # Convert indices to PyTorch tensor
         | 
| 503 | 
            +
                indices = torch.from_numpy(np.argmin(NNs_errors_proj, axis=0)).long().to(device)
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                # Create index tensor for the second dimension
         | 
| 506 | 
            +
                n_indices = torch.arange(NNs_triangulated_points.shape[1]).long().to(device)
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                # Use advanced indexing to select elements
         | 
| 509 | 
            +
                NNs_triangulated_points_selected = NNs_triangulated_points[indices, n_indices, :]  # Shape: [N, k]
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                return NNs_triangulated_points_selected, np.min(NNs_errors_proj, axis=0)
         | 
| 512 | 
            +
             | 
| 513 | 
            +
             | 
| 514 | 
            +
             | 
| 515 | 
            +
            import time
         | 
| 516 | 
            +
            from collections import defaultdict
         | 
| 517 | 
            +
            from tqdm import tqdm
         | 
| 518 | 
            +
             | 
| 519 | 
            +
            # def init_gaussians_with_corr_profiled(gaussians, scene, cfg, device, verbose=False, roma_model=None):
         | 
| 520 | 
            +
            #     timings = defaultdict(list)  # To accumulate timings
         | 
| 521 | 
            +
             | 
| 522 | 
            +
            #     if roma_model is None:
         | 
| 523 | 
            +
            #         if cfg.roma_model == "indoors":
         | 
| 524 | 
            +
            #             roma_model = roma_indoor(device=device)
         | 
| 525 | 
            +
            #         else:
         | 
| 526 | 
            +
            #             roma_model = roma_outdoor(device=device)
         | 
| 527 | 
            +
            #         roma_model.upsample_preds = False
         | 
| 528 | 
            +
            #         roma_model.symmetric = False
         | 
| 529 | 
            +
             | 
| 530 | 
            +
            #     M = cfg.matches_per_ref
         | 
| 531 | 
            +
            #     upper_thresh = roma_model.sample_thresh
         | 
| 532 | 
            +
            #     scaling_factor = cfg.scaling_factor
         | 
| 533 | 
            +
            #     expansion_factor = 1
         | 
| 534 | 
            +
            #     keypoint_fit_error_tolerance = cfg.proj_err_tolerance
         | 
| 535 | 
            +
            #     visualizations = {}
         | 
| 536 | 
            +
            #     viewpoint_stack = scene.getTrainCameras().copy()
         | 
| 537 | 
            +
            #     NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack))
         | 
| 538 | 
            +
            #     NUM_NNS_PER_REFERENCE = min(cfg.nns_per_ref, len(viewpoint_stack))
         | 
| 539 | 
            +
             | 
| 540 | 
            +
            #     viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
         | 
| 541 | 
            +
             | 
| 542 | 
            +
            #     selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES)
         | 
| 543 | 
            +
            #     selected_indices = sorted(selected_indices)
         | 
| 544 | 
            +
             | 
| 545 | 
            +
            #     viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
         | 
| 546 | 
            +
            #     closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE)
         | 
| 547 | 
            +
            #     closest_indices_selected = closest_indices[:, :].detach().cpu().numpy()
         | 
| 548 | 
            +
             | 
| 549 | 
            +
            #     all_new_xyz = []
         | 
| 550 | 
            +
            #     all_new_features_dc = []
         | 
| 551 | 
            +
            #     all_new_features_rest = []
         | 
| 552 | 
            +
            #     all_new_opacities = []
         | 
| 553 | 
            +
            #     all_new_scaling = []
         | 
| 554 | 
            +
            #     all_new_rotation = []
         | 
| 555 | 
            +
             | 
| 556 | 
            +
            #     # Dummy first pass to initialize model
         | 
| 557 | 
            +
            #     with torch.no_grad():
         | 
| 558 | 
            +
            #         viewpoint_cam1 = viewpoint_stack[0]
         | 
| 559 | 
            +
            #         viewpoint_cam2 = viewpoint_stack[1]
         | 
| 560 | 
            +
            #         imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 561 | 
            +
            #         imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 562 | 
            +
            #         imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
         | 
| 563 | 
            +
            #         imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
         | 
| 564 | 
            +
            #         warp, certainty_warp = roma_model.match(imA, imB, device=device)
         | 
| 565 | 
            +
            #         del warp, certainty_warp
         | 
| 566 | 
            +
            #         torch.cuda.empty_cache()
         | 
| 567 | 
            +
             | 
| 568 | 
            +
            #     # Main Loop over source_idx
         | 
| 569 | 
            +
            #     for source_idx in tqdm(sorted(selected_indices), desc="Profiling source frames"):
         | 
| 570 | 
            +
             | 
| 571 | 
            +
            #         # =================== Step 1: Aggregate Confidences and Warps ===================
         | 
| 572 | 
            +
            #         start = time.time()
         | 
| 573 | 
            +
            #         viewpoint_cam1 = viewpoint_stack[source_idx]
         | 
| 574 | 
            +
            #         viewpoint_cam2 = viewpoint_stack[closest_indices_selected[source_idx,0]]
         | 
| 575 | 
            +
            #         imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 576 | 
            +
            #         imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 577 | 
            +
            #         imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
         | 
| 578 | 
            +
            #         imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
         | 
| 579 | 
            +
            #         warp, certainty_warp = roma_model.match(imA, imB, device=device)
         | 
| 580 | 
            +
             | 
| 581 | 
            +
            #         certainties_max, warps_max, certainties_max_idcs, imB_compound, certainties_all, warps_all = aggregate_confidences_and_warps(
         | 
| 582 | 
            +
            #             viewpoint_stack=viewpoint_stack,
         | 
| 583 | 
            +
            #             closest_indices=closest_indices_selected,
         | 
| 584 | 
            +
            #             roma_model=roma_model,
         | 
| 585 | 
            +
            #             source_idx=source_idx,
         | 
| 586 | 
            +
            #             verbose=verbose, 
         | 
| 587 | 
            +
            #             output_dict=visualizations
         | 
| 588 | 
            +
            #         )
         | 
| 589 | 
            +
             | 
| 590 | 
            +
            #         certainties_max = certainty_warp
         | 
| 591 | 
            +
            #         with torch.no_grad():
         | 
| 592 | 
            +
            #             warps_all = warps.unsqueeze(0)
         | 
| 593 | 
            +
                    
         | 
| 594 | 
            +
            #         timings['aggregation_warp_certainty'].append(time.time() - start)
         | 
| 595 | 
            +
             | 
| 596 | 
            +
            #         # =================== Step 2: Good Samples Selection ===================
         | 
| 597 | 
            +
            #         start = time.time()
         | 
| 598 | 
            +
            #         certainty = certainties_max.reshape(-1).clone()
         | 
| 599 | 
            +
            #         certainty[certainty > upper_thresh] = 1
         | 
| 600 | 
            +
            #         good_samples = torch.multinomial(certainty, num_samples=min(expansion_factor * M, len(certainty)), replacement=False)
         | 
| 601 | 
            +
            #         timings['good_samples_selection'].append(time.time() - start)
         | 
| 602 | 
            +
             | 
| 603 | 
            +
            #         # =================== Step 3: Triangulate Keypoints for Each NN ===================
         | 
| 604 | 
            +
            #         reference_image_dict = {
         | 
| 605 | 
            +
            #             "triangulated_points": [],
         | 
| 606 | 
            +
            #             "triangulated_points_errors_proj1": [],
         | 
| 607 | 
            +
            #             "triangulated_points_errors_proj2": []
         | 
| 608 | 
            +
            #         }
         | 
| 609 | 
            +
             | 
| 610 | 
            +
            #         start = time.time()
         | 
| 611 | 
            +
            #         for NN_idx in range(len(warps_all)):
         | 
| 612 | 
            +
            #             matches_NN = warps_all[NN_idx].reshape(-1, 4)[good_samples]
         | 
| 613 | 
            +
             | 
| 614 | 
            +
            #             # Extract keypoints and colors
         | 
| 615 | 
            +
            #             kptsA_np, kptsB_np, kptsB_proj_matrices_idcs, kptsA_color, kptsB_color = extract_keypoints_and_colors(
         | 
| 616 | 
            +
            #                 imA, imB_compound, certainties_max, certainties_max_idcs, matches_NN, roma_model
         | 
| 617 | 
            +
            #             )
         | 
| 618 | 
            +
             | 
| 619 | 
            +
            #             proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform
         | 
| 620 | 
            +
            #             proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, NN_idx]].full_proj_transform
         | 
| 621 | 
            +
            #             triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points(
         | 
| 622 | 
            +
            #                 P1=torch.stack([proj_matrices_A] * M, axis=0),
         | 
| 623 | 
            +
            #                 P2=torch.stack([proj_matrices_B] * M, axis=0),
         | 
| 624 | 
            +
            #                 k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1],
         | 
| 625 | 
            +
            #                 k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1])
         | 
| 626 | 
            +
             | 
| 627 | 
            +
            #             reference_image_dict["triangulated_points"].append(triangulated_points)
         | 
| 628 | 
            +
            #             reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1)
         | 
| 629 | 
            +
            #             reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2)
         | 
| 630 | 
            +
            #         timings['triangulation_per_NN'].append(time.time() - start)
         | 
| 631 | 
            +
             | 
| 632 | 
            +
            #         # =================== Step 4: Select Best Triangulated Points ===================
         | 
| 633 | 
            +
            #         start = time.time()
         | 
| 634 | 
            +
            #         NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints(
         | 
| 635 | 
            +
            #             NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0),
         | 
| 636 | 
            +
            #             NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0),
         | 
| 637 | 
            +
            #             NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0))
         | 
| 638 | 
            +
            #         timings['select_best_keypoints'].append(time.time() - start)
         | 
| 639 | 
            +
             | 
| 640 | 
            +
            #         # =================== Step 5: Create New Gaussians ===================
         | 
| 641 | 
            +
            #         start = time.time()
         | 
| 642 | 
            +
            #         viewpoint_cam1 = viewpoint_stack[source_idx]
         | 
| 643 | 
            +
            #         N = len(NNs_triangulated_points_selected)
         | 
| 644 | 
            +
            #         new_xyz = NNs_triangulated_points_selected[:, :-1]
         | 
| 645 | 
            +
            #         all_new_xyz.append(new_xyz)
         | 
| 646 | 
            +
            #         all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1))
         | 
| 647 | 
            +
            #         all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0))
         | 
| 648 | 
            +
             | 
| 649 | 
            +
            #         mask_bad_points = torch.tensor(
         | 
| 650 | 
            +
            #             NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance,
         | 
| 651 | 
            +
            #             dtype=torch.float32).unsqueeze(1).to(device)
         | 
| 652 | 
            +
             | 
| 653 | 
            +
            #         all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1))
         | 
| 654 | 
            +
             | 
| 655 | 
            +
            #         dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz, dim=1, ord=2)
         | 
| 656 | 
            +
            #         all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3)))
         | 
| 657 | 
            +
            #         all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0))
         | 
| 658 | 
            +
            #         timings['save_gaussians'].append(time.time() - start)
         | 
| 659 | 
            +
             | 
| 660 | 
            +
            #     # =================== Final Densification Postfix ===================
         | 
| 661 | 
            +
            #     start = time.time()
         | 
| 662 | 
            +
            #     all_new_xyz = torch.cat(all_new_xyz, dim=0) 
         | 
| 663 | 
            +
            #     all_new_features_dc = torch.cat(all_new_features_dc, dim=0)
         | 
| 664 | 
            +
            #     new_tmp_radii = torch.zeros(all_new_xyz.shape[0])
         | 
| 665 | 
            +
            #     prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool)
         | 
| 666 | 
            +
             | 
| 667 | 
            +
            #     gaussians.densification_postfix(
         | 
| 668 | 
            +
            #         all_new_xyz[prune_mask].to(device),
         | 
| 669 | 
            +
            #         all_new_features_dc[prune_mask].to(device),
         | 
| 670 | 
            +
            #         torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device),
         | 
| 671 | 
            +
            #         torch.cat(all_new_opacities, dim=0)[prune_mask].to(device),
         | 
| 672 | 
            +
            #         torch.cat(all_new_scaling, dim=0)[prune_mask].to(device),
         | 
| 673 | 
            +
            #         torch.cat(all_new_rotation, dim=0)[prune_mask].to(device),
         | 
| 674 | 
            +
            #         new_tmp_radii[prune_mask].to(device)
         | 
| 675 | 
            +
            #     )
         | 
| 676 | 
            +
            #     timings['final_densification_postfix'].append(time.time() - start)
         | 
| 677 | 
            +
             | 
| 678 | 
            +
            #     # =================== Print Profiling Results ===================
         | 
| 679 | 
            +
            #     print("\n=== Profiling Summary (average per frame) ===")
         | 
| 680 | 
            +
            #     for key, times in timings.items():
         | 
| 681 | 
            +
            #         print(f"{key:35s}: {sum(times) / len(times):.4f} sec (total {sum(times):.2f} sec)")
         | 
| 682 | 
            +
             | 
| 683 | 
            +
            #     return viewpoint_stack, closest_indices_selected, visualizations
         | 
| 684 | 
            +
             | 
| 685 | 
            +
             | 
| 686 | 
            +
             | 
| 687 | 
            +
            def extract_keypoints_and_colors_single(imA, imB, matches, roma_model, verbose=False, output_dict={}):
         | 
| 688 | 
            +
                """
         | 
| 689 | 
            +
                Extracts keypoints and corresponding colors from a source image (imA) and a single target image (imB).
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                Args:
         | 
| 692 | 
            +
                    imA: Source image as a NumPy array (H_A, W_A, C).
         | 
| 693 | 
            +
                    imB: Target image as a NumPy array (H_B, W_B, C).
         | 
| 694 | 
            +
                    matches: Matches in normalized coordinates (torch.Tensor).
         | 
| 695 | 
            +
                    roma_model: Roma model instance for keypoint operations.
         | 
| 696 | 
            +
                    verbose: If True, outputs intermediate visualizations.
         | 
| 697 | 
            +
                Returns:
         | 
| 698 | 
            +
                    kptsA_np: Keypoints in imA (normalized).
         | 
| 699 | 
            +
                    kptsB_np: Keypoints in imB (normalized).
         | 
| 700 | 
            +
                    kptsA_color: Colors of keypoints in imA.
         | 
| 701 | 
            +
                    kptsB_color: Colors of keypoints in imB.
         | 
| 702 | 
            +
                """
         | 
| 703 | 
            +
                H_A, W_A, _ = imA.shape
         | 
| 704 | 
            +
                H_B, W_B, _ = imB.shape
         | 
| 705 | 
            +
             | 
| 706 | 
            +
                # Convert matches to pixel coordinates
         | 
| 707 | 
            +
                # Matches format: (B, 4) = (x1_norm, y1_norm, x2_norm, y2_norm)
         | 
| 708 | 
            +
                kptsA = matches[:, :2]  # [N, 2]
         | 
| 709 | 
            +
                kptsB = matches[:, 2:]  # [N, 2]
         | 
| 710 | 
            +
             | 
| 711 | 
            +
                # Scale normalized coordinates [-1,1] to pixel coordinates
         | 
| 712 | 
            +
                kptsA_pix = torch.zeros_like(kptsA)
         | 
| 713 | 
            +
                kptsB_pix = torch.zeros_like(kptsB)
         | 
| 714 | 
            +
             | 
| 715 | 
            +
                # Important! [Normalized to pixel space]
         | 
| 716 | 
            +
                kptsA_pix[:, 0] = (kptsA[:, 0] + 1) * (W_A - 1) / 2
         | 
| 717 | 
            +
                kptsA_pix[:, 1] = (kptsA[:, 1] + 1) * (H_A - 1) / 2
         | 
| 718 | 
            +
             | 
| 719 | 
            +
                kptsB_pix[:, 0] = (kptsB[:, 0] + 1) * (W_B - 1) / 2
         | 
| 720 | 
            +
                kptsB_pix[:, 1] = (kptsB[:, 1] + 1) * (H_B - 1) / 2
         | 
| 721 | 
            +
             | 
| 722 | 
            +
                kptsA_np = kptsA_pix.detach().cpu().numpy()
         | 
| 723 | 
            +
                kptsB_np = kptsB_pix.detach().cpu().numpy()
         | 
| 724 | 
            +
             | 
| 725 | 
            +
                # Extract colors
         | 
| 726 | 
            +
                kptsA_x = np.round(kptsA_np[:, 0]).astype(int)
         | 
| 727 | 
            +
                kptsA_y = np.round(kptsA_np[:, 1]).astype(int)
         | 
| 728 | 
            +
                kptsB_x = np.round(kptsB_np[:, 0]).astype(int)
         | 
| 729 | 
            +
                kptsB_y = np.round(kptsB_np[:, 1]).astype(int)
         | 
| 730 | 
            +
             | 
| 731 | 
            +
                kptsA_color = imA[np.clip(kptsA_y, 0, H_A-1), np.clip(kptsA_x, 0, W_A-1)]
         | 
| 732 | 
            +
                kptsB_color = imB[np.clip(kptsB_y, 0, H_B-1), np.clip(kptsB_x, 0, W_B-1)]
         | 
| 733 | 
            +
             | 
| 734 | 
            +
                # Normalize keypoints into [-1, 1] for downstream triangulation
         | 
| 735 | 
            +
                kptsA_np_norm = np.zeros_like(kptsA_np)
         | 
| 736 | 
            +
                kptsB_np_norm = np.zeros_like(kptsB_np)
         | 
| 737 | 
            +
             | 
| 738 | 
            +
                kptsA_np_norm[:, 0] = kptsA_np[:, 0] / (W_A - 1) * 2.0 - 1.0
         | 
| 739 | 
            +
                kptsA_np_norm[:, 1] = kptsA_np[:, 1] / (H_A - 1) * 2.0 - 1.0
         | 
| 740 | 
            +
             | 
| 741 | 
            +
                kptsB_np_norm[:, 0] = kptsB_np[:, 0] / (W_B - 1) * 2.0 - 1.0
         | 
| 742 | 
            +
                kptsB_np_norm[:, 1] = kptsB_np[:, 1] / (H_B - 1) * 2.0 - 1.0
         | 
| 743 | 
            +
             | 
| 744 | 
            +
                return kptsA_np_norm, kptsB_np_norm, kptsA_color, kptsB_color
         | 
| 745 | 
            +
             | 
| 746 | 
            +
             | 
| 747 | 
            +
             | 
| 748 | 
            +
            def init_gaussians_with_corr_profiled(gaussians, scene, cfg, device, verbose=False, roma_model=None):
         | 
| 749 | 
            +
                timings = defaultdict(list)
         | 
| 750 | 
            +
             | 
| 751 | 
            +
                if roma_model is None:
         | 
| 752 | 
            +
                    if cfg.roma_model == "indoors":
         | 
| 753 | 
            +
                        roma_model = roma_indoor(device=device)
         | 
| 754 | 
            +
                    else:
         | 
| 755 | 
            +
                        roma_model = roma_outdoor(device=device)
         | 
| 756 | 
            +
                    roma_model.upsample_preds = False
         | 
| 757 | 
            +
                    roma_model.symmetric = False
         | 
| 758 | 
            +
             | 
| 759 | 
            +
                M = cfg.matches_per_ref
         | 
| 760 | 
            +
                upper_thresh = roma_model.sample_thresh
         | 
| 761 | 
            +
                scaling_factor = cfg.scaling_factor
         | 
| 762 | 
            +
                expansion_factor = 1
         | 
| 763 | 
            +
                keypoint_fit_error_tolerance = cfg.proj_err_tolerance
         | 
| 764 | 
            +
                visualizations = {}
         | 
| 765 | 
            +
                viewpoint_stack = scene.getTrainCameras().copy()
         | 
| 766 | 
            +
                NUM_REFERENCE_FRAMES = min(cfg.num_refs, len(viewpoint_stack))
         | 
| 767 | 
            +
                NUM_NNS_PER_REFERENCE = 1  # Only ONE neighbor now!
         | 
| 768 | 
            +
             | 
| 769 | 
            +
                viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
         | 
| 770 | 
            +
             | 
| 771 | 
            +
                selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES)
         | 
| 772 | 
            +
                selected_indices = sorted(selected_indices)
         | 
| 773 | 
            +
             | 
| 774 | 
            +
                viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0)
         | 
| 775 | 
            +
                closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE)
         | 
| 776 | 
            +
                closest_indices_selected = closest_indices[:, :].detach().cpu().numpy()
         | 
| 777 | 
            +
             | 
| 778 | 
            +
                all_new_xyz = []
         | 
| 779 | 
            +
                all_new_features_dc = []
         | 
| 780 | 
            +
                all_new_features_rest = []
         | 
| 781 | 
            +
                all_new_opacities = []
         | 
| 782 | 
            +
                all_new_scaling = []
         | 
| 783 | 
            +
                all_new_rotation = []
         | 
| 784 | 
            +
             | 
| 785 | 
            +
                # Dummy first pass to initialize model
         | 
| 786 | 
            +
                with torch.no_grad():
         | 
| 787 | 
            +
                    viewpoint_cam1 = viewpoint_stack[0]
         | 
| 788 | 
            +
                    viewpoint_cam2 = viewpoint_stack[1]
         | 
| 789 | 
            +
                    imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 790 | 
            +
                    imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 791 | 
            +
                    imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
         | 
| 792 | 
            +
                    imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
         | 
| 793 | 
            +
                    warp, certainty_warp = roma_model.match(imA, imB, device=device)
         | 
| 794 | 
            +
                    del warp, certainty_warp
         | 
| 795 | 
            +
                    torch.cuda.empty_cache()
         | 
| 796 | 
            +
             | 
| 797 | 
            +
                # Main Loop over source_idx
         | 
| 798 | 
            +
                for source_idx in tqdm(sorted(selected_indices), desc="Profiling source frames"):
         | 
| 799 | 
            +
             | 
| 800 | 
            +
                    # =================== Step 1: Compute Warp and Certainty ===================
         | 
| 801 | 
            +
                    start = time.time()
         | 
| 802 | 
            +
                    viewpoint_cam1 = viewpoint_stack[source_idx]
         | 
| 803 | 
            +
                    NNs=closest_indices_selected.shape[1]
         | 
| 804 | 
            +
                    viewpoint_cam2 = viewpoint_stack[closest_indices_selected[source_idx, np.random.randint(NNs)]]
         | 
| 805 | 
            +
                    imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 806 | 
            +
                    imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 807 | 
            +
                    imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8))
         | 
| 808 | 
            +
                    imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8))
         | 
| 809 | 
            +
                    warp, certainty_warp = roma_model.match(imA, imB, device=device)
         | 
| 810 | 
            +
             | 
| 811 | 
            +
                    certainties_max = certainty_warp  # New manual sampling
         | 
| 812 | 
            +
                    timings['aggregation_warp_certainty'].append(time.time() - start)
         | 
| 813 | 
            +
             | 
| 814 | 
            +
                    # =================== Step 2: Good Samples Selection ===================
         | 
| 815 | 
            +
                    start = time.time()
         | 
| 816 | 
            +
                    certainty = certainties_max.reshape(-1).clone()
         | 
| 817 | 
            +
                    certainty[certainty > upper_thresh] = 1
         | 
| 818 | 
            +
                    good_samples = torch.multinomial(certainty, num_samples=min(expansion_factor * M, len(certainty)), replacement=False)
         | 
| 819 | 
            +
                    timings['good_samples_selection'].append(time.time() - start)
         | 
| 820 | 
            +
             | 
| 821 | 
            +
                    # =================== Step 3: Triangulate Keypoints ===================
         | 
| 822 | 
            +
                    reference_image_dict = {
         | 
| 823 | 
            +
                        "triangulated_points": [],
         | 
| 824 | 
            +
                        "triangulated_points_errors_proj1": [],
         | 
| 825 | 
            +
                        "triangulated_points_errors_proj2": []
         | 
| 826 | 
            +
                    }
         | 
| 827 | 
            +
             | 
| 828 | 
            +
                    start = time.time()
         | 
| 829 | 
            +
                    matches_NN = warp.reshape(-1, 4)[good_samples]
         | 
| 830 | 
            +
             | 
| 831 | 
            +
                    # Convert matches to pixel coordinates
         | 
| 832 | 
            +
                    kptsA_np, kptsB_np, kptsA_color, kptsB_color = extract_keypoints_and_colors_single(
         | 
| 833 | 
            +
                        np.array(imA).astype(np.uint8), 
         | 
| 834 | 
            +
                        np.array(imB).astype(np.uint8), 
         | 
| 835 | 
            +
                        matches_NN, 
         | 
| 836 | 
            +
                        roma_model
         | 
| 837 | 
            +
                    )
         | 
| 838 | 
            +
             | 
| 839 | 
            +
                    proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform
         | 
| 840 | 
            +
                    proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, 0]].full_proj_transform
         | 
| 841 | 
            +
             | 
| 842 | 
            +
                    triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points(
         | 
| 843 | 
            +
                        P1=torch.stack([proj_matrices_A] * M, axis=0),
         | 
| 844 | 
            +
                        P2=torch.stack([proj_matrices_B] * M, axis=0),
         | 
| 845 | 
            +
                        k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1],
         | 
| 846 | 
            +
                        k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1])
         | 
| 847 | 
            +
             | 
| 848 | 
            +
                    reference_image_dict["triangulated_points"].append(triangulated_points)
         | 
| 849 | 
            +
                    reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1)
         | 
| 850 | 
            +
                    reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2)
         | 
| 851 | 
            +
                    timings['triangulation_per_NN'].append(time.time() - start)
         | 
| 852 | 
            +
             | 
| 853 | 
            +
                    # =================== Step 4: Select Best Triangulated Points ===================
         | 
| 854 | 
            +
                    start = time.time()
         | 
| 855 | 
            +
                    NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints(
         | 
| 856 | 
            +
                        NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0),
         | 
| 857 | 
            +
                        NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0),
         | 
| 858 | 
            +
                        NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0))
         | 
| 859 | 
            +
                    timings['select_best_keypoints'].append(time.time() - start)
         | 
| 860 | 
            +
             | 
| 861 | 
            +
                    # =================== Step 5: Create New Gaussians ===================
         | 
| 862 | 
            +
                    start = time.time()
         | 
| 863 | 
            +
                    viewpoint_cam1 = viewpoint_stack[source_idx]
         | 
| 864 | 
            +
                    N = len(NNs_triangulated_points_selected)
         | 
| 865 | 
            +
                    new_xyz = NNs_triangulated_points_selected[:, :-1]
         | 
| 866 | 
            +
                    all_new_xyz.append(new_xyz)
         | 
| 867 | 
            +
                    all_new_features_dc.append(RGB2SH(torch.tensor(kptsA_color.astype(np.float32) / 255.)).unsqueeze(1))
         | 
| 868 | 
            +
                    all_new_features_rest.append(torch.stack([gaussians._features_rest[-1].clone().detach() * 0.] * N, dim=0))
         | 
| 869 | 
            +
             | 
| 870 | 
            +
                    mask_bad_points = torch.tensor(
         | 
| 871 | 
            +
                        NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance,
         | 
| 872 | 
            +
                        dtype=torch.float32).unsqueeze(1).to(device)
         | 
| 873 | 
            +
             | 
| 874 | 
            +
                    all_new_opacities.append(torch.stack([gaussians._opacity[-1].clone().detach()] * N, dim=0) * 0. - mask_bad_points * (1e1))
         | 
| 875 | 
            +
             | 
| 876 | 
            +
                    dist_points_to_cam1 = torch.linalg.norm(viewpoint_cam1.camera_center.clone().detach() - new_xyz, dim=1, ord=2)
         | 
| 877 | 
            +
                    all_new_scaling.append(gaussians.scaling_inverse_activation((dist_points_to_cam1 * scaling_factor).unsqueeze(1).repeat(1, 3)))
         | 
| 878 | 
            +
                    all_new_rotation.append(torch.stack([gaussians._rotation[-1].clone().detach()] * N, dim=0))
         | 
| 879 | 
            +
                    timings['save_gaussians'].append(time.time() - start)
         | 
| 880 | 
            +
             | 
| 881 | 
            +
                # =================== Final Densification Postfix ===================
         | 
| 882 | 
            +
                start = time.time()
         | 
| 883 | 
            +
                all_new_xyz = torch.cat(all_new_xyz, dim=0) 
         | 
| 884 | 
            +
                all_new_features_dc = torch.cat(all_new_features_dc, dim=0)
         | 
| 885 | 
            +
                new_tmp_radii = torch.zeros(all_new_xyz.shape[0])
         | 
| 886 | 
            +
                prune_mask = torch.ones(all_new_xyz.shape[0], dtype=torch.bool)
         | 
| 887 | 
            +
             | 
| 888 | 
            +
                gaussians.densification_postfix(
         | 
| 889 | 
            +
                    all_new_xyz[prune_mask].to(device),
         | 
| 890 | 
            +
                    all_new_features_dc[prune_mask].to(device),
         | 
| 891 | 
            +
                    torch.cat(all_new_features_rest, dim=0)[prune_mask].to(device),
         | 
| 892 | 
            +
                    torch.cat(all_new_opacities, dim=0)[prune_mask].to(device),
         | 
| 893 | 
            +
                    torch.cat(all_new_scaling, dim=0)[prune_mask].to(device),
         | 
| 894 | 
            +
                    torch.cat(all_new_rotation, dim=0)[prune_mask].to(device),
         | 
| 895 | 
            +
                    new_tmp_radii[prune_mask].to(device)
         | 
| 896 | 
            +
                )
         | 
| 897 | 
            +
                timings['final_densification_postfix'].append(time.time() - start)
         | 
| 898 | 
            +
             | 
| 899 | 
            +
                # =================== Print Profiling Results ===================
         | 
| 900 | 
            +
                print("\n=== Profiling Summary (average per frame) ===")
         | 
| 901 | 
            +
                for key, times in timings.items():
         | 
| 902 | 
            +
                    print(f"{key:35s}: {sum(times) / len(times):.4f} sec (total {sum(times):.2f} sec)")
         | 
| 903 | 
            +
             | 
| 904 | 
            +
                return viewpoint_stack, closest_indices_selected, visualizations
         | 
    	
        source/data_utils.py
    ADDED
    
    | @@ -0,0 +1,28 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            def scene_cameras_train_test_split(scene, verbose=False):
         | 
| 2 | 
            +
                """
         | 
| 3 | 
            +
                Iterate over resolutions in the scene. For each resolution check if this resolution has test_cameras
         | 
| 4 | 
            +
                if it doesn't then extract every 8th camera from the train and put it to the test set. This follows the
         | 
| 5 | 
            +
                evaluation protocol suggested by Kerbl et al. in the seminal work on 3DGS. All changes to the input
         | 
| 6 | 
            +
                object scene are inplace changes.
         | 
| 7 | 
            +
                :param scene: Scene Class object from the gaussian-splatting.scene module
         | 
| 8 | 
            +
                :param verbose: Print initial and final stage of the function
         | 
| 9 | 
            +
                :return:  None
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                """
         | 
| 12 | 
            +
                if verbose: print("Preparing train and test sets split...")
         | 
| 13 | 
            +
                for resolution in scene.train_cameras.keys():
         | 
| 14 | 
            +
                    if len(scene.test_cameras[resolution]) == 0:
         | 
| 15 | 
            +
                        if verbose:
         | 
| 16 | 
            +
                            print(f"Found no test_cameras for resolution {resolution}. Move every 8th camera out ouf total "+\
         | 
| 17 | 
            +
                                  f"{len(scene.train_cameras[resolution])} train cameras to the test set now")
         | 
| 18 | 
            +
                        N = len(scene.train_cameras[resolution])
         | 
| 19 | 
            +
                        scene.test_cameras[resolution] = [scene.train_cameras[resolution][idx] for idx in range(0, N) 
         | 
| 20 | 
            +
                                                          if idx % 8 == 0]
         | 
| 21 | 
            +
                        scene.train_cameras[resolution] = [scene.train_cameras[resolution][idx] for idx in range(0, N)
         | 
| 22 | 
            +
                                                           if idx % 8 != 0]
         | 
| 23 | 
            +
                        if verbose:
         | 
| 24 | 
            +
                            print(f"Done. Now train and test sets contain each {len(scene.train_cameras[resolution])} and " + \
         | 
| 25 | 
            +
                                  f"{len(scene.test_cameras[resolution])} cameras respectively.")
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
                return
         | 
    	
        source/losses.py
    ADDED
    
    | @@ -0,0 +1,100 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Code is copied from the gaussian-splatting/utils/loss_utils.py
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
            from torch.autograd import Variable
         | 
| 6 | 
            +
            from math import exp
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            def l1_loss(network_output, gt, mean=True):
         | 
| 9 | 
            +
                return torch.abs((network_output - gt)).mean() if mean else torch.abs((network_output - gt))
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            def l2_loss(network_output, gt):
         | 
| 12 | 
            +
                return ((network_output - gt) ** 2).mean()
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            def gaussian(window_size, sigma):
         | 
| 15 | 
            +
                gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
         | 
| 16 | 
            +
                return gauss / gauss.sum()
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            def create_window(window_size, channel):
         | 
| 19 | 
            +
                _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
         | 
| 20 | 
            +
                _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
         | 
| 21 | 
            +
                window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
         | 
| 22 | 
            +
                return window
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            def ssim(img1, img2, window_size=11, size_average=True, mask = None):
         | 
| 25 | 
            +
                channel = img1.size(-3)
         | 
| 26 | 
            +
                window = create_window(window_size, channel)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                if img1.is_cuda:
         | 
| 29 | 
            +
                    window = window.cuda(img1.get_device())
         | 
| 30 | 
            +
                window = window.type_as(img1)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                return _ssim(img1, img2, window, window_size, channel, size_average, mask)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            def _ssim(img1, img2, window, window_size, channel, size_average=True, mask = None):
         | 
| 35 | 
            +
                mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
         | 
| 36 | 
            +
                mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                mu1_sq = mu1.pow(2)
         | 
| 39 | 
            +
                mu2_sq = mu2.pow(2)
         | 
| 40 | 
            +
                mu1_mu2 = mu1 * mu2
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
         | 
| 43 | 
            +
                sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
         | 
| 44 | 
            +
                sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                C1 = 0.01 ** 2
         | 
| 47 | 
            +
                C2 = 0.03 ** 2
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                if mask is not None:
         | 
| 52 | 
            +
                    ssim_map = ssim_map * mask
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                if size_average:
         | 
| 55 | 
            +
                    return ssim_map.mean()
         | 
| 56 | 
            +
                else:
         | 
| 57 | 
            +
                    return ssim_map.mean(1).mean(1).mean(1)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            def mse(img1, img2):
         | 
| 61 | 
            +
                return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            def psnr(img1, img2):
         | 
| 64 | 
            +
                """
         | 
| 65 | 
            +
                Computes the Peak Signal-to-Noise Ratio (PSNR) between two single images. NOT BATCHED!
         | 
| 66 | 
            +
                Args:
         | 
| 67 | 
            +
                    img1 (torch.Tensor): The first image tensor, with pixel values scaled between 0 and 1.
         | 
| 68 | 
            +
                                         Shape should be (channels, height, width).
         | 
| 69 | 
            +
                    img2 (torch.Tensor): The second image tensor with the same shape as img1, used for comparison.
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                Returns:
         | 
| 72 | 
            +
                    torch.Tensor: A scalar tensor containing the PSNR value in decibels (dB).
         | 
| 73 | 
            +
                """
         | 
| 74 | 
            +
                mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
         | 
| 75 | 
            +
                return 20 * torch.log10(1.0 / torch.sqrt(mse))
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            def tv_loss(image):
         | 
| 79 | 
            +
                """
         | 
| 80 | 
            +
                Computes the total variation (TV) loss for an image of shape [3, H, W].
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                Args:
         | 
| 83 | 
            +
                    image (torch.Tensor): Input image of shape [3, H, W]
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                Returns:
         | 
| 86 | 
            +
                    torch.Tensor: Scalar value representing the total variation loss.
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                # Ensure the image has the correct dimensions
         | 
| 89 | 
            +
                assert image.ndim == 3 and image.shape[0] == 3, "Input must be of shape [3, H, W]"
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                # Compute the difference between adjacent pixels in the x-direction (width)
         | 
| 92 | 
            +
                diff_x = torch.abs(image[:, :, 1:] - image[:, :, :-1])
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                # Compute the difference between adjacent pixels in the y-direction (height)
         | 
| 95 | 
            +
                diff_y = torch.abs(image[:, 1:, :] - image[:, :-1, :])
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                # Sum the total variation in both directions
         | 
| 98 | 
            +
                tv_loss_value = torch.mean(diff_x) + torch.mean(diff_y)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                return tv_loss_value
         | 
    	
        source/networks.py
    ADDED
    
    | @@ -0,0 +1,52 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
            sys.path.append('./submodules/gaussian-splatting/')
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from random import randint
         | 
| 7 | 
            +
            from scene import Scene, GaussianModel
         | 
| 8 | 
            +
            from gaussian_renderer import render
         | 
| 9 | 
            +
            from source.data_utils import scene_cameras_train_test_split
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            class Warper3DGS(torch.nn.Module):
         | 
| 12 | 
            +
                def __init__(self, sh_degree,  opt, pipe, dataset, viewpoint_stack, verbose,
         | 
| 13 | 
            +
                             do_train_test_split=True):
         | 
| 14 | 
            +
                    super(Warper3DGS, self).__init__()
         | 
| 15 | 
            +
                    """
         | 
| 16 | 
            +
                    Init Warper using all the objects necessary for rendering gaussian splats.
         | 
| 17 | 
            +
                    Here we merely link class objects to the objects instantiated outsided the class.
         | 
| 18 | 
            +
                    """
         | 
| 19 | 
            +
                    print("ready!!!7")
         | 
| 20 | 
            +
                    self.gaussians = GaussianModel(sh_degree)
         | 
| 21 | 
            +
                    print("ready!!!8")
         | 
| 22 | 
            +
                    self.gaussians.tmp_radii = torch.zeros((self.gaussians.get_xyz.shape[0]), device="cuda")
         | 
| 23 | 
            +
                    self.render = render
         | 
| 24 | 
            +
                    self.gs_config_opt = opt
         | 
| 25 | 
            +
                    bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
         | 
| 26 | 
            +
                    self.bg = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
         | 
| 27 | 
            +
                    self.pipe = pipe
         | 
| 28 | 
            +
                    print("ready!!!")
         | 
| 29 | 
            +
                    self.scene = Scene(dataset, self.gaussians, shuffle=False)
         | 
| 30 | 
            +
                    print("ready2")
         | 
| 31 | 
            +
                    if do_train_test_split:
         | 
| 32 | 
            +
                        scene_cameras_train_test_split(self.scene, verbose=verbose)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    self.gaussians.training_setup(opt)
         | 
| 35 | 
            +
                    self.viewpoint_stack = viewpoint_stack
         | 
| 36 | 
            +
                    if not self.viewpoint_stack:
         | 
| 37 | 
            +
                        self.viewpoint_stack = self.scene.getTrainCameras().copy()
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def forward(self, viewpoint_cam=None):
         | 
| 40 | 
            +
                    """
         | 
| 41 | 
            +
                    For a provided camera viewpoint_cam we render gaussians from this viewpoint.
         | 
| 42 | 
            +
                    If no camera provided then we use the self.viewpoint_stack (list of cameras).
         | 
| 43 | 
            +
                    If the latter is empty we reinitialize it using the self.scene object.
         | 
| 44 | 
            +
                    """
         | 
| 45 | 
            +
                    if not viewpoint_cam:
         | 
| 46 | 
            +
                        if not self.viewpoint_stack:
         | 
| 47 | 
            +
                            self.viewpoint_stack = self.scene.getTrainCameras().copy()
         | 
| 48 | 
            +
                        viewpoint_cam = self.viewpoint_stack[randint(0, len(self.viewpoint_stack) - 1)]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    render_pkg = self.render(viewpoint_cam, self.gaussians, self.pipe, self.bg)
         | 
| 51 | 
            +
                    return render_pkg
         | 
| 52 | 
            +
             | 
    	
        source/timer.py
    ADDED
    
    | @@ -0,0 +1,24 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import time
         | 
| 2 | 
            +
            class Timer:
         | 
| 3 | 
            +
                def __init__(self):
         | 
| 4 | 
            +
                    self.start_time = None
         | 
| 5 | 
            +
                    self.elapsed = 0
         | 
| 6 | 
            +
                    self.paused = False
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                def start(self):
         | 
| 9 | 
            +
                    if self.start_time is None:
         | 
| 10 | 
            +
                        self.start_time = time.time()
         | 
| 11 | 
            +
                    elif self.paused:
         | 
| 12 | 
            +
                        self.start_time = time.time() - self.elapsed
         | 
| 13 | 
            +
                        self.paused = False
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def pause(self):
         | 
| 16 | 
            +
                    if not self.paused:
         | 
| 17 | 
            +
                        self.elapsed = time.time() - self.start_time
         | 
| 18 | 
            +
                        self.paused = True
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def get_elapsed_time(self):
         | 
| 21 | 
            +
                    if self.paused:
         | 
| 22 | 
            +
                        return self.elapsed
         | 
| 23 | 
            +
                    else:
         | 
| 24 | 
            +
                        return time.time() - self.start_time
         | 
    	
        source/trainer.py
    ADDED
    
    | @@ -0,0 +1,262 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from random import randint
         | 
| 3 | 
            +
            from tqdm.rich import trange
         | 
| 4 | 
            +
            from tqdm import tqdm as tqdm
         | 
| 5 | 
            +
            from source.networks import Warper3DGS
         | 
| 6 | 
            +
            import wandb
         | 
| 7 | 
            +
            import sys
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            sys.path.append('./submodules/gaussian-splatting/')
         | 
| 10 | 
            +
            import lpips
         | 
| 11 | 
            +
            from source.losses import ssim, l1_loss, psnr
         | 
| 12 | 
            +
            from rich.console import Console
         | 
| 13 | 
            +
            from rich.theme import Theme
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            custom_theme = Theme({
         | 
| 16 | 
            +
                "info": "dim cyan",
         | 
| 17 | 
            +
                "warning": "magenta",
         | 
| 18 | 
            +
                "danger": "bold red"
         | 
| 19 | 
            +
            })
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            #from source.corr_init import init_gaussians_with_corr
         | 
| 22 | 
            +
            from source.corr_init_new import init_gaussians_with_corr_profiled as init_gaussians_with_corr
         | 
| 23 | 
            +
            from source.utils_aux import log_samples
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from source.timer import Timer
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            class EDGSTrainer:
         | 
| 28 | 
            +
                def __init__(self,
         | 
| 29 | 
            +
                             GS: Warper3DGS,
         | 
| 30 | 
            +
                             training_config,
         | 
| 31 | 
            +
                             dataset_white_background=False,
         | 
| 32 | 
            +
                             device=torch.device('cuda'),
         | 
| 33 | 
            +
                             log_wandb=True,
         | 
| 34 | 
            +
                             ):
         | 
| 35 | 
            +
                    self.GS = GS
         | 
| 36 | 
            +
                    self.scene = GS.scene
         | 
| 37 | 
            +
                    self.viewpoint_stack = GS.viewpoint_stack
         | 
| 38 | 
            +
                    self.gaussians = GS.gaussians
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    self.training_config = training_config
         | 
| 41 | 
            +
                    self.GS_optimizer = GS.gaussians.optimizer
         | 
| 42 | 
            +
                    self.dataset_white_background = dataset_white_background
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    self.training_step = 1
         | 
| 45 | 
            +
                    self.gs_step = 0
         | 
| 46 | 
            +
                    self.CONSOLE = Console(width=120, theme=custom_theme)
         | 
| 47 | 
            +
                    self.saving_iterations = training_config.save_iterations
         | 
| 48 | 
            +
                    self.evaluate_iterations = None
         | 
| 49 | 
            +
                    self.batch_size = training_config.batch_size
         | 
| 50 | 
            +
                    self.ema_loss_for_log = 0.0
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    # Logs in the format {step:{"loss1":loss1_value, "loss2":loss2_value}}
         | 
| 53 | 
            +
                    self.logs_losses = {}
         | 
| 54 | 
            +
                    self.lpips = lpips.LPIPS(net='vgg').to(device)
         | 
| 55 | 
            +
                    self.device = device
         | 
| 56 | 
            +
                    self.timer = Timer()
         | 
| 57 | 
            +
                    self.log_wandb = log_wandb
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def load_checkpoints(self, load_cfg):
         | 
| 60 | 
            +
                    # Load 3DGS checkpoint
         | 
| 61 | 
            +
                    if load_cfg.gs:
         | 
| 62 | 
            +
                        self.gs.gaussians.restore(
         | 
| 63 | 
            +
                            torch.load(f"{load_cfg.gs}/chkpnt{load_cfg.gs_step}.pth")[0],
         | 
| 64 | 
            +
                            self.training_config)
         | 
| 65 | 
            +
                        self.GS_optimizer = self.GS.gaussians.optimizer
         | 
| 66 | 
            +
                        self.CONSOLE.print(f"3DGS loaded from checkpoint for iteration {load_cfg.gs_step}",
         | 
| 67 | 
            +
                                           style="info")
         | 
| 68 | 
            +
                        self.training_step += load_cfg.gs_step
         | 
| 69 | 
            +
                        self.gs_step += load_cfg.gs_step
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def train(self, train_cfg):
         | 
| 72 | 
            +
                    # 3DGS training
         | 
| 73 | 
            +
                    self.CONSOLE.print("Train 3DGS for {} iterations".format(train_cfg.gs_epochs), style="info")    
         | 
| 74 | 
            +
                    with trange(self.training_step, self.training_step + train_cfg.gs_epochs, desc="[green]Train gaussians") as progress_bar:
         | 
| 75 | 
            +
                        for self.training_step in progress_bar:
         | 
| 76 | 
            +
                            radii = self.train_step_gs(max_lr=train_cfg.max_lr, no_densify=train_cfg.no_densify)
         | 
| 77 | 
            +
                            with torch.no_grad():
         | 
| 78 | 
            +
                                if train_cfg.no_densify:
         | 
| 79 | 
            +
                                    self.prune(radii)
         | 
| 80 | 
            +
                                else:
         | 
| 81 | 
            +
                                    self.densify_and_prune(radii)
         | 
| 82 | 
            +
                                if train_cfg.reduce_opacity:
         | 
| 83 | 
            +
                                    # Slightly reduce opacity every few steps:
         | 
| 84 | 
            +
                                    if self.gs_step < self.training_config.densify_until_iter and self.gs_step % 10 == 0:
         | 
| 85 | 
            +
                                        opacities_new = torch.log(torch.exp(self.GS.gaussians._opacity.data) * 0.99)
         | 
| 86 | 
            +
                                        self.GS.gaussians._opacity.data = opacities_new
         | 
| 87 | 
            +
                                self.timer.pause()
         | 
| 88 | 
            +
                                # Progress bar
         | 
| 89 | 
            +
                                if self.training_step % 10 == 0:
         | 
| 90 | 
            +
                                    progress_bar.set_postfix({"[red]Loss": f"{self.ema_loss_for_log:.{7}f}"}, refresh=True)
         | 
| 91 | 
            +
                                # Log and save
         | 
| 92 | 
            +
                                if self.training_step in self.saving_iterations:
         | 
| 93 | 
            +
                                    self.save_model()
         | 
| 94 | 
            +
                                if self.evaluate_iterations is not None:
         | 
| 95 | 
            +
                                    if self.training_step in self.evaluate_iterations:
         | 
| 96 | 
            +
                                        self.evaluate()
         | 
| 97 | 
            +
                                else:
         | 
| 98 | 
            +
                                    if (self.training_step <= 3000 and self.training_step % 500 == 0) or \
         | 
| 99 | 
            +
                                        (self.training_step > 3000 and self.training_step % 1000 == 228) :
         | 
| 100 | 
            +
                                        self.evaluate()
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                                self.timer.start()
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
                def evaluate(self):
         | 
| 106 | 
            +
                    torch.cuda.empty_cache()
         | 
| 107 | 
            +
                    log_gen_images, log_real_images = [], []
         | 
| 108 | 
            +
                    validation_configs = ({'name': 'test', 'cameras': self.scene.getTestCameras(), 'cam_idx': self.training_config.TEST_CAM_IDX_TO_LOG},
         | 
| 109 | 
            +
                                          {'name': 'train',
         | 
| 110 | 
            +
                                           'cameras': [self.scene.getTrainCameras()[idx % len(self.scene.getTrainCameras())] for idx in
         | 
| 111 | 
            +
                                                       range(0, 150, 5)], 'cam_idx': 10})
         | 
| 112 | 
            +
                    if self.log_wandb:
         | 
| 113 | 
            +
                        wandb.log({f"Number of Gaussians": len(self.GS.gaussians._xyz)}, step=self.training_step)
         | 
| 114 | 
            +
                    for config in validation_configs:
         | 
| 115 | 
            +
                        if config['cameras'] and len(config['cameras']) > 0:
         | 
| 116 | 
            +
                            l1_test = 0.0
         | 
| 117 | 
            +
                            psnr_test = 0.0
         | 
| 118 | 
            +
                            ssim_test = 0.0
         | 
| 119 | 
            +
                            lpips_splat_test = 0.0
         | 
| 120 | 
            +
                            for idx, viewpoint in enumerate(config['cameras']):
         | 
| 121 | 
            +
                                image = torch.clamp(self.GS(viewpoint)["render"], 0.0, 1.0)
         | 
| 122 | 
            +
                                gt_image = torch.clamp(viewpoint.original_image.to(self.device), 0.0, 1.0)
         | 
| 123 | 
            +
                                l1_test += l1_loss(image, gt_image).double()
         | 
| 124 | 
            +
                                psnr_test += psnr(image.unsqueeze(0), gt_image.unsqueeze(0)).double()
         | 
| 125 | 
            +
                                ssim_test += ssim(image, gt_image).double()
         | 
| 126 | 
            +
                                lpips_splat_test += self.lpips(image, gt_image).detach().double()
         | 
| 127 | 
            +
                                if idx in [config['cam_idx']]:
         | 
| 128 | 
            +
                                    log_gen_images.append(image)
         | 
| 129 | 
            +
                                    log_real_images.append(gt_image)
         | 
| 130 | 
            +
                            psnr_test /= len(config['cameras'])
         | 
| 131 | 
            +
                            l1_test /= len(config['cameras'])
         | 
| 132 | 
            +
                            ssim_test /= len(config['cameras'])
         | 
| 133 | 
            +
                            lpips_splat_test /= len(config['cameras'])
         | 
| 134 | 
            +
                            if self.log_wandb:
         | 
| 135 | 
            +
                                wandb.log({f"{config['name']}/L1": l1_test.item(), f"{config['name']}/PSNR": psnr_test.item(), \
         | 
| 136 | 
            +
                                        f"{config['name']}/SSIM": ssim_test.item(), f"{config['name']}/LPIPS_splat": lpips_splat_test.item()}, step = self.training_step)
         | 
| 137 | 
            +
                            self.CONSOLE.print("\n[ITER {}], #{} gaussians, Evaluating {}: L1={:.6f},  PSNR={:.6f}, SSIM={:.6f}, LPIPS_splat={:.6f} ".format(
         | 
| 138 | 
            +
                                self.training_step, len(self.GS.gaussians._xyz), config['name'], l1_test.item(), psnr_test.item(), ssim_test.item(), lpips_splat_test.item()), style="info")
         | 
| 139 | 
            +
                    if self.log_wandb:
         | 
| 140 | 
            +
                        with torch.no_grad():
         | 
| 141 | 
            +
                            log_samples(torch.stack((log_real_images[0],log_gen_images[0])) , [], self.training_step, caption="Real and Generated Samples")
         | 
| 142 | 
            +
                            wandb.log({"time": self.timer.get_elapsed_time()}, step=self.training_step)
         | 
| 143 | 
            +
                    torch.cuda.empty_cache()
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                def train_step_gs(self, max_lr = False, no_densify = False):
         | 
| 146 | 
            +
                    self.gs_step += 1
         | 
| 147 | 
            +
                    if max_lr:
         | 
| 148 | 
            +
                        self.GS.gaussians.update_learning_rate(max(self.gs_step, 8_000))
         | 
| 149 | 
            +
                    else:
         | 
| 150 | 
            +
                        self.GS.gaussians.update_learning_rate(self.gs_step)
         | 
| 151 | 
            +
                    # Every 1000 its we increase the levels of SH up to a maximum degree
         | 
| 152 | 
            +
                    if self.gs_step % 1000 == 0:
         | 
| 153 | 
            +
                        self.GS.gaussians.oneupSHdegree()
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    # Pick a random Camera
         | 
| 156 | 
            +
                    if not self.viewpoint_stack:
         | 
| 157 | 
            +
                        self.viewpoint_stack = self.scene.getTrainCameras().copy()
         | 
| 158 | 
            +
                    viewpoint_cam = self.viewpoint_stack.pop(randint(0, len(self.viewpoint_stack) - 1))
         | 
| 159 | 
            +
                  
         | 
| 160 | 
            +
                    render_pkg = self.GS(viewpoint_cam=viewpoint_cam)
         | 
| 161 | 
            +
                    image = render_pkg["render"]
         | 
| 162 | 
            +
                    # Loss
         | 
| 163 | 
            +
                    gt_image = viewpoint_cam.original_image.to(self.device)
         | 
| 164 | 
            +
                    L1_loss = l1_loss(image, gt_image)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    ssim_loss = (1.0 - ssim(image, gt_image))
         | 
| 167 | 
            +
                    loss = (1.0 - self.training_config.lambda_dssim) * L1_loss + \
         | 
| 168 | 
            +
                           self.training_config.lambda_dssim * ssim_loss
         | 
| 169 | 
            +
                    self.timer.pause() 
         | 
| 170 | 
            +
                    self.logs_losses[self.training_step] = {"loss": loss.item(),
         | 
| 171 | 
            +
                                                            "L1_loss": L1_loss.item(),
         | 
| 172 | 
            +
                                                            "ssim_loss": ssim_loss.item()}
         | 
| 173 | 
            +
                    
         | 
| 174 | 
            +
                    if self.log_wandb:
         | 
| 175 | 
            +
                        for k, v in self.logs_losses[self.training_step].items():
         | 
| 176 | 
            +
                            wandb.log({f"train/{k}": v}, step=self.training_step)
         | 
| 177 | 
            +
                    self.ema_loss_for_log = 0.4 * self.logs_losses[self.training_step]["loss"] + 0.6 * self.ema_loss_for_log
         | 
| 178 | 
            +
                    self.timer.start()
         | 
| 179 | 
            +
                    self.GS_optimizer.zero_grad(set_to_none=True)
         | 
| 180 | 
            +
                    loss.backward()
         | 
| 181 | 
            +
                    with torch.no_grad():
         | 
| 182 | 
            +
                        if self.gs_step < self.training_config.densify_until_iter and not no_densify:
         | 
| 183 | 
            +
                            self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]] = torch.max(
         | 
| 184 | 
            +
                                self.GS.gaussians.max_radii2D[render_pkg["visibility_filter"]],
         | 
| 185 | 
            +
                                render_pkg["radii"][render_pkg["visibility_filter"]])
         | 
| 186 | 
            +
                            self.GS.gaussians.add_densification_stats(render_pkg["viewspace_points"],
         | 
| 187 | 
            +
                                                                                 render_pkg["visibility_filter"])
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    # Optimizer step
         | 
| 190 | 
            +
                    self.GS_optimizer.step()
         | 
| 191 | 
            +
                    self.GS_optimizer.zero_grad(set_to_none=True)
         | 
| 192 | 
            +
                    return render_pkg["radii"]
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                def densify_and_prune(self, radii = None):
         | 
| 195 | 
            +
                    # Densification or pruning
         | 
| 196 | 
            +
                    if self.gs_step < self.training_config.densify_until_iter:
         | 
| 197 | 
            +
                        if (self.gs_step > self.training_config.densify_from_iter) and \
         | 
| 198 | 
            +
                                (self.gs_step % self.training_config.densification_interval == 0):
         | 
| 199 | 
            +
                            size_threshold = 20 if self.gs_step > self.training_config.opacity_reset_interval else None
         | 
| 200 | 
            +
                            self.GS.gaussians.densify_and_prune(self.training_config.densify_grad_threshold,
         | 
| 201 | 
            +
                                                                           0.005,
         | 
| 202 | 
            +
                                                                           self.GS.scene.cameras_extent,
         | 
| 203 | 
            +
                                                                           size_threshold, radii)
         | 
| 204 | 
            +
                        if self.gs_step % self.training_config.opacity_reset_interval == 0 or (
         | 
| 205 | 
            +
                                self.dataset_white_background and self.gs_step == self.training_config.densify_from_iter):
         | 
| 206 | 
            +
                            self.GS.gaussians.reset_opacity()             
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                      
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                def save_model(self):
         | 
| 211 | 
            +
                    print("\n[ITER {}] Saving Gaussians".format(self.gs_step))
         | 
| 212 | 
            +
                    self.scene.save(self.gs_step)
         | 
| 213 | 
            +
                    print("\n[ITER {}] Saving Checkpoint".format(self.gs_step))
         | 
| 214 | 
            +
                    torch.save((self.GS.gaussians.capture(), self.gs_step),
         | 
| 215 | 
            +
                            self.scene.model_path + "/chkpnt" + str(self.gs_step) + ".pth")
         | 
| 216 | 
            +
             | 
| 217 | 
            +
             | 
| 218 | 
            +
                def init_with_corr(self, cfg, verbose=False, roma_model=None): 
         | 
| 219 | 
            +
                    """
         | 
| 220 | 
            +
                    Initializes image with matchings. Also removes SfM init points.
         | 
| 221 | 
            +
                    Args:
         | 
| 222 | 
            +
                        cfg: configuration part named init_wC. Check train.yaml
         | 
| 223 | 
            +
                        verbose: whether you want to print intermediate results. Useful for debug.
         | 
| 224 | 
            +
                        roma_model: optionally you can pass here preinit RoMA model to avoid reinit 
         | 
| 225 | 
            +
                            it every time.  
         | 
| 226 | 
            +
                    """
         | 
| 227 | 
            +
                    if not cfg.use:
         | 
| 228 | 
            +
                        return None
         | 
| 229 | 
            +
                    N_splats_at_init = len(self.GS.gaussians._xyz)
         | 
| 230 | 
            +
                    print("N_splats_at_init:", N_splats_at_init)
         | 
| 231 | 
            +
                    camera_set, selected_indices, visualization_dict = init_gaussians_with_corr(
         | 
| 232 | 
            +
                        self.GS.gaussians, 
         | 
| 233 | 
            +
                        self.scene, 
         | 
| 234 | 
            +
                        cfg, 
         | 
| 235 | 
            +
                        self.device,                                                                                    
         | 
| 236 | 
            +
                        verbose=verbose,
         | 
| 237 | 
            +
                        roma_model=roma_model)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    # Remove SfM points and leave only matchings inits
         | 
| 240 | 
            +
                    if not cfg.add_SfM_init:
         | 
| 241 | 
            +
                        with torch.no_grad():
         | 
| 242 | 
            +
                            N_splats_after_init = len(self.GS.gaussians._xyz)
         | 
| 243 | 
            +
                            print("N_splats_after_init:", N_splats_after_init)
         | 
| 244 | 
            +
                            self.gaussians.tmp_radii = torch.zeros(self.gaussians._xyz.shape[0]).to(self.device)
         | 
| 245 | 
            +
                            mask = torch.concat([torch.ones(N_splats_at_init, dtype=torch.bool),
         | 
| 246 | 
            +
                                                torch.zeros(N_splats_after_init-N_splats_at_init, dtype=torch.bool)],
         | 
| 247 | 
            +
                                            axis=0)
         | 
| 248 | 
            +
                            self.GS.gaussians.prune_points(mask)
         | 
| 249 | 
            +
                    with torch.no_grad():
         | 
| 250 | 
            +
                        gaussians =  self.gaussians
         | 
| 251 | 
            +
                        gaussians._scaling =  gaussians.scaling_inverse_activation(gaussians.scaling_activation(gaussians._scaling)*0.5)
         | 
| 252 | 
            +
                    return visualization_dict
         | 
| 253 | 
            +
                
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                def prune(self, radii, min_opacity=0.005):
         | 
| 256 | 
            +
                    self.GS.gaussians.tmp_radii = radii
         | 
| 257 | 
            +
                    if self.gs_step < self.training_config.densify_until_iter:
         | 
| 258 | 
            +
                        prune_mask = (self.GS.gaussians.get_opacity < min_opacity).squeeze()
         | 
| 259 | 
            +
                        self.GS.gaussians.prune_points(prune_mask)
         | 
| 260 | 
            +
                        torch.cuda.empty_cache()
         | 
| 261 | 
            +
                    self.GS.gaussians.tmp_radii = None
         | 
| 262 | 
            +
             | 
    	
        source/utils_aux.py
    ADDED
    
    | @@ -0,0 +1,92 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Perlin noise code taken from https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869
         | 
| 2 | 
            +
            from types import SimpleNamespace
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torchvision
         | 
| 7 | 
            +
            import wandb
         | 
| 8 | 
            +
            import random
         | 
| 9 | 
            +
            import torchvision.transforms as T
         | 
| 10 | 
            +
            import torchvision.transforms.functional as F
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            from PIL import Image
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            def parse_dict_to_namespace(dict_nested):
         | 
| 15 | 
            +
                """Turns nested dictionary into nested namespaces"""
         | 
| 16 | 
            +
                if type(dict_nested) != dict and type(dict_nested) != list: return dict_nested
         | 
| 17 | 
            +
                x = SimpleNamespace()
         | 
| 18 | 
            +
                for key, val in dict_nested.items():
         | 
| 19 | 
            +
                    if type(val) == dict:
         | 
| 20 | 
            +
                        setattr(x, key, parse_dict_to_namespace(val))
         | 
| 21 | 
            +
                    elif type(val) == list:
         | 
| 22 | 
            +
                        setattr(x, key, [parse_dict_to_namespace(v) for v in val])
         | 
| 23 | 
            +
                    else:
         | 
| 24 | 
            +
                        setattr(x, key, val)
         | 
| 25 | 
            +
                return x
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            def set_seed(seed=42, cuda=True):
         | 
| 28 | 
            +
                random.seed(seed)
         | 
| 29 | 
            +
                np.random.seed(seed)
         | 
| 30 | 
            +
                torch.manual_seed(seed)
         | 
| 31 | 
            +
                if cuda:
         | 
| 32 | 
            +
                    torch.cuda.manual_seed_all(seed)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def log_samples(samples, scores, iteration, caption="Real Samples"):
         | 
| 37 | 
            +
                # Create a grid of images
         | 
| 38 | 
            +
                grid = torchvision.utils.make_grid(samples)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                # Log the images and scores to wandb
         | 
| 41 | 
            +
                wandb.log({
         | 
| 42 | 
            +
                    f"{caption}_images": [wandb.Image(grid, caption=f"{caption}: {scores}")],
         | 
| 43 | 
            +
                }, step = iteration)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def pairwise_distances(matrix):
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                Computes the pairwise Euclidean distances between all vectors in the input matrix.
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                Args:
         | 
| 52 | 
            +
                    matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                Returns:
         | 
| 55 | 
            +
                    torch.Tensor: Pairwise distance matrix of shape [N, N].
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                # Compute squared pairwise distances
         | 
| 58 | 
            +
                squared_diff = torch.cdist(matrix, matrix, p=2)
         | 
| 59 | 
            +
                return squared_diff
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            def k_closest_vectors(matrix, k):
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
                Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance.
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                Args:
         | 
| 66 | 
            +
                    matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality.
         | 
| 67 | 
            +
                    k (int): Number of closest vectors to return for each vector.
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                Returns:
         | 
| 70 | 
            +
                    torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself.
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
                # Compute pairwise distances
         | 
| 73 | 
            +
                distances = pairwise_distances(matrix)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                # For each vector, sort distances and get the indices of the k-closest vectors (excluding itself)
         | 
| 76 | 
            +
                # Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors
         | 
| 77 | 
            +
                distances.fill_diagonal_(float('inf'))
         | 
| 78 | 
            +
                
         | 
| 79 | 
            +
                # Get the indices of the k smallest distances (k-closest vectors)
         | 
| 80 | 
            +
                _, indices = torch.topk(distances, k, largest=False, dim=1)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                return indices
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            def process_image(image_tensor):
         | 
| 85 | 
            +
                image_np = image_tensor.detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 86 | 
            +
                return Image.fromarray(np.clip(image_np * 255, 0, 255).astype(np.uint8))
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            def normalize_keypoints(kpts_np, width, height):
         | 
| 90 | 
            +
                kpts_np[:, 0] = kpts_np[:, 0] / width * 2. - 1.
         | 
| 91 | 
            +
                kpts_np[:, 1] = kpts_np[:, 1] / height * 2. - 1.
         | 
| 92 | 
            +
                return kpts_np
         | 
    	
        source/utils_preprocess.py
    ADDED
    
    | @@ -0,0 +1,334 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # This file contains function for video or image collection preprocessing.
         | 
| 2 | 
            +
            # For video we do the preprocessing and select k sharpest frames.
         | 
| 3 | 
            +
            # Afterwards scene is constructed 
         | 
| 4 | 
            +
            import cv2
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            from tqdm import tqdm
         | 
| 7 | 
            +
            import pycolmap
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
            import time
         | 
| 10 | 
            +
            import tempfile
         | 
| 11 | 
            +
            from moviepy import VideoFileClip
         | 
| 12 | 
            +
            from matplotlib import pyplot as plt
         | 
| 13 | 
            +
            from PIL import Image
         | 
| 14 | 
            +
            import cv2
         | 
| 15 | 
            +
            from tqdm import tqdm
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            WORKDIR = "../outputs/"
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def get_rotation_moviepy(video_path):
         | 
| 21 | 
            +
                clip = VideoFileClip(video_path)
         | 
| 22 | 
            +
                rotation = 0
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                try:
         | 
| 25 | 
            +
                    displaymatrix = clip.reader.infos['inputs'][0]['streams'][2]['metadata'].get('displaymatrix', '')
         | 
| 26 | 
            +
                    if 'rotation of' in displaymatrix:
         | 
| 27 | 
            +
                        angle = float(displaymatrix.strip().split('rotation of')[-1].split('degrees')[0])
         | 
| 28 | 
            +
                        rotation = int(angle) % 360
         | 
| 29 | 
            +
                        
         | 
| 30 | 
            +
                except Exception as e:
         | 
| 31 | 
            +
                    print(f"No displaymatrix rotation found: {e}")
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                clip.reader.close()
         | 
| 34 | 
            +
                #if clip.audio:
         | 
| 35 | 
            +
                #    clip.audio.reader.close_proc()
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                return rotation
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            def resize_max_side(frame, max_size):
         | 
| 40 | 
            +
                h, w = frame.shape[:2]
         | 
| 41 | 
            +
                scale = max_size / max(h, w)
         | 
| 42 | 
            +
                if scale < 1:
         | 
| 43 | 
            +
                    frame = cv2.resize(frame, (int(w * scale), int(h * scale)))
         | 
| 44 | 
            +
                return frame
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            def read_video_frames(video_input, k=1, max_size=1024):
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                Extracts every k-th frame from a video or list of images, resizes to max size, and returns frames as list.
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                Parameters:
         | 
| 51 | 
            +
                    video_input (str, file-like, or list): Path to video file, file-like object, or list of image files.
         | 
| 52 | 
            +
                    k (int): Interval for frame extraction (every k-th frame).
         | 
| 53 | 
            +
                    max_size (int): Maximum size for width or height after resizing.
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                Returns:
         | 
| 56 | 
            +
                    frames (list): List of resized frames (numpy arrays).
         | 
| 57 | 
            +
                """
         | 
| 58 | 
            +
                # Handle list of image files (not single video in a list)
         | 
| 59 | 
            +
                if isinstance(video_input, list):
         | 
| 60 | 
            +
                    # If it's a single video in a list, treat it as video
         | 
| 61 | 
            +
                    if len(video_input) == 1 and video_input[0].name.endswith(('.mp4', '.avi', '.mov')):
         | 
| 62 | 
            +
                        video_input = video_input[0]  # unwrap single video file
         | 
| 63 | 
            +
                    else:
         | 
| 64 | 
            +
                        # Treat as list of images
         | 
| 65 | 
            +
                        frames = []
         | 
| 66 | 
            +
                        for img_file in video_input:
         | 
| 67 | 
            +
                            img = Image.open(img_file.name).convert("RGB")
         | 
| 68 | 
            +
                            img.thumbnail((max_size, max_size))
         | 
| 69 | 
            +
                            frames.append(np.array(img)[...,::-1])
         | 
| 70 | 
            +
                        return frames
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                # Handle file-like or path
         | 
| 73 | 
            +
                if hasattr(video_input, 'name'):
         | 
| 74 | 
            +
                    video_path = video_input.name
         | 
| 75 | 
            +
                elif isinstance(video_input, (str, os.PathLike)):
         | 
| 76 | 
            +
                    video_path = str(video_input)
         | 
| 77 | 
            +
                else:
         | 
| 78 | 
            +
                    raise ValueError("Unsupported video input type. Must be a filepath, file-like object, or list of images.")
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                
         | 
| 81 | 
            +
                cap = cv2.VideoCapture(video_path)
         | 
| 82 | 
            +
                if not cap.isOpened():
         | 
| 83 | 
            +
                    raise ValueError(f"Error: Could not open video {video_path}.")
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
         | 
| 86 | 
            +
                frame_count = 0
         | 
| 87 | 
            +
                frames = []
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                with tqdm(total=total_frames // k, desc="Processing Video", unit="frame") as pbar:
         | 
| 90 | 
            +
                    while True:
         | 
| 91 | 
            +
                        ret, frame = cap.read()
         | 
| 92 | 
            +
                        if not ret:
         | 
| 93 | 
            +
                            break
         | 
| 94 | 
            +
                        if frame_count % k == 0:
         | 
| 95 | 
            +
                            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
         | 
| 96 | 
            +
                            h, w = frame.shape[:2]
         | 
| 97 | 
            +
                            scale = max(h, w) / max_size
         | 
| 98 | 
            +
                            if scale > 1:
         | 
| 99 | 
            +
                                frame = cv2.resize(frame, (int(w / scale), int(h / scale)))
         | 
| 100 | 
            +
                            frames.append(frame[...,[2,1,0]])
         | 
| 101 | 
            +
                            pbar.update(1)
         | 
| 102 | 
            +
                        frame_count += 1
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                cap.release()
         | 
| 105 | 
            +
                return frames
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            def resize_max_side(frame, max_size):
         | 
| 108 | 
            +
                """
         | 
| 109 | 
            +
                Resizes the frame so that its largest side equals max_size, maintaining aspect ratio.
         | 
| 110 | 
            +
                """
         | 
| 111 | 
            +
                height, width = frame.shape[:2]
         | 
| 112 | 
            +
                max_dim = max(height, width)
         | 
| 113 | 
            +
                
         | 
| 114 | 
            +
                if max_dim <= max_size:
         | 
| 115 | 
            +
                    return frame  # No need to resize
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                scale = max_size / max_dim
         | 
| 118 | 
            +
                new_width = int(width * scale)
         | 
| 119 | 
            +
                new_height = int(height * scale)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                resized_frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
         | 
| 122 | 
            +
                return resized_frame
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
            +
             | 
| 126 | 
            +
            def variance_of_laplacian(image):
         | 
| 127 | 
            +
            	# compute the Laplacian of the image and then return the focus
         | 
| 128 | 
            +
            	# measure, which is simply the variance of the Laplacian
         | 
| 129 | 
            +
            	return cv2.Laplacian(image, cv2.CV_64F).var()
         | 
| 130 | 
            +
                
         | 
| 131 | 
            +
            def process_all_frames(IMG_FOLDER = '/scratch/datasets/hq_data/night2_all_frames',
         | 
| 132 | 
            +
                                   to_visualize=False,
         | 
| 133 | 
            +
                                   save_images=True):
         | 
| 134 | 
            +
                dict_scores = {}
         | 
| 135 | 
            +
                for idx, img_name in tqdm(enumerate(sorted([x for x in os.listdir(IMG_FOLDER) if '.png' in x]))):
         | 
| 136 | 
            +
                    
         | 
| 137 | 
            +
                    img = cv2.imread(os.path.join(IMG_FOLDER, img_name))#[250:, 100:]
         | 
| 138 | 
            +
                    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
         | 
| 139 | 
            +
                    fm = variance_of_laplacian(gray) + \
         | 
| 140 | 
            +
                            variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.75, fy=0.75)) + \
         | 
| 141 | 
            +
                            variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.5, fy=0.5)) + \
         | 
| 142 | 
            +
                            variance_of_laplacian(cv2.resize(gray, (0,0), fx=0.25, fy=0.25))
         | 
| 143 | 
            +
                    if to_visualize:
         | 
| 144 | 
            +
                        plt.figure()
         | 
| 145 | 
            +
                        plt.title(f"Laplacian score: {fm:.2f}")
         | 
| 146 | 
            +
                        plt.imshow(img[..., [2,1,0]])
         | 
| 147 | 
            +
                        plt.show()
         | 
| 148 | 
            +
                    dict_scores[idx] = {"idx" : idx, 
         | 
| 149 | 
            +
                                        "img_name" : img_name,
         | 
| 150 | 
            +
                                        "score" : fm}
         | 
| 151 | 
            +
                    if save_images:
         | 
| 152 | 
            +
                        dict_scores[idx]["img"] = img
         | 
| 153 | 
            +
                    
         | 
| 154 | 
            +
                return dict_scores
         | 
| 155 | 
            +
             | 
| 156 | 
            +
            def select_optimal_frames(scores, k):
         | 
| 157 | 
            +
                """
         | 
| 158 | 
            +
                Selects a minimal subset of frames while ensuring no gaps exceed k.
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                Args:
         | 
| 161 | 
            +
                    scores (list of float): List of scores where index represents frame number.
         | 
| 162 | 
            +
                    k (int): Maximum allowed gap between selected frames.
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                Returns:
         | 
| 165 | 
            +
                    list of int: Indices of selected frames.
         | 
| 166 | 
            +
                """
         | 
| 167 | 
            +
                n = len(scores)
         | 
| 168 | 
            +
                selected = [0, n-1]
         | 
| 169 | 
            +
                i = 0  # Start at the first frame
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                while i < n:
         | 
| 172 | 
            +
                    # Find the best frame to select within the next k frames
         | 
| 173 | 
            +
                    best_idx = max(range(i, min(i + k + 1, n)), key=lambda x: scores[x], default=None)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    if best_idx is None:
         | 
| 176 | 
            +
                        break  # No more frames left
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    selected.append(best_idx)
         | 
| 179 | 
            +
                    i = best_idx + k + 1  # Move forward, ensuring gaps stay within k
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                return sorted(selected)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
             | 
| 184 | 
            +
            def variance_of_laplacian(image):
         | 
| 185 | 
            +
                """
         | 
| 186 | 
            +
                Compute the variance of Laplacian as a focus measure.
         | 
| 187 | 
            +
                """
         | 
| 188 | 
            +
                return cv2.Laplacian(image, cv2.CV_64F).var()
         | 
| 189 | 
            +
             | 
| 190 | 
            +
            def preprocess_frames(frames, verbose=False):
         | 
| 191 | 
            +
                """
         | 
| 192 | 
            +
                Compute sharpness scores for a list of frames using multi-scale Laplacian variance.
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                Args:
         | 
| 195 | 
            +
                    frames (list of np.ndarray): List of frames (BGR images).
         | 
| 196 | 
            +
                    verbose (bool): If True, print scores.
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                Returns:
         | 
| 199 | 
            +
                    list of float: Sharpness scores for each frame.
         | 
| 200 | 
            +
                """
         | 
| 201 | 
            +
                scores = []
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                for idx, frame in enumerate(tqdm(frames, desc="Scoring frames")):
         | 
| 204 | 
            +
                    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    fm = (
         | 
| 207 | 
            +
                        variance_of_laplacian(gray) +
         | 
| 208 | 
            +
                        variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.75, fy=0.75)) +
         | 
| 209 | 
            +
                        variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.5, fy=0.5)) +
         | 
| 210 | 
            +
                        variance_of_laplacian(cv2.resize(gray, (0, 0), fx=0.25, fy=0.25))
         | 
| 211 | 
            +
                    )
         | 
| 212 | 
            +
                    
         | 
| 213 | 
            +
                    if verbose:
         | 
| 214 | 
            +
                        print(f"Frame {idx}: Sharpness Score = {fm:.2f}")
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    scores.append(fm)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                return scores
         | 
| 219 | 
            +
             | 
| 220 | 
            +
            def select_optimal_frames(scores, k):
         | 
| 221 | 
            +
                """
         | 
| 222 | 
            +
                Selects k frames by splitting into k segments and picking the sharpest frame from each.
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                Args:
         | 
| 225 | 
            +
                    scores (list of float): List of sharpness scores.
         | 
| 226 | 
            +
                    k (int): Number of frames to select.
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                Returns:
         | 
| 229 | 
            +
                    list of int: Indices of selected frames.  
         | 
| 230 | 
            +
                """
         | 
| 231 | 
            +
                n = len(scores)
         | 
| 232 | 
            +
                selected_indices = []
         | 
| 233 | 
            +
                segment_size = n // k
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                for i in range(k):
         | 
| 236 | 
            +
                    start = i * segment_size
         | 
| 237 | 
            +
                    end = (i + 1) * segment_size if i < k - 1 else n  # Last chunk may be larger
         | 
| 238 | 
            +
                    segment_scores = scores[start:end]
         | 
| 239 | 
            +
                    
         | 
| 240 | 
            +
                    if len(segment_scores) == 0:
         | 
| 241 | 
            +
                        continue  # Safety check if some segment is empty
         | 
| 242 | 
            +
                    
         | 
| 243 | 
            +
                    best_in_segment = start + np.argmax(segment_scores)
         | 
| 244 | 
            +
                    selected_indices.append(best_in_segment)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                return sorted(selected_indices)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
            def save_frames_to_scene_dir(frames, scene_dir):
         | 
| 249 | 
            +
                """
         | 
| 250 | 
            +
                Saves a list of frames into the target scene directory under 'images/' subfolder.
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                Args:
         | 
| 253 | 
            +
                    frames (list of np.ndarray): List of frames (BGR images) to save.
         | 
| 254 | 
            +
                    scene_dir (str): Target path where 'images/' subfolder will be created.
         | 
| 255 | 
            +
                """
         | 
| 256 | 
            +
                images_dir = os.path.join(scene_dir, "images")
         | 
| 257 | 
            +
                os.makedirs(images_dir, exist_ok=True)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                for idx, frame in enumerate(frames):
         | 
| 260 | 
            +
                    filename = os.path.join(images_dir, f"{idx:08d}.png")  # 00000000.png, 00000001.png, etc.
         | 
| 261 | 
            +
                    cv2.imwrite(filename, frame)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                print(f"Saved {len(frames)} frames to {images_dir}")
         | 
| 264 | 
            +
             | 
| 265 | 
            +
             | 
| 266 | 
            +
            def run_colmap_on_scene(scene_dir):
         | 
| 267 | 
            +
                """
         | 
| 268 | 
            +
                Runs feature extraction, matching, and mapping on all images inside scene_dir/images using pycolmap.
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                Args:
         | 
| 271 | 
            +
                    scene_dir (str): Path to scene directory containing 'images' folder.
         | 
| 272 | 
            +
                
         | 
| 273 | 
            +
                TODO: if the function hasn't managed to match all the frames either increase image size,
         | 
| 274 | 
            +
                increase number of features or just remove those frames from the folder scene_dir/images
         | 
| 275 | 
            +
                """
         | 
| 276 | 
            +
                start_time = time.time()
         | 
| 277 | 
            +
                print(f"Running COLMAP pipeline on all images inside {scene_dir}")
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                # Setup paths
         | 
| 280 | 
            +
                database_path = os.path.join(scene_dir, "database.db")
         | 
| 281 | 
            +
                sparse_path = os.path.join(scene_dir, "sparse")
         | 
| 282 | 
            +
                image_dir = os.path.join(scene_dir, "images")
         | 
| 283 | 
            +
                
         | 
| 284 | 
            +
                # Make sure output directories exist
         | 
| 285 | 
            +
                os.makedirs(sparse_path, exist_ok=True)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                # Step 1: Feature Extraction
         | 
| 288 | 
            +
                pycolmap.extract_features(
         | 
| 289 | 
            +
                    database_path,
         | 
| 290 | 
            +
                    image_dir,
         | 
| 291 | 
            +
                    sift_options={
         | 
| 292 | 
            +
                        "max_num_features": 512 * 2,
         | 
| 293 | 
            +
                        "max_image_size": 512 * 1,
         | 
| 294 | 
            +
                    }
         | 
| 295 | 
            +
                )
         | 
| 296 | 
            +
                print(f"Finished feature extraction in {(time.time() - start_time):.2f}s.")
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                # Step 2: Feature Matching
         | 
| 299 | 
            +
                pycolmap.match_exhaustive(database_path)
         | 
| 300 | 
            +
                print(f"Finished feature matching in {(time.time() - start_time):.2f}s.")
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                # Step 3: Mapping
         | 
| 303 | 
            +
                pipeline_options = pycolmap.IncrementalPipelineOptions()
         | 
| 304 | 
            +
                pipeline_options.min_num_matches = 15
         | 
| 305 | 
            +
                pipeline_options.multiple_models = True
         | 
| 306 | 
            +
                pipeline_options.max_num_models = 50
         | 
| 307 | 
            +
                pipeline_options.max_model_overlap = 20
         | 
| 308 | 
            +
                pipeline_options.min_model_size = 10
         | 
| 309 | 
            +
                pipeline_options.extract_colors = True
         | 
| 310 | 
            +
                pipeline_options.num_threads = 8
         | 
| 311 | 
            +
                pipeline_options.mapper.init_min_num_inliers = 30
         | 
| 312 | 
            +
                pipeline_options.mapper.init_max_error = 8.0
         | 
| 313 | 
            +
                pipeline_options.mapper.init_min_tri_angle = 5.0
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                reconstruction = pycolmap.incremental_mapping(
         | 
| 316 | 
            +
                    database_path=database_path,
         | 
| 317 | 
            +
                    image_path=image_dir,
         | 
| 318 | 
            +
                    output_path=sparse_path,
         | 
| 319 | 
            +
                    options=pipeline_options,
         | 
| 320 | 
            +
                )
         | 
| 321 | 
            +
                print(f"Finished incremental mapping in {(time.time() - start_time):.2f}s.")
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                # Step 4: Post-process Cameras to SIMPLE_PINHOLE
         | 
| 324 | 
            +
                recon_path = os.path.join(sparse_path, "0")
         | 
| 325 | 
            +
                reconstruction = pycolmap.Reconstruction(recon_path)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                for cam in reconstruction.cameras.values():
         | 
| 328 | 
            +
                    cam.model = 'SIMPLE_PINHOLE'
         | 
| 329 | 
            +
                    cam.params = cam.params[:3]  # Keep only [f, cx, cy]
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                reconstruction.write(recon_path)
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                print(f"Total pipeline time: {(time.time() - start_time):.2f}s.")
         | 
| 334 | 
            +
             | 
    	
        source/vggt_to_colmap.py
    ADDED
    
    | @@ -0,0 +1,598 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Reuse code taken from the implementation of atakan-topaloglu:
         | 
| 2 | 
            +
            #  https://github.com/atakan-topaloglu/vggt/blob/main/vggt_to_colmap.py
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import argparse
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import glob
         | 
| 9 | 
            +
            import struct
         | 
| 10 | 
            +
            from scipy.spatial.transform import Rotation
         | 
| 11 | 
            +
            import sys
         | 
| 12 | 
            +
            from PIL import Image
         | 
| 13 | 
            +
            import cv2
         | 
| 14 | 
            +
            import requests
         | 
| 15 | 
            +
            import tempfile
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            sys.path.append("submodules/vggt/")
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from vggt.models.vggt import VGGT
         | 
| 20 | 
            +
            from vggt.utils.load_fn import load_and_preprocess_images
         | 
| 21 | 
            +
            from vggt.utils.pose_enc import pose_encoding_to_extri_intri
         | 
| 22 | 
            +
            from vggt.utils.geometry import unproject_depth_map_to_point_map
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            def load_model(device=None):
         | 
| 25 | 
            +
                """Load and initialize the VGGT model."""
         | 
| 26 | 
            +
                if device is None:
         | 
| 27 | 
            +
                    device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 28 | 
            +
                print(f"Using device: {device}")
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                model = VGGT.from_pretrained("facebook/VGGT-1B")
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                # model = VGGT()
         | 
| 33 | 
            +
                # _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
         | 
| 34 | 
            +
                # model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
         | 
| 35 | 
            +
                
         | 
| 36 | 
            +
                model.eval()
         | 
| 37 | 
            +
                model = model.to(device)
         | 
| 38 | 
            +
                return model, device
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            def process_images(image_dir, model, device):
         | 
| 41 | 
            +
                """Process images with VGGT and return predictions."""
         | 
| 42 | 
            +
                image_names = glob.glob(os.path.join(image_dir, "*"))
         | 
| 43 | 
            +
                image_names = sorted([f for f in image_names if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
         | 
| 44 | 
            +
                print(f"Found {len(image_names)} images")
         | 
| 45 | 
            +
                
         | 
| 46 | 
            +
                if len(image_names) == 0:
         | 
| 47 | 
            +
                    raise ValueError(f"No images found in {image_dir}")
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                original_images = []
         | 
| 50 | 
            +
                for img_path in image_names:
         | 
| 51 | 
            +
                    img = Image.open(img_path).convert('RGB')
         | 
| 52 | 
            +
                    original_images.append(np.array(img))
         | 
| 53 | 
            +
                
         | 
| 54 | 
            +
                images = load_and_preprocess_images(image_names).to(device)
         | 
| 55 | 
            +
                print(f"Preprocessed images shape: {images.shape}")
         | 
| 56 | 
            +
                
         | 
| 57 | 
            +
                print("Running inference...")
         | 
| 58 | 
            +
                dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
         | 
| 59 | 
            +
                
         | 
| 60 | 
            +
                with torch.no_grad():
         | 
| 61 | 
            +
                    with torch.cuda.amp.autocast(dtype=dtype):
         | 
| 62 | 
            +
                        predictions = model(images)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                print("Converting pose encoding to camera parameters...")
         | 
| 65 | 
            +
                extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
         | 
| 66 | 
            +
                predictions["extrinsic"] = extrinsic
         | 
| 67 | 
            +
                predictions["intrinsic"] = intrinsic
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                for key in predictions.keys():
         | 
| 70 | 
            +
                    if isinstance(predictions[key], torch.Tensor):
         | 
| 71 | 
            +
                        predictions[key] = predictions[key].cpu().numpy().squeeze(0)  # remove batch dimension
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                print("Computing 3D points from depth maps...")
         | 
| 74 | 
            +
                depth_map = predictions["depth"]  # (S, H, W, 1)
         | 
| 75 | 
            +
                world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
         | 
| 76 | 
            +
                predictions["world_points_from_depth"] = world_points
         | 
| 77 | 
            +
                
         | 
| 78 | 
            +
                predictions["original_images"] = original_images
         | 
| 79 | 
            +
                
         | 
| 80 | 
            +
                S, H, W = world_points.shape[:3]
         | 
| 81 | 
            +
                normalized_images = np.zeros((S, H, W, 3), dtype=np.float32)
         | 
| 82 | 
            +
                
         | 
| 83 | 
            +
                for i, img in enumerate(original_images):
         | 
| 84 | 
            +
                    resized_img = cv2.resize(img, (W, H))
         | 
| 85 | 
            +
                    normalized_images[i] = resized_img / 255.0
         | 
| 86 | 
            +
                
         | 
| 87 | 
            +
                predictions["images"] = normalized_images
         | 
| 88 | 
            +
                
         | 
| 89 | 
            +
                return predictions, image_names
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            def extrinsic_to_colmap_format(extrinsics):
         | 
| 92 | 
            +
                """Convert extrinsic matrices to COLMAP format (quaternion + translation)."""
         | 
| 93 | 
            +
                num_cameras = extrinsics.shape[0]
         | 
| 94 | 
            +
                quaternions = []
         | 
| 95 | 
            +
                translations = []
         | 
| 96 | 
            +
                
         | 
| 97 | 
            +
                for i in range(num_cameras):
         | 
| 98 | 
            +
                    # VGGT's extrinsic is camera-to-world (R|t) format
         | 
| 99 | 
            +
                    R = extrinsics[i, :3, :3]
         | 
| 100 | 
            +
                    t = extrinsics[i, :3, 3]
         | 
| 101 | 
            +
                    
         | 
| 102 | 
            +
                    # Convert rotation matrix to quaternion
         | 
| 103 | 
            +
                    # COLMAP quaternion format is [qw, qx, qy, qz]
         | 
| 104 | 
            +
                    rot = Rotation.from_matrix(R)
         | 
| 105 | 
            +
                    quat = rot.as_quat()  # scipy returns [x, y, z, w]
         | 
| 106 | 
            +
                    quat = np.array([quat[3], quat[0], quat[1], quat[2]])  # Convert to [w, x, y, z]
         | 
| 107 | 
            +
                    
         | 
| 108 | 
            +
                    quaternions.append(quat)
         | 
| 109 | 
            +
                    translations.append(t)
         | 
| 110 | 
            +
                
         | 
| 111 | 
            +
                return np.array(quaternions), np.array(translations)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            def download_file_from_url(url, filename):
         | 
| 114 | 
            +
                """Downloads a file from a URL, handling redirects."""
         | 
| 115 | 
            +
                try:
         | 
| 116 | 
            +
                    response = requests.get(url, allow_redirects=False)
         | 
| 117 | 
            +
                    response.raise_for_status() 
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    if response.status_code == 302:  
         | 
| 120 | 
            +
                        redirect_url = response.headers["Location"]
         | 
| 121 | 
            +
                        response = requests.get(redirect_url, stream=True)
         | 
| 122 | 
            +
                        response.raise_for_status()
         | 
| 123 | 
            +
                    else:
         | 
| 124 | 
            +
                        response = requests.get(url, stream=True)
         | 
| 125 | 
            +
                        response.raise_for_status()
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    with open(filename, "wb") as f:
         | 
| 128 | 
            +
                        for chunk in response.iter_content(chunk_size=8192):
         | 
| 129 | 
            +
                            f.write(chunk)
         | 
| 130 | 
            +
                    print(f"Downloaded {filename} successfully.")
         | 
| 131 | 
            +
                    return True
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                except requests.exceptions.RequestException as e:
         | 
| 134 | 
            +
                    print(f"Error downloading file: {e}")
         | 
| 135 | 
            +
                    return False
         | 
| 136 | 
            +
             | 
| 137 | 
            +
            def segment_sky(image_path, onnx_session, mask_filename=None):
         | 
| 138 | 
            +
                """
         | 
| 139 | 
            +
                Segments sky from an image using an ONNX model.
         | 
| 140 | 
            +
                """
         | 
| 141 | 
            +
                image = cv2.imread(image_path)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                result_map = run_skyseg(onnx_session, [320, 320], image)
         | 
| 144 | 
            +
                result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                # Fix: Invert the mask so that 255 = non-sky, 0 = sky
         | 
| 147 | 
            +
                # The model outputs low values for sky, high values for non-sky
         | 
| 148 | 
            +
                output_mask = np.zeros_like(result_map_original)
         | 
| 149 | 
            +
                output_mask[result_map_original < 32] = 255  # Use threshold of 32
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                if mask_filename is not None:
         | 
| 152 | 
            +
                    os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
         | 
| 153 | 
            +
                    cv2.imwrite(mask_filename, output_mask)
         | 
| 154 | 
            +
                
         | 
| 155 | 
            +
                return output_mask
         | 
| 156 | 
            +
             | 
| 157 | 
            +
            def run_skyseg(onnx_session, input_size, image):
         | 
| 158 | 
            +
                """
         | 
| 159 | 
            +
                Runs sky segmentation inference using ONNX model.
         | 
| 160 | 
            +
                """
         | 
| 161 | 
            +
                import copy
         | 
| 162 | 
            +
                
         | 
| 163 | 
            +
                temp_image = copy.deepcopy(image)
         | 
| 164 | 
            +
                resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
         | 
| 165 | 
            +
                x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
         | 
| 166 | 
            +
                x = np.array(x, dtype=np.float32)
         | 
| 167 | 
            +
                mean = [0.485, 0.456, 0.406]
         | 
| 168 | 
            +
                std = [0.229, 0.224, 0.225]
         | 
| 169 | 
            +
                x = (x / 255 - mean) / std
         | 
| 170 | 
            +
                x = x.transpose(2, 0, 1)
         | 
| 171 | 
            +
                x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                input_name = onnx_session.get_inputs()[0].name
         | 
| 174 | 
            +
                output_name = onnx_session.get_outputs()[0].name
         | 
| 175 | 
            +
                onnx_result = onnx_session.run([output_name], {input_name: x})
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                onnx_result = np.array(onnx_result).squeeze()
         | 
| 178 | 
            +
                min_value = np.min(onnx_result)
         | 
| 179 | 
            +
                max_value = np.max(onnx_result)
         | 
| 180 | 
            +
                onnx_result = (onnx_result - min_value) / (max_value - min_value)
         | 
| 181 | 
            +
                onnx_result *= 255
         | 
| 182 | 
            +
                onnx_result = onnx_result.astype("uint8")
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                return onnx_result
         | 
| 185 | 
            +
             | 
| 186 | 
            +
            def filter_and_prepare_points(predictions, conf_threshold, mask_sky=False, mask_black_bg=False, 
         | 
| 187 | 
            +
                                         mask_white_bg=False, stride=1, prediction_mode="Depthmap and Camera Branch"):
         | 
| 188 | 
            +
                """
         | 
| 189 | 
            +
                Filter points based on confidence and prepare for COLMAP format.
         | 
| 190 | 
            +
                Implementation matches the conventions in the original VGGT code.
         | 
| 191 | 
            +
                """
         | 
| 192 | 
            +
                
         | 
| 193 | 
            +
                if "Pointmap" in prediction_mode:
         | 
| 194 | 
            +
                    print("Using Pointmap Branch")
         | 
| 195 | 
            +
                    if "world_points" in predictions:
         | 
| 196 | 
            +
                        pred_world_points = predictions["world_points"]
         | 
| 197 | 
            +
                        pred_world_points_conf = predictions.get("world_points_conf", np.ones_like(pred_world_points[..., 0]))
         | 
| 198 | 
            +
                    else:
         | 
| 199 | 
            +
                        print("Warning: world_points not found in predictions, falling back to depth-based points")
         | 
| 200 | 
            +
                        pred_world_points = predictions["world_points_from_depth"]
         | 
| 201 | 
            +
                        pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0]))
         | 
| 202 | 
            +
                else:
         | 
| 203 | 
            +
                    print("Using Depthmap and Camera Branch")
         | 
| 204 | 
            +
                    pred_world_points = predictions["world_points_from_depth"]
         | 
| 205 | 
            +
                    pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0]))
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                colors_rgb = predictions["images"] 
         | 
| 208 | 
            +
                
         | 
| 209 | 
            +
                S, H, W = pred_world_points.shape[:3]
         | 
| 210 | 
            +
                if colors_rgb.shape[:3] != (S, H, W):
         | 
| 211 | 
            +
                    print(f"Reshaping colors_rgb from {colors_rgb.shape} to match {(S, H, W, 3)}")
         | 
| 212 | 
            +
                    reshaped_colors = np.zeros((S, H, W, 3), dtype=np.float32)
         | 
| 213 | 
            +
                    for i in range(S):
         | 
| 214 | 
            +
                        if i < len(colors_rgb):
         | 
| 215 | 
            +
                            reshaped_colors[i] = cv2.resize(colors_rgb[i], (W, H))
         | 
| 216 | 
            +
                    colors_rgb = reshaped_colors
         | 
| 217 | 
            +
                
         | 
| 218 | 
            +
                colors_rgb = (colors_rgb * 255).astype(np.uint8)
         | 
| 219 | 
            +
                
         | 
| 220 | 
            +
                if mask_sky:
         | 
| 221 | 
            +
                    print("Applying sky segmentation mask")
         | 
| 222 | 
            +
                    try:
         | 
| 223 | 
            +
                        import onnxruntime
         | 
| 224 | 
            +
                     
         | 
| 225 | 
            +
                        with tempfile.TemporaryDirectory() as temp_dir:
         | 
| 226 | 
            +
                            print(f"Created temporary directory for sky segmentation: {temp_dir}")
         | 
| 227 | 
            +
                            temp_images_dir = os.path.join(temp_dir, "images")
         | 
| 228 | 
            +
                            sky_masks_dir = os.path.join(temp_dir, "sky_masks")
         | 
| 229 | 
            +
                            os.makedirs(temp_images_dir, exist_ok=True)
         | 
| 230 | 
            +
                            os.makedirs(sky_masks_dir, exist_ok=True)
         | 
| 231 | 
            +
                            
         | 
| 232 | 
            +
                            image_list = []
         | 
| 233 | 
            +
                            for i, img in enumerate(colors_rgb):
         | 
| 234 | 
            +
                                img_path = os.path.join(temp_images_dir, f"image_{i:04d}.png")
         | 
| 235 | 
            +
                                image_list.append(img_path)
         | 
| 236 | 
            +
                                cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
         | 
| 237 | 
            +
                            
         | 
| 238 | 
            +
                       
         | 
| 239 | 
            +
                            skyseg_path = os.path.join(temp_dir, "skyseg.onnx")
         | 
| 240 | 
            +
                            if not os.path.exists("skyseg.onnx"): 
         | 
| 241 | 
            +
                                print("Downloading skyseg.onnx...")
         | 
| 242 | 
            +
                                download_success = download_file_from_url(
         | 
| 243 | 
            +
                                    "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", 
         | 
| 244 | 
            +
                                    skyseg_path
         | 
| 245 | 
            +
                                )
         | 
| 246 | 
            +
                                if not download_success:
         | 
| 247 | 
            +
                                    print("Failed to download skyseg model, skipping sky filtering")
         | 
| 248 | 
            +
                                    mask_sky = False
         | 
| 249 | 
            +
                            else:
         | 
| 250 | 
            +
                        
         | 
| 251 | 
            +
                                import shutil
         | 
| 252 | 
            +
                                shutil.copy("skyseg.onnx", skyseg_path)
         | 
| 253 | 
            +
                            
         | 
| 254 | 
            +
                            if mask_sky:  
         | 
| 255 | 
            +
                                skyseg_session = onnxruntime.InferenceSession(skyseg_path)
         | 
| 256 | 
            +
                                sky_mask_list = []
         | 
| 257 | 
            +
                                
         | 
| 258 | 
            +
                                for img_path in image_list:
         | 
| 259 | 
            +
                                    mask_path = os.path.join(sky_masks_dir, os.path.basename(img_path))
         | 
| 260 | 
            +
                                    sky_mask = segment_sky(img_path, skyseg_session, mask_path)
         | 
| 261 | 
            +
                       
         | 
| 262 | 
            +
                                    if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
         | 
| 263 | 
            +
                                        sky_mask = cv2.resize(sky_mask, (W, H))
         | 
| 264 | 
            +
                                    
         | 
| 265 | 
            +
                                    sky_mask_list.append(sky_mask)
         | 
| 266 | 
            +
                                
         | 
| 267 | 
            +
                                sky_mask_array = np.array(sky_mask_list)
         | 
| 268 | 
            +
                                
         | 
| 269 | 
            +
                                sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32)
         | 
| 270 | 
            +
                                pred_world_points_conf = pred_world_points_conf * sky_mask_binary
         | 
| 271 | 
            +
                                print(f"Applied sky mask, shape: {sky_mask_binary.shape}")
         | 
| 272 | 
            +
                            
         | 
| 273 | 
            +
                    except (ImportError, Exception) as e:
         | 
| 274 | 
            +
                        print(f"Error in sky segmentation: {e}")
         | 
| 275 | 
            +
                        mask_sky = False
         | 
| 276 | 
            +
                
         | 
| 277 | 
            +
                vertices_3d = pred_world_points.reshape(-1, 3)
         | 
| 278 | 
            +
                conf = pred_world_points_conf.reshape(-1)
         | 
| 279 | 
            +
                colors_rgb_flat = colors_rgb.reshape(-1, 3)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                if len(conf) != len(colors_rgb_flat):
         | 
| 284 | 
            +
                    print(f"WARNING: Shape mismatch between confidence ({len(conf)}) and colors ({len(colors_rgb_flat)})")
         | 
| 285 | 
            +
                    min_size = min(len(conf), len(colors_rgb_flat))
         | 
| 286 | 
            +
                    conf = conf[:min_size]
         | 
| 287 | 
            +
                    vertices_3d = vertices_3d[:min_size]
         | 
| 288 | 
            +
                    colors_rgb_flat = colors_rgb_flat[:min_size]
         | 
| 289 | 
            +
                
         | 
| 290 | 
            +
                if conf_threshold == 0.0:
         | 
| 291 | 
            +
                    conf_thres_value = 0.0
         | 
| 292 | 
            +
                else:
         | 
| 293 | 
            +
                    conf_thres_value = np.percentile(conf, conf_threshold)
         | 
| 294 | 
            +
                
         | 
| 295 | 
            +
                print(f"Using confidence threshold: {conf_threshold}% (value: {conf_thres_value:.4f})")
         | 
| 296 | 
            +
                conf_mask = (conf >= conf_thres_value) & (conf > 1e-5)
         | 
| 297 | 
            +
                
         | 
| 298 | 
            +
                if mask_black_bg:
         | 
| 299 | 
            +
                    print("Filtering black background")
         | 
| 300 | 
            +
                    black_bg_mask = colors_rgb_flat.sum(axis=1) >= 16
         | 
| 301 | 
            +
                    conf_mask = conf_mask & black_bg_mask
         | 
| 302 | 
            +
                
         | 
| 303 | 
            +
                if mask_white_bg:
         | 
| 304 | 
            +
                    print("Filtering white background")
         | 
| 305 | 
            +
                    white_bg_mask = ~((colors_rgb_flat[:, 0] > 240) & (colors_rgb_flat[:, 1] > 240) & (colors_rgb_flat[:, 2] > 240))
         | 
| 306 | 
            +
                    conf_mask = conf_mask & white_bg_mask
         | 
| 307 | 
            +
                
         | 
| 308 | 
            +
                filtered_vertices = vertices_3d[conf_mask]
         | 
| 309 | 
            +
                filtered_colors = colors_rgb_flat[conf_mask]
         | 
| 310 | 
            +
                
         | 
| 311 | 
            +
                if len(filtered_vertices) == 0:
         | 
| 312 | 
            +
                    print("Warning: No points remaining after filtering. Using default point.")
         | 
| 313 | 
            +
                    filtered_vertices = np.array([[0, 0, 0]])
         | 
| 314 | 
            +
                    filtered_colors = np.array([[200, 200, 200]])
         | 
| 315 | 
            +
                
         | 
| 316 | 
            +
                print(f"Filtered to {len(filtered_vertices)} points")
         | 
| 317 | 
            +
                
         | 
| 318 | 
            +
                points3D = []
         | 
| 319 | 
            +
                point_indices = {}
         | 
| 320 | 
            +
                image_points2D = [[] for _ in range(len(pred_world_points))]
         | 
| 321 | 
            +
                
         | 
| 322 | 
            +
                print(f"Preparing points for COLMAP format with stride {stride}...")
         | 
| 323 | 
            +
                
         | 
| 324 | 
            +
                total_points = 0
         | 
| 325 | 
            +
                for img_idx in range(S):
         | 
| 326 | 
            +
                    for y in range(0, H, stride):
         | 
| 327 | 
            +
                        for x in range(0, W, stride):
         | 
| 328 | 
            +
                            flat_idx = img_idx * H * W + y * W + x
         | 
| 329 | 
            +
                            
         | 
| 330 | 
            +
                            if flat_idx >= len(conf):
         | 
| 331 | 
            +
                                continue
         | 
| 332 | 
            +
                            
         | 
| 333 | 
            +
                            if conf[flat_idx] < conf_thres_value or conf[flat_idx] <= 1e-5:
         | 
| 334 | 
            +
                                continue
         | 
| 335 | 
            +
                            
         | 
| 336 | 
            +
                            if mask_black_bg and colors_rgb_flat[flat_idx].sum() < 16:
         | 
| 337 | 
            +
                                continue
         | 
| 338 | 
            +
                            
         | 
| 339 | 
            +
                            if mask_white_bg and all(colors_rgb_flat[flat_idx] > 240):
         | 
| 340 | 
            +
                                continue
         | 
| 341 | 
            +
                            
         | 
| 342 | 
            +
                            point3D = vertices_3d[flat_idx]
         | 
| 343 | 
            +
                            rgb = colors_rgb_flat[flat_idx]
         | 
| 344 | 
            +
                            
         | 
| 345 | 
            +
                            if not np.all(np.isfinite(point3D)):
         | 
| 346 | 
            +
                                continue
         | 
| 347 | 
            +
                            
         | 
| 348 | 
            +
                            point_hash = hash_point(point3D, scale=100)
         | 
| 349 | 
            +
                            
         | 
| 350 | 
            +
                            if point_hash not in point_indices:
         | 
| 351 | 
            +
                                point_idx = len(points3D)
         | 
| 352 | 
            +
                                point_indices[point_hash] = point_idx
         | 
| 353 | 
            +
                                
         | 
| 354 | 
            +
                                point_entry = {
         | 
| 355 | 
            +
                                    "id": point_idx,
         | 
| 356 | 
            +
                                    "xyz": point3D,
         | 
| 357 | 
            +
                                    "rgb": rgb,
         | 
| 358 | 
            +
                                    "error": 1.0,
         | 
| 359 | 
            +
                                    "track": [(img_idx, len(image_points2D[img_idx]))]
         | 
| 360 | 
            +
                                }
         | 
| 361 | 
            +
                                points3D.append(point_entry)
         | 
| 362 | 
            +
                                total_points += 1
         | 
| 363 | 
            +
                            else:
         | 
| 364 | 
            +
                                point_idx = point_indices[point_hash]
         | 
| 365 | 
            +
                                points3D[point_idx]["track"].append((img_idx, len(image_points2D[img_idx])))
         | 
| 366 | 
            +
                            
         | 
| 367 | 
            +
                            image_points2D[img_idx].append((x, y, point_indices[point_hash]))
         | 
| 368 | 
            +
                
         | 
| 369 | 
            +
                print(f"Prepared {len(points3D)} 3D points with {sum(len(pts) for pts in image_points2D)} observations for COLMAP")
         | 
| 370 | 
            +
                return points3D, image_points2D
         | 
| 371 | 
            +
             | 
| 372 | 
            +
            def hash_point(point, scale=100):
         | 
| 373 | 
            +
                """Create a hash for a 3D point by quantizing coordinates."""
         | 
| 374 | 
            +
                quantized = tuple(np.round(point * scale).astype(int))
         | 
| 375 | 
            +
                return hash(quantized)
         | 
| 376 | 
            +
             | 
| 377 | 
            +
            def write_colmap_cameras_txt(file_path, intrinsics, image_width, image_height):
         | 
| 378 | 
            +
                """Write camera intrinsics to COLMAP cameras.txt format."""
         | 
| 379 | 
            +
                with open(file_path, 'w') as f:
         | 
| 380 | 
            +
                    f.write("# Camera list with one line of data per camera:\n")
         | 
| 381 | 
            +
                    f.write("#   CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n")
         | 
| 382 | 
            +
                    f.write(f"# Number of cameras: {len(intrinsics)}\n")
         | 
| 383 | 
            +
                    
         | 
| 384 | 
            +
                    for i, intrinsic in enumerate(intrinsics):
         | 
| 385 | 
            +
                        camera_id = i + 1  # COLMAP uses 1-indexed camera IDs
         | 
| 386 | 
            +
                        model = "PINHOLE" 
         | 
| 387 | 
            +
                        
         | 
| 388 | 
            +
                        fx = intrinsic[0, 0]
         | 
| 389 | 
            +
                        fy = intrinsic[1, 1]
         | 
| 390 | 
            +
                        cx = intrinsic[0, 2]
         | 
| 391 | 
            +
                        cy = intrinsic[1, 2]
         | 
| 392 | 
            +
                        
         | 
| 393 | 
            +
                        f.write(f"{camera_id} {model} {image_width} {image_height} {fx} {fy} {cx} {cy}\n")
         | 
| 394 | 
            +
             | 
| 395 | 
            +
            def write_colmap_images_txt(file_path, quaternions, translations, image_points2D, image_names):
         | 
| 396 | 
            +
                """Write camera poses and keypoints to COLMAP images.txt format."""
         | 
| 397 | 
            +
                with open(file_path, 'w') as f:
         | 
| 398 | 
            +
                    f.write("# Image list with two lines of data per image:\n")
         | 
| 399 | 
            +
                    f.write("#   IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n")
         | 
| 400 | 
            +
                    f.write("#   POINTS2D[] as (X, Y, POINT3D_ID)\n")
         | 
| 401 | 
            +
                    
         | 
| 402 | 
            +
                    num_points = sum(len(points) for points in image_points2D)
         | 
| 403 | 
            +
                    avg_points = num_points / len(image_points2D) if image_points2D else 0
         | 
| 404 | 
            +
                    f.write(f"# Number of images: {len(quaternions)}, mean observations per image: {avg_points:.1f}\n")
         | 
| 405 | 
            +
                    
         | 
| 406 | 
            +
                    for i in range(len(quaternions)):
         | 
| 407 | 
            +
                        image_id = i + 1 
         | 
| 408 | 
            +
                        camera_id = i + 1  
         | 
| 409 | 
            +
                      
         | 
| 410 | 
            +
                        qw, qx, qy, qz = quaternions[i]
         | 
| 411 | 
            +
                        tx, ty, tz = translations[i]
         | 
| 412 | 
            +
                        
         | 
| 413 | 
            +
                        f.write(f"{image_id} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {camera_id} {os.path.basename(image_names[i])}\n")
         | 
| 414 | 
            +
                        
         | 
| 415 | 
            +
                        points_line = " ".join([f"{x} {y} {point3d_id+1}" for x, y, point3d_id in image_points2D[i]])
         | 
| 416 | 
            +
                        f.write(f"{points_line}\n")
         | 
| 417 | 
            +
             | 
| 418 | 
            +
            def write_colmap_points3D_txt(file_path, points3D):
         | 
| 419 | 
            +
                """Write 3D points and tracks to COLMAP points3D.txt format."""
         | 
| 420 | 
            +
                with open(file_path, 'w') as f:
         | 
| 421 | 
            +
                    f.write("# 3D point list with one line of data per point:\n")
         | 
| 422 | 
            +
                    f.write("#   POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n")
         | 
| 423 | 
            +
                    
         | 
| 424 | 
            +
                    avg_track_length = sum(len(point["track"]) for point in points3D) / len(points3D) if points3D else 0
         | 
| 425 | 
            +
                    f.write(f"# Number of points: {len(points3D)}, mean track length: {avg_track_length:.4f}\n")
         | 
| 426 | 
            +
                    
         | 
| 427 | 
            +
                    for point in points3D:
         | 
| 428 | 
            +
                        point_id = point["id"] + 1  
         | 
| 429 | 
            +
                        x, y, z = point["xyz"]
         | 
| 430 | 
            +
                        r, g, b = point["rgb"]
         | 
| 431 | 
            +
                        error = point["error"]
         | 
| 432 | 
            +
                        
         | 
| 433 | 
            +
                        track = " ".join([f"{img_id+1} {point2d_idx}" for img_id, point2d_idx in point["track"]])
         | 
| 434 | 
            +
                        
         | 
| 435 | 
            +
                        f.write(f"{point_id} {x} {y} {z} {int(r)} {int(g)} {int(b)} {error} {track}\n")
         | 
| 436 | 
            +
             | 
| 437 | 
            +
            def write_colmap_cameras_bin(file_path, intrinsics, image_width, image_height):
         | 
| 438 | 
            +
                """Write camera intrinsics to COLMAP cameras.bin format."""
         | 
| 439 | 
            +
                with open(file_path, 'wb') as fid:
         | 
| 440 | 
            +
                    # Write number of cameras (uint64)
         | 
| 441 | 
            +
                    fid.write(struct.pack('<Q', len(intrinsics)))
         | 
| 442 | 
            +
                    
         | 
| 443 | 
            +
                    for i, intrinsic in enumerate(intrinsics):
         | 
| 444 | 
            +
                        camera_id = i + 1
         | 
| 445 | 
            +
                        model_id = 1 
         | 
| 446 | 
            +
                        
         | 
| 447 | 
            +
                        fx = float(intrinsic[0, 0])
         | 
| 448 | 
            +
                        fy = float(intrinsic[1, 1])
         | 
| 449 | 
            +
                        cx = float(intrinsic[0, 2])
         | 
| 450 | 
            +
                        cy = float(intrinsic[1, 2])
         | 
| 451 | 
            +
                        
         | 
| 452 | 
            +
                        # Camera ID (uint32)
         | 
| 453 | 
            +
                        fid.write(struct.pack('<I', camera_id))
         | 
| 454 | 
            +
                        # Model ID (uint32)
         | 
| 455 | 
            +
                        fid.write(struct.pack('<I', model_id))
         | 
| 456 | 
            +
                        # Width (uint64)
         | 
| 457 | 
            +
                        fid.write(struct.pack('<Q', image_width))
         | 
| 458 | 
            +
                        # Height (uint64)
         | 
| 459 | 
            +
                        fid.write(struct.pack('<Q', image_height))
         | 
| 460 | 
            +
                        
         | 
| 461 | 
            +
                        # Parameters (double)
         | 
| 462 | 
            +
                        fid.write(struct.pack('<dddd', fx, fy, cx, cy))
         | 
| 463 | 
            +
             | 
| 464 | 
            +
            def write_colmap_images_bin(file_path, quaternions, translations, image_points2D, image_names):
         | 
| 465 | 
            +
                """Write camera poses and keypoints to COLMAP images.bin format."""
         | 
| 466 | 
            +
                with open(file_path, 'wb') as fid:
         | 
| 467 | 
            +
                    # Write number of images (uint64)
         | 
| 468 | 
            +
                    fid.write(struct.pack('<Q', len(quaternions)))
         | 
| 469 | 
            +
                    
         | 
| 470 | 
            +
                    for i in range(len(quaternions)):
         | 
| 471 | 
            +
                        image_id = i + 1
         | 
| 472 | 
            +
                        camera_id = i + 1
         | 
| 473 | 
            +
                        
         | 
| 474 | 
            +
                        qw, qx, qy, qz = quaternions[i].astype(float)
         | 
| 475 | 
            +
                        tx, ty, tz = translations[i].astype(float)
         | 
| 476 | 
            +
                        
         | 
| 477 | 
            +
                        image_name = os.path.basename(image_names[i]).encode()
         | 
| 478 | 
            +
                        points = image_points2D[i]
         | 
| 479 | 
            +
                        
         | 
| 480 | 
            +
                        # Image ID (uint32)
         | 
| 481 | 
            +
                        fid.write(struct.pack('<I', image_id))
         | 
| 482 | 
            +
                        # Quaternion (double): qw, qx, qy, qz
         | 
| 483 | 
            +
                        fid.write(struct.pack('<dddd', qw, qx, qy, qz))
         | 
| 484 | 
            +
                        # Translation (double): tx, ty, tz
         | 
| 485 | 
            +
                        fid.write(struct.pack('<ddd', tx, ty, tz))
         | 
| 486 | 
            +
                        # Camera ID (uint32)
         | 
| 487 | 
            +
                        fid.write(struct.pack('<I', camera_id))
         | 
| 488 | 
            +
                        # Image name
         | 
| 489 | 
            +
                        fid.write(struct.pack('<I', len(image_name)))
         | 
| 490 | 
            +
                        fid.write(image_name)
         | 
| 491 | 
            +
                        
         | 
| 492 | 
            +
                        # Write number of 2D points (uint64)
         | 
| 493 | 
            +
                        fid.write(struct.pack('<Q', len(points)))
         | 
| 494 | 
            +
                        
         | 
| 495 | 
            +
                        # Write 2D points: x, y, point3D_id
         | 
| 496 | 
            +
                        for x, y, point3d_id in points:
         | 
| 497 | 
            +
                            fid.write(struct.pack('<dd', float(x), float(y)))
         | 
| 498 | 
            +
                            fid.write(struct.pack('<Q', point3d_id + 1))
         | 
| 499 | 
            +
             | 
| 500 | 
            +
            def write_colmap_points3D_bin(file_path, points3D):
         | 
| 501 | 
            +
                """Write 3D points and tracks to COLMAP points3D.bin format."""
         | 
| 502 | 
            +
                with open(file_path, 'wb') as fid:
         | 
| 503 | 
            +
                    # Write number of points (uint64)
         | 
| 504 | 
            +
                    fid.write(struct.pack('<Q', len(points3D)))
         | 
| 505 | 
            +
                    
         | 
| 506 | 
            +
                    for point in points3D:
         | 
| 507 | 
            +
                        point_id = point["id"] + 1
         | 
| 508 | 
            +
                        x, y, z = point["xyz"].astype(float)
         | 
| 509 | 
            +
                        r, g, b = point["rgb"].astype(np.uint8)
         | 
| 510 | 
            +
                        error = float(point["error"])
         | 
| 511 | 
            +
                        track = point["track"]
         | 
| 512 | 
            +
                        
         | 
| 513 | 
            +
                        # Point ID (uint64)
         | 
| 514 | 
            +
                        fid.write(struct.pack('<Q', point_id))
         | 
| 515 | 
            +
                        # Position (double): x, y, z
         | 
| 516 | 
            +
                        fid.write(struct.pack('<ddd', x, y, z))
         | 
| 517 | 
            +
                        # Color (uint8): r, g, b
         | 
| 518 | 
            +
                        fid.write(struct.pack('<BBB', int(r), int(g), int(b)))
         | 
| 519 | 
            +
                        # Error (double)
         | 
| 520 | 
            +
                        fid.write(struct.pack('<d', error))
         | 
| 521 | 
            +
                        
         | 
| 522 | 
            +
                        # Track: list of (image_id, point2D_idx)
         | 
| 523 | 
            +
                        fid.write(struct.pack('<Q', len(track)))
         | 
| 524 | 
            +
                        for img_id, point2d_idx in track:
         | 
| 525 | 
            +
                            fid.write(struct.pack('<II', img_id + 1, point2d_idx))
         | 
| 526 | 
            +
             | 
| 527 | 
            +
            def main():
         | 
| 528 | 
            +
                parser = argparse.ArgumentParser(description="Convert images to COLMAP format using VGGT")
         | 
| 529 | 
            +
                parser.add_argument("--image_dir", type=str, required=True, 
         | 
| 530 | 
            +
                                    help="Directory containing input images")
         | 
| 531 | 
            +
                parser.add_argument("--output_dir", type=str, default="colmap_output", 
         | 
| 532 | 
            +
                                    help="Directory to save COLMAP files")
         | 
| 533 | 
            +
                parser.add_argument("--conf_threshold", type=float, default=50.0, 
         | 
| 534 | 
            +
                                    help="Confidence threshold (0-100%) for including points")
         | 
| 535 | 
            +
                parser.add_argument("--mask_sky", action="store_true",
         | 
| 536 | 
            +
                                    help="Filter out points likely to be sky")
         | 
| 537 | 
            +
                parser.add_argument("--mask_black_bg", action="store_true",
         | 
| 538 | 
            +
                                    help="Filter out points with very dark/black color")
         | 
| 539 | 
            +
                parser.add_argument("--mask_white_bg", action="store_true",
         | 
| 540 | 
            +
                                    help="Filter out points with very bright/white color")
         | 
| 541 | 
            +
                parser.add_argument("--binary", action="store_true", 
         | 
| 542 | 
            +
                                    help="Output binary COLMAP files instead of text")
         | 
| 543 | 
            +
                parser.add_argument("--stride", type=int, default=1, 
         | 
| 544 | 
            +
                                    help="Stride for point sampling (higher = fewer points)")
         | 
| 545 | 
            +
                parser.add_argument("--prediction_mode", type=str, default="Depthmap and Camera Branch",
         | 
| 546 | 
            +
                                    choices=["Depthmap and Camera Branch", "Pointmap Branch"],
         | 
| 547 | 
            +
                                    help="Which prediction branch to use")
         | 
| 548 | 
            +
                
         | 
| 549 | 
            +
                args = parser.parse_args()
         | 
| 550 | 
            +
                
         | 
| 551 | 
            +
                os.makedirs(args.output_dir, exist_ok=True)
         | 
| 552 | 
            +
                
         | 
| 553 | 
            +
                model, device = load_model()
         | 
| 554 | 
            +
                
         | 
| 555 | 
            +
                predictions, image_names = process_images(args.image_dir, model, device)
         | 
| 556 | 
            +
                
         | 
| 557 | 
            +
                print("Converting camera parameters to COLMAP format...")
         | 
| 558 | 
            +
                quaternions, translations = extrinsic_to_colmap_format(predictions["extrinsic"])
         | 
| 559 | 
            +
                
         | 
| 560 | 
            +
                print(f"Filtering points with confidence threshold {args.conf_threshold}% and stride {args.stride}...")
         | 
| 561 | 
            +
                points3D, image_points2D = filter_and_prepare_points(
         | 
| 562 | 
            +
                    predictions, 
         | 
| 563 | 
            +
                    args.conf_threshold, 
         | 
| 564 | 
            +
                    mask_sky=args.mask_sky, 
         | 
| 565 | 
            +
                    mask_black_bg=args.mask_black_bg,
         | 
| 566 | 
            +
                    mask_white_bg=args.mask_white_bg,
         | 
| 567 | 
            +
                    stride=args.stride,
         | 
| 568 | 
            +
                    prediction_mode=args.prediction_mode
         | 
| 569 | 
            +
                )
         | 
| 570 | 
            +
                
         | 
| 571 | 
            +
                height, width = predictions["depth"].shape[1:3]
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                print(f"Writing {'binary' if args.binary else 'text'} COLMAP files to {args.output_dir}...")
         | 
| 574 | 
            +
                if args.binary:
         | 
| 575 | 
            +
                    write_colmap_cameras_bin(
         | 
| 576 | 
            +
                        os.path.join(args.output_dir, "cameras.bin"), 
         | 
| 577 | 
            +
                        predictions["intrinsic"], width, height)
         | 
| 578 | 
            +
                    write_colmap_images_bin(
         | 
| 579 | 
            +
                        os.path.join(args.output_dir, "images.bin"), 
         | 
| 580 | 
            +
                        quaternions, translations, image_points2D, image_names)
         | 
| 581 | 
            +
                    write_colmap_points3D_bin(
         | 
| 582 | 
            +
                        os.path.join(args.output_dir, "points3D.bin"), 
         | 
| 583 | 
            +
                        points3D)
         | 
| 584 | 
            +
                else:
         | 
| 585 | 
            +
                    write_colmap_cameras_txt(
         | 
| 586 | 
            +
                        os.path.join(args.output_dir, "cameras.txt"), 
         | 
| 587 | 
            +
                        predictions["intrinsic"], width, height)
         | 
| 588 | 
            +
                    write_colmap_images_txt(
         | 
| 589 | 
            +
                        os.path.join(args.output_dir, "images.txt"), 
         | 
| 590 | 
            +
                        quaternions, translations, image_points2D, image_names)
         | 
| 591 | 
            +
                    write_colmap_points3D_txt(
         | 
| 592 | 
            +
                        os.path.join(args.output_dir, "points3D.txt"), 
         | 
| 593 | 
            +
                        points3D)
         | 
| 594 | 
            +
                
         | 
| 595 | 
            +
                print(f"COLMAP files successfully written to {args.output_dir}")
         | 
| 596 | 
            +
             | 
| 597 | 
            +
            if __name__ == "__main__":
         | 
| 598 | 
            +
                main()
         | 
    	
        source/visualization.py
    ADDED
    
    | @@ -0,0 +1,1072 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from matplotlib import pyplot as plt
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            from typing import List
         | 
| 7 | 
            +
            import sys
         | 
| 8 | 
            +
            sys.path.append('./submodules/gaussian-splatting/')
         | 
| 9 | 
            +
            from scene.cameras import Camera
         | 
| 10 | 
            +
            from PIL import Image
         | 
| 11 | 
            +
            import imageio
         | 
| 12 | 
            +
            from scipy.interpolate import splprep, splev
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import cv2
         | 
| 15 | 
            +
            import numpy as np
         | 
| 16 | 
            +
            import plotly.graph_objects as go
         | 
| 17 | 
            +
            import numpy as np
         | 
| 18 | 
            +
            from scipy.spatial.transform import Rotation as R, Slerp
         | 
| 19 | 
            +
            from scipy.spatial import distance_matrix
         | 
| 20 | 
            +
            from sklearn.decomposition import PCA
         | 
| 21 | 
            +
            from scipy.interpolate import splprep, splev
         | 
| 22 | 
            +
            from typing import List
         | 
| 23 | 
            +
            from sklearn.mixture import GaussianMixture
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            def render_gaussians_rgb(generator3DGS, viewpoint_cam, visualize=False):
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                Simply render gaussians from the generator3DGS from the viewpoint_cam.
         | 
| 28 | 
            +
                Args:
         | 
| 29 | 
            +
                    generator3DGS : instance of the Generator3DGS class from the networks.py file
         | 
| 30 | 
            +
                    viewpoint_cam : camera instance
         | 
| 31 | 
            +
                    visualize : boolean flag. If True, will call pyplot function and render image inplace
         | 
| 32 | 
            +
                Returns:
         | 
| 33 | 
            +
                    uint8 numpy array with shape (H, W, 3) representing the image
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
                with torch.no_grad():
         | 
| 36 | 
            +
                    render_pkg = generator3DGS(viewpoint_cam)
         | 
| 37 | 
            +
                    image = render_pkg["render"]
         | 
| 38 | 
            +
                    image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    # Clip values to be in the range [0, 1]
         | 
| 41 | 
            +
                    image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)
         | 
| 42 | 
            +
                    if visualize:
         | 
| 43 | 
            +
                        plt.figure(figsize=(12, 8))
         | 
| 44 | 
            +
                        plt.imshow(image_np)
         | 
| 45 | 
            +
                        plt.show()
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    return image_np
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            def render_gaussians_D_scores(generator3DGS, viewpoint_cam, mask=None, mask_channel=0, visualize=False):
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                    Simply render D_scores of gaussians from the generator3DGS from the viewpoint_cam.
         | 
| 52 | 
            +
                    Args:
         | 
| 53 | 
            +
                        generator3DGS : instance of the Generator3DGS class from the networks.py file
         | 
| 54 | 
            +
                        viewpoint_cam : camera instance
         | 
| 55 | 
            +
                        visualize : boolean flag. If True, will call pyplot function and render image inplace
         | 
| 56 | 
            +
                        mask : optional mask to highlight specific gaussians. Must be of shape (N) where N is the numnber
         | 
| 57 | 
            +
                            of gaussians in generator3DGS.gaussians. Must be a torch tensor of floats, please scale according
         | 
| 58 | 
            +
                            to how much color you want to have. Recommended mask value is 10.
         | 
| 59 | 
            +
                        mask_channel: to which color channel should we add mask
         | 
| 60 | 
            +
                    Returns:
         | 
| 61 | 
            +
                        uint8 numpy array with shape (H, W, 3) representing the generator3DGS.gaussians.D_scores rendered as colors
         | 
| 62 | 
            +
                    """
         | 
| 63 | 
            +
                with torch.no_grad():
         | 
| 64 | 
            +
                    # Visualize D_scores
         | 
| 65 | 
            +
                    generator3DGS.gaussians._features_dc = generator3DGS.gaussians._features_dc * 1e-4 + \
         | 
| 66 | 
            +
                                                           torch.stack([generator3DGS.gaussians.D_scores] * 3, axis=-1)
         | 
| 67 | 
            +
                    generator3DGS.gaussians._features_rest = generator3DGS.gaussians._features_rest * 1e-4
         | 
| 68 | 
            +
                    if mask is not None:
         | 
| 69 | 
            +
                        generator3DGS.gaussians._features_dc[..., mask_channel] += mask.unsqueeze(-1)
         | 
| 70 | 
            +
                    render_pkg = generator3DGS(viewpoint_cam)
         | 
| 71 | 
            +
                    image = render_pkg["render"]
         | 
| 72 | 
            +
                    image_np = image.clone().detach().cpu().numpy().transpose(1, 2, 0)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    # Clip values to be in the range [0, 1]
         | 
| 75 | 
            +
                    image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)
         | 
| 76 | 
            +
                    if visualize:
         | 
| 77 | 
            +
                        plt.figure(figsize=(12, 8))
         | 
| 78 | 
            +
                        plt.imshow(image_np)
         | 
| 79 | 
            +
                        plt.show()
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    if mask is not None:
         | 
| 82 | 
            +
                        generator3DGS.gaussians._features_dc[..., mask_channel] -= mask.unsqueeze(-1)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    generator3DGS.gaussians._features_dc = (generator3DGS.gaussians._features_dc - \
         | 
| 85 | 
            +
                                                                 torch.stack([generator3DGS.gaussians.D_scores] * 3, axis=-1)) * 1e4
         | 
| 86 | 
            +
                    generator3DGS.gaussians._features_rest = generator3DGS.gaussians._features_rest * 1e4
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    return image_np
         | 
| 89 | 
            +
                
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            def normalize(v):
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
                Normalize a vector to unit length.
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                Parameters:
         | 
| 97 | 
            +
                    v (np.ndarray): Input vector.
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                Returns:
         | 
| 100 | 
            +
                    np.ndarray: Unit vector in the same direction as `v`.
         | 
| 101 | 
            +
                """
         | 
| 102 | 
            +
                return v / np.linalg.norm(v)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            def look_at_rotation(camera_position: np.ndarray, target: np.ndarray, world_up=np.array([0, 1, 0])):
         | 
| 105 | 
            +
                """
         | 
| 106 | 
            +
                Compute a rotation matrix for a camera looking at a target point.
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                Parameters:
         | 
| 109 | 
            +
                    camera_position (np.ndarray): The 3D position of the camera.
         | 
| 110 | 
            +
                    target (np.ndarray): The point the camera should look at.
         | 
| 111 | 
            +
                    world_up (np.ndarray): A vector that defines the global 'up' direction.
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                Returns:
         | 
| 114 | 
            +
                    np.ndarray: A 3x3 rotation matrix (camera-to-world) with columns [right, up, forward].
         | 
| 115 | 
            +
                """
         | 
| 116 | 
            +
                z_axis = normalize(target - camera_position)         # Forward direction
         | 
| 117 | 
            +
                x_axis = normalize(np.cross(world_up, z_axis))       # Right direction
         | 
| 118 | 
            +
                y_axis = np.cross(z_axis, x_axis)                    # Recomputed up
         | 
| 119 | 
            +
                return np.stack([x_axis, y_axis, z_axis], axis=1)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                
         | 
| 122 | 
            +
            def generate_circular_camera_path(existing_cameras: List[Camera], N: int = 12, radius_scale: float = 1.0, d: float = 2.0) -> List[Camera]:
         | 
| 123 | 
            +
                """
         | 
| 124 | 
            +
                Generate a circular path of cameras around an existing camera group, 
         | 
| 125 | 
            +
                with each new camera oriented to look at the average viewing direction.
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                Parameters:
         | 
| 128 | 
            +
                    existing_cameras (List[Camera]): List of existing camera objects to estimate average orientation and layout.
         | 
| 129 | 
            +
                    N (int): Number of new cameras to generate along the circular path.
         | 
| 130 | 
            +
                    radius_scale (float): Scale factor to adjust the radius of the circle.
         | 
| 131 | 
            +
                    d (float): Distance ahead of each camera used to estimate its look-at point.
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                Returns:
         | 
| 134 | 
            +
                    List[Camera]: A list of newly generated Camera objects forming a circular path and oriented toward a shared view center.
         | 
| 135 | 
            +
                """
         | 
| 136 | 
            +
                # Step 1: Compute average camera position
         | 
| 137 | 
            +
                center = np.mean([cam.T for cam in existing_cameras], axis=0)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                # Estimate where each camera is looking
         | 
| 140 | 
            +
                # d denotes how far ahead each camera sees — you can scale this
         | 
| 141 | 
            +
                look_targets = [cam.T + cam.R[:, 2] * d for cam in existing_cameras]
         | 
| 142 | 
            +
                center_of_view = np.mean(look_targets, axis=0)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                # Step 2: Define circular plane basis using fixed up vector
         | 
| 145 | 
            +
                avg_forward = normalize(np.mean([cam.R[:, 2] for cam in existing_cameras], axis=0))
         | 
| 146 | 
            +
                up_guess = np.array([0, 1, 0])
         | 
| 147 | 
            +
                right = normalize(np.cross(avg_forward, up_guess))
         | 
| 148 | 
            +
                up = normalize(np.cross(right, avg_forward))
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                # Step 3: Estimate radius
         | 
| 151 | 
            +
                avg_radius = np.mean([np.linalg.norm(cam.T - center) for cam in existing_cameras]) * radius_scale
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                # Step 4: Create cameras on a circular path
         | 
| 154 | 
            +
                angles = np.linspace(0, 2 * np.pi, N, endpoint=False)
         | 
| 155 | 
            +
                reference_cam = existing_cameras[0]
         | 
| 156 | 
            +
                new_cameras = []
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                
         | 
| 159 | 
            +
                for i, a in enumerate(angles):
         | 
| 160 | 
            +
                    position = center + avg_radius * (np.cos(a) * right + np.sin(a) * up)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    if d < 1e-5 or radius_scale < 1e-5:
         | 
| 163 | 
            +
                        # Use same orientation as the first camera
         | 
| 164 | 
            +
                        R = reference_cam.R.copy()
         | 
| 165 | 
            +
                    else:
         | 
| 166 | 
            +
                        # Change orientation
         | 
| 167 | 
            +
                        R = look_at_rotation(position, center_of_view)
         | 
| 168 | 
            +
                    new_cameras.append(Camera(
         | 
| 169 | 
            +
                        R=R, 
         | 
| 170 | 
            +
                        T=position,                                   # New position
         | 
| 171 | 
            +
                        FoVx=reference_cam.FoVx,
         | 
| 172 | 
            +
                        FoVy=reference_cam.FoVy,
         | 
| 173 | 
            +
                        resolution=(reference_cam.image_width, reference_cam.image_height),
         | 
| 174 | 
            +
                        colmap_id=-1,
         | 
| 175 | 
            +
                        depth_params=None,
         | 
| 176 | 
            +
                        image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
         | 
| 177 | 
            +
                        invdepthmap=None,
         | 
| 178 | 
            +
                        image_name=f"circular_a={a:.3f}",
         | 
| 179 | 
            +
                        uid=i
         | 
| 180 | 
            +
                    ))
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                return new_cameras
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            def save_numpy_frames_as_gif(frames, output_path="animation.gif", duration=100):
         | 
| 186 | 
            +
                """
         | 
| 187 | 
            +
                Save a list of RGB NumPy frames as a looping GIF animation.
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                Parameters:
         | 
| 190 | 
            +
                    frames (List[np.ndarray]): List of RGB images as uint8 NumPy arrays (shape HxWx3).
         | 
| 191 | 
            +
                    output_path (str): Path to save the output GIF.
         | 
| 192 | 
            +
                    duration (int): Duration per frame in milliseconds.
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                Returns:
         | 
| 195 | 
            +
                    None
         | 
| 196 | 
            +
                """
         | 
| 197 | 
            +
                pil_frames = [Image.fromarray(f) for f in frames]
         | 
| 198 | 
            +
                pil_frames[0].save(
         | 
| 199 | 
            +
                    output_path,
         | 
| 200 | 
            +
                    save_all=True,
         | 
| 201 | 
            +
                    append_images=pil_frames[1:],
         | 
| 202 | 
            +
                    duration=duration,  # duration per frame in ms
         | 
| 203 | 
            +
                    loop=0
         | 
| 204 | 
            +
                )
         | 
| 205 | 
            +
                print(f"GIF saved to: {output_path}")
         | 
| 206 | 
            +
             | 
| 207 | 
            +
            def center_crop_frame(frame: np.ndarray, crop_fraction: float) -> np.ndarray:
         | 
| 208 | 
            +
                """
         | 
| 209 | 
            +
                Crop the central region of the frame by the given fraction.
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                Parameters:
         | 
| 212 | 
            +
                    frame (np.ndarray): Input RGB image (H, W, 3).
         | 
| 213 | 
            +
                    crop_fraction (float): Fraction of the original size to retain (e.g., 0.8 keeps 80%).
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                Returns:
         | 
| 216 | 
            +
                    np.ndarray: Cropped RGB image.
         | 
| 217 | 
            +
                """
         | 
| 218 | 
            +
                if crop_fraction >= 1.0:
         | 
| 219 | 
            +
                    return frame
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                h, w, _ = frame.shape
         | 
| 222 | 
            +
                new_h, new_w = int(h * crop_fraction), int(w * crop_fraction)
         | 
| 223 | 
            +
                start_y = (h - new_h) // 2
         | 
| 224 | 
            +
                start_x = (w - new_w) // 2
         | 
| 225 | 
            +
                return frame[start_y:start_y + new_h, start_x:start_x + new_w, :]
         | 
| 226 | 
            +
             | 
| 227 | 
            +
             | 
| 228 | 
            +
             | 
| 229 | 
            +
            def generate_smooth_closed_camera_path(existing_cameras: List[Camera], N: int = 120, d: float = 2.0, s=.25) -> List[Camera]:
         | 
| 230 | 
            +
                """
         | 
| 231 | 
            +
                Generate a smooth, closed path interpolating the positions of existing cameras.
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                Parameters:
         | 
| 234 | 
            +
                    existing_cameras (List[Camera]): List of existing cameras.
         | 
| 235 | 
            +
                    N (int): Number of points (cameras) to sample along the smooth path.
         | 
| 236 | 
            +
                    d (float): Distance ahead for estimating the center of view.
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                Returns:
         | 
| 239 | 
            +
                    List[Camera]: A list of smoothly moving Camera objects along a closed loop.
         | 
| 240 | 
            +
                """
         | 
| 241 | 
            +
                # Step 1: Extract camera positions
         | 
| 242 | 
            +
                positions = np.array([cam.T for cam in existing_cameras])
         | 
| 243 | 
            +
                
         | 
| 244 | 
            +
                # Step 2: Estimate center of view
         | 
| 245 | 
            +
                look_targets = [cam.T + cam.R[:, 2] * d for cam in existing_cameras]
         | 
| 246 | 
            +
                center_of_view = np.mean(look_targets, axis=0)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                # Step 3: Fit a smooth closed spline through the positions
         | 
| 249 | 
            +
                positions = np.vstack([positions, positions[0]])  # close the loop
         | 
| 250 | 
            +
                tck, u = splprep(positions.T, s=s, per=True)  # periodic=True for closed loop
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                # Step 4: Sample points along the spline
         | 
| 253 | 
            +
                u_fine = np.linspace(0, 1, N)
         | 
| 254 | 
            +
                smooth_path = np.stack(splev(u_fine, tck), axis=-1)
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                # Step 5: Generate cameras along the smooth path
         | 
| 257 | 
            +
                reference_cam = existing_cameras[0]
         | 
| 258 | 
            +
                new_cameras = []
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                for i, pos in enumerate(smooth_path):
         | 
| 261 | 
            +
                    R = look_at_rotation(pos, center_of_view)
         | 
| 262 | 
            +
                    new_cameras.append(Camera(
         | 
| 263 | 
            +
                        R=R,
         | 
| 264 | 
            +
                        T=pos,
         | 
| 265 | 
            +
                        FoVx=reference_cam.FoVx,
         | 
| 266 | 
            +
                        FoVy=reference_cam.FoVy,
         | 
| 267 | 
            +
                        resolution=(reference_cam.image_width, reference_cam.image_height),
         | 
| 268 | 
            +
                        colmap_id=-1,
         | 
| 269 | 
            +
                        depth_params=None,
         | 
| 270 | 
            +
                        image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
         | 
| 271 | 
            +
                        invdepthmap=None,
         | 
| 272 | 
            +
                        image_name=f"smooth_path_i={i}",
         | 
| 273 | 
            +
                        uid=i
         | 
| 274 | 
            +
                    ))
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                return new_cameras
         | 
| 277 | 
            +
             | 
| 278 | 
            +
             | 
| 279 | 
            +
            def save_numpy_frames_as_mp4(frames, output_path="animation.mp4", fps=10, center_crop: float = 1.0):
         | 
| 280 | 
            +
                """
         | 
| 281 | 
            +
                Save a list of RGB NumPy frames as an MP4 video with optional center cropping.
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                Parameters:
         | 
| 284 | 
            +
                    frames (List[np.ndarray]): List of RGB images as uint8 NumPy arrays (shape HxWx3).
         | 
| 285 | 
            +
                    output_path (str): Path to save the output MP4.
         | 
| 286 | 
            +
                    fps (int): Frames per second for playback speed.
         | 
| 287 | 
            +
                    center_crop (float): Fraction (0 < center_crop <= 1.0) of central region to retain. 
         | 
| 288 | 
            +
                                         Use 1.0 for no cropping; 0.8 to crop to 80% center region.
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                Returns:
         | 
| 291 | 
            +
                    None
         | 
| 292 | 
            +
                """
         | 
| 293 | 
            +
                with imageio.get_writer(output_path, fps=fps, codec='libx264', quality=8) as writer:
         | 
| 294 | 
            +
                    for frame in frames:
         | 
| 295 | 
            +
                        cropped = center_crop_frame(frame, center_crop)
         | 
| 296 | 
            +
                        writer.append_data(cropped)
         | 
| 297 | 
            +
                print(f"MP4 saved to: {output_path}")
         | 
| 298 | 
            +
             | 
| 299 | 
            +
             | 
| 300 | 
            +
                
         | 
| 301 | 
            +
            def put_text_on_image(img: np.ndarray, text: str) -> np.ndarray:
         | 
| 302 | 
            +
                """
         | 
| 303 | 
            +
                Draws multiline white text on a copy of the input image, positioned near the bottom
         | 
| 304 | 
            +
                and around 80% of the image width. Handles '\n' characters to split text into multiple lines.
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                Args:
         | 
| 307 | 
            +
                    img (np.ndarray): Input image as a (H, W, 3) uint8 numpy array.
         | 
| 308 | 
            +
                    text (str): Text string to draw on the image. Newlines '\n' are treated as line breaks.
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                Returns:
         | 
| 311 | 
            +
                    np.ndarray: The output image with the text drawn on it.
         | 
| 312 | 
            +
                
         | 
| 313 | 
            +
                Notes:
         | 
| 314 | 
            +
                    - The function automatically adjusts line spacing and prevents text from going outside the image.
         | 
| 315 | 
            +
                    - Text is drawn in white with small font size (0.5) for minimal visual impact.
         | 
| 316 | 
            +
                """
         | 
| 317 | 
            +
                img = img.copy()
         | 
| 318 | 
            +
                height, width, _ = img.shape
         | 
| 319 | 
            +
                
         | 
| 320 | 
            +
                font = cv2.FONT_HERSHEY_SIMPLEX
         | 
| 321 | 
            +
                font_scale = 1.
         | 
| 322 | 
            +
                color = (255, 255, 255)
         | 
| 323 | 
            +
                thickness = 2
         | 
| 324 | 
            +
                line_spacing = 5  # extra pixels between lines
         | 
| 325 | 
            +
                
         | 
| 326 | 
            +
                lines = text.split('\n')
         | 
| 327 | 
            +
                
         | 
| 328 | 
            +
                # Precompute the maximum text width to adjust starting x
         | 
| 329 | 
            +
                max_text_width = max(cv2.getTextSize(line, font, font_scale, thickness)[0][0] for line in lines)
         | 
| 330 | 
            +
                
         | 
| 331 | 
            +
                x = int(0.8 * width)
         | 
| 332 | 
            +
                x = min(x, width - max_text_width - 30)  # margin on right
         | 
| 333 | 
            +
                #x = int(0.03 * width)
         | 
| 334 | 
            +
                
         | 
| 335 | 
            +
                # Start near the bottom, but move up depending on number of lines
         | 
| 336 | 
            +
                total_text_height = len(lines) * (cv2.getTextSize('A', font, font_scale, thickness)[0][1] + line_spacing)
         | 
| 337 | 
            +
                y_start = int(height*0.9) - total_text_height  # 30 pixels from bottom
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                for i, line in enumerate(lines):
         | 
| 340 | 
            +
                    y = y_start + i * (cv2.getTextSize(line, font, font_scale, thickness)[0][1] + line_spacing)
         | 
| 341 | 
            +
                    cv2.putText(img, line, (x, y), font, font_scale, color, thickness, cv2.LINE_AA)
         | 
| 342 | 
            +
                
         | 
| 343 | 
            +
                return img
         | 
| 344 | 
            +
             | 
| 345 | 
            +
             | 
| 346 | 
            +
             | 
| 347 | 
            +
             | 
| 348 | 
            +
            def catmull_rom_spline(P0, P1, P2, P3, n_points=20):
         | 
| 349 | 
            +
                """
         | 
| 350 | 
            +
                Compute Catmull-Rom spline segment between P1 and P2.
         | 
| 351 | 
            +
                """
         | 
| 352 | 
            +
                t = np.linspace(0, 1, n_points)[:, None]
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                M = 0.5 * np.array([
         | 
| 355 | 
            +
                    [-1,  3, -3, 1],
         | 
| 356 | 
            +
                    [ 2, -5,  4, -1],
         | 
| 357 | 
            +
                    [-1,  0,  1, 0],
         | 
| 358 | 
            +
                    [ 0,  2,  0, 0]
         | 
| 359 | 
            +
                ])
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                G = np.stack([P0, P1, P2, P3], axis=0)
         | 
| 362 | 
            +
                T = np.concatenate([t**3, t**2, t, np.ones_like(t)], axis=1)
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                return T @ M @ G
         | 
| 365 | 
            +
             | 
| 366 | 
            +
            def sort_cameras_pca(existing_cameras: List[Camera]):
         | 
| 367 | 
            +
                """
         | 
| 368 | 
            +
                Sort cameras along the main PCA axis.
         | 
| 369 | 
            +
                """
         | 
| 370 | 
            +
                positions = np.array([cam.T for cam in existing_cameras])
         | 
| 371 | 
            +
                pca = PCA(n_components=1)
         | 
| 372 | 
            +
                scores = pca.fit_transform(positions)
         | 
| 373 | 
            +
                sorted_indices = np.argsort(scores[:, 0])
         | 
| 374 | 
            +
                return sorted_indices
         | 
| 375 | 
            +
             | 
| 376 | 
            +
            def generate_fully_smooth_cameras(existing_cameras: List[Camera], 
         | 
| 377 | 
            +
                                              n_selected: int = 30, 
         | 
| 378 | 
            +
                                              n_points_per_segment: int = 20, 
         | 
| 379 | 
            +
                                              d: float = 2.0,
         | 
| 380 | 
            +
                                              closed: bool = False) -> List[Camera]:
         | 
| 381 | 
            +
                """
         | 
| 382 | 
            +
                Generate a fully smooth camera path using PCA ordering, global Catmull-Rom spline for positions, and global SLERP for orientations.
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                Args:
         | 
| 385 | 
            +
                    existing_cameras (List[Camera]): List of input cameras.
         | 
| 386 | 
            +
                    n_selected (int): Number of cameras to select after sorting.
         | 
| 387 | 
            +
                    n_points_per_segment (int): Number of interpolated points per spline segment.
         | 
| 388 | 
            +
                    d (float): Distance ahead for estimating center of view.
         | 
| 389 | 
            +
                    closed (bool): Whether to close the path.
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                Returns:
         | 
| 392 | 
            +
                    List[Camera]: List of smoothly moving Camera objects.
         | 
| 393 | 
            +
                """
         | 
| 394 | 
            +
                # 1. Sort cameras along PCA axis
         | 
| 395 | 
            +
                sorted_indices = sort_cameras_pca(existing_cameras)
         | 
| 396 | 
            +
                sorted_cameras = [existing_cameras[i] for i in sorted_indices]
         | 
| 397 | 
            +
                positions = np.array([cam.T for cam in sorted_cameras])
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                # 2. Subsample uniformly
         | 
| 400 | 
            +
                idx = np.linspace(0, len(positions) - 1, n_selected).astype(int)
         | 
| 401 | 
            +
                sampled_positions = positions[idx]
         | 
| 402 | 
            +
                sampled_cameras = [sorted_cameras[i] for i in idx]
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                # 3. Prepare for Catmull-Rom
         | 
| 405 | 
            +
                if closed:
         | 
| 406 | 
            +
                    sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
         | 
| 407 | 
            +
                else:
         | 
| 408 | 
            +
                    sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                # 4. Generate smooth path positions
         | 
| 411 | 
            +
                path_positions = []
         | 
| 412 | 
            +
                for i in range(1, len(sampled_positions) - 2):
         | 
| 413 | 
            +
                    segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
         | 
| 414 | 
            +
                    path_positions.append(segment)
         | 
| 415 | 
            +
                path_positions = np.concatenate(path_positions, axis=0)
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                # 5. Global SLERP for rotations
         | 
| 418 | 
            +
                rotations = R.from_matrix([cam.R for cam in sampled_cameras])
         | 
| 419 | 
            +
                key_times = np.linspace(0, 1, len(rotations))
         | 
| 420 | 
            +
                slerp = Slerp(key_times, rotations)
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                query_times = np.linspace(0, 1, len(path_positions))
         | 
| 423 | 
            +
                interpolated_rotations = slerp(query_times)
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                # 6. Generate Camera objects
         | 
| 426 | 
            +
                reference_cam = existing_cameras[0]
         | 
| 427 | 
            +
                smooth_cameras = []
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                for i, pos in enumerate(path_positions):
         | 
| 430 | 
            +
                    R_interp = interpolated_rotations[i].as_matrix()
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    smooth_cameras.append(Camera(
         | 
| 433 | 
            +
                        R=R_interp,
         | 
| 434 | 
            +
                        T=pos,
         | 
| 435 | 
            +
                        FoVx=reference_cam.FoVx,
         | 
| 436 | 
            +
                        FoVy=reference_cam.FoVy,
         | 
| 437 | 
            +
                        resolution=(reference_cam.image_width, reference_cam.image_height),
         | 
| 438 | 
            +
                        colmap_id=-1,
         | 
| 439 | 
            +
                        depth_params=None,
         | 
| 440 | 
            +
                        image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
         | 
| 441 | 
            +
                        invdepthmap=None,
         | 
| 442 | 
            +
                        image_name=f"fully_smooth_path_i={i}",
         | 
| 443 | 
            +
                        uid=i
         | 
| 444 | 
            +
                    ))
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                return smooth_cameras
         | 
| 447 | 
            +
             | 
| 448 | 
            +
             | 
| 449 | 
            +
            def plot_cameras_and_smooth_path_with_orientation(existing_cameras: List[Camera], smooth_cameras: List[Camera], scale: float = 0.1):
         | 
| 450 | 
            +
                """
         | 
| 451 | 
            +
                Plot input cameras and smooth path cameras with their orientations in 3D.
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                Args:
         | 
| 454 | 
            +
                    existing_cameras (List[Camera]): List of original input cameras.
         | 
| 455 | 
            +
                    smooth_cameras (List[Camera]): List of smooth path cameras.
         | 
| 456 | 
            +
                    scale (float): Length of orientation arrows.
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                Returns:
         | 
| 459 | 
            +
                    None
         | 
| 460 | 
            +
                """
         | 
| 461 | 
            +
                # Input cameras
         | 
| 462 | 
            +
                input_positions = np.array([cam.T for cam in existing_cameras])
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                # Smooth cameras
         | 
| 465 | 
            +
                smooth_positions = np.array([cam.T for cam in smooth_cameras])
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                fig = go.Figure()
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                # Plot input camera positions
         | 
| 470 | 
            +
                fig.add_trace(go.Scatter3d(
         | 
| 471 | 
            +
                    x=input_positions[:, 0], y=input_positions[:, 1], z=input_positions[:, 2],
         | 
| 472 | 
            +
                    mode='markers',
         | 
| 473 | 
            +
                    marker=dict(size=4, color='blue'),
         | 
| 474 | 
            +
                    name='Input Cameras'
         | 
| 475 | 
            +
                ))
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                # Plot smooth path positions
         | 
| 478 | 
            +
                fig.add_trace(go.Scatter3d(
         | 
| 479 | 
            +
                    x=smooth_positions[:, 0], y=smooth_positions[:, 1], z=smooth_positions[:, 2],
         | 
| 480 | 
            +
                    mode='lines+markers',
         | 
| 481 | 
            +
                    line=dict(color='red', width=3),
         | 
| 482 | 
            +
                    marker=dict(size=2, color='red'),
         | 
| 483 | 
            +
                    name='Smooth Path Cameras'
         | 
| 484 | 
            +
                ))
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                # Plot input camera orientations
         | 
| 487 | 
            +
                for cam in existing_cameras:
         | 
| 488 | 
            +
                    origin = cam.T
         | 
| 489 | 
            +
                    forward = cam.R[:, 2]  # Forward direction
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                    fig.add_trace(go.Cone(
         | 
| 492 | 
            +
                        x=[origin[0]], y=[origin[1]], z=[origin[2]],
         | 
| 493 | 
            +
                        u=[forward[0]], v=[forward[1]], w=[forward[2]],
         | 
| 494 | 
            +
                        colorscale=[[0, 'blue'], [1, 'blue']],
         | 
| 495 | 
            +
                        sizemode="absolute",
         | 
| 496 | 
            +
                        sizeref=scale,
         | 
| 497 | 
            +
                        anchor="tail",
         | 
| 498 | 
            +
                        showscale=False,
         | 
| 499 | 
            +
                        name='Input Camera Direction'
         | 
| 500 | 
            +
                    ))
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                # Plot smooth camera orientations
         | 
| 503 | 
            +
                for cam in smooth_cameras:
         | 
| 504 | 
            +
                    origin = cam.T
         | 
| 505 | 
            +
                    forward = cam.R[:, 2]  # Forward direction
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                    fig.add_trace(go.Cone(
         | 
| 508 | 
            +
                        x=[origin[0]], y=[origin[1]], z=[origin[2]],
         | 
| 509 | 
            +
                        u=[forward[0]], v=[forward[1]], w=[forward[2]],
         | 
| 510 | 
            +
                        colorscale=[[0, 'red'], [1, 'red']],
         | 
| 511 | 
            +
                        sizemode="absolute",
         | 
| 512 | 
            +
                        sizeref=scale,
         | 
| 513 | 
            +
                        anchor="tail",
         | 
| 514 | 
            +
                        showscale=False,
         | 
| 515 | 
            +
                        name='Smooth Camera Direction'
         | 
| 516 | 
            +
                    ))
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                fig.update_layout(
         | 
| 519 | 
            +
                    scene=dict(
         | 
| 520 | 
            +
                        xaxis_title='X',
         | 
| 521 | 
            +
                        yaxis_title='Y',
         | 
| 522 | 
            +
                        zaxis_title='Z',
         | 
| 523 | 
            +
                        aspectmode='data'
         | 
| 524 | 
            +
                    ),
         | 
| 525 | 
            +
                    title="Input Cameras and Smooth Path with Orientations",
         | 
| 526 | 
            +
                    margin=dict(l=0, r=0, b=0, t=30)
         | 
| 527 | 
            +
                )
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                fig.show()
         | 
| 530 | 
            +
             | 
| 531 | 
            +
             | 
| 532 | 
            +
            def solve_tsp_nearest_neighbor(points: np.ndarray):
         | 
| 533 | 
            +
                """
         | 
| 534 | 
            +
                Solve TSP approximately using nearest neighbor heuristic.
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                Args:
         | 
| 537 | 
            +
                    points (np.ndarray): (N, 3) array of points.
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                Returns:
         | 
| 540 | 
            +
                    List[int]: Optimal visiting order of points.
         | 
| 541 | 
            +
                """
         | 
| 542 | 
            +
                N = points.shape[0]
         | 
| 543 | 
            +
                dist = distance_matrix(points, points)
         | 
| 544 | 
            +
                visited = [0]
         | 
| 545 | 
            +
                unvisited = set(range(1, N))
         | 
| 546 | 
            +
             | 
| 547 | 
            +
                while unvisited:
         | 
| 548 | 
            +
                    last = visited[-1]
         | 
| 549 | 
            +
                    next_city = min(unvisited, key=lambda city: dist[last, city])
         | 
| 550 | 
            +
                    visited.append(next_city)
         | 
| 551 | 
            +
                    unvisited.remove(next_city)
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                return visited
         | 
| 554 | 
            +
             | 
| 555 | 
            +
            def solve_tsp_2opt(points: np.ndarray, n_iter: int = 1000) -> np.ndarray:
         | 
| 556 | 
            +
                """
         | 
| 557 | 
            +
                Solve TSP approximately using Nearest Neighbor + 2-Opt.
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                Args:
         | 
| 560 | 
            +
                    points (np.ndarray): Array of shape (N, D) with points.
         | 
| 561 | 
            +
                    n_iter (int): Number of 2-opt iterations.
         | 
| 562 | 
            +
             | 
| 563 | 
            +
                Returns:
         | 
| 564 | 
            +
                    np.ndarray: Ordered list of indices.
         | 
| 565 | 
            +
                """
         | 
| 566 | 
            +
                n_points = points.shape[0]
         | 
| 567 | 
            +
             | 
| 568 | 
            +
                # === 1. Start with Nearest Neighbor
         | 
| 569 | 
            +
                unvisited = list(range(n_points))
         | 
| 570 | 
            +
                current = unvisited.pop(0)
         | 
| 571 | 
            +
                path = [current]
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                while unvisited:
         | 
| 574 | 
            +
                    dists = np.linalg.norm(points[unvisited] - points[current], axis=1)
         | 
| 575 | 
            +
                    next_idx = unvisited[np.argmin(dists)]
         | 
| 576 | 
            +
                    unvisited.remove(next_idx)
         | 
| 577 | 
            +
                    path.append(next_idx)
         | 
| 578 | 
            +
                    current = next_idx
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                # === 2. Apply 2-Opt improvements
         | 
| 581 | 
            +
                def path_length(path):
         | 
| 582 | 
            +
                    return np.sum(np.linalg.norm(points[path[i]] - points[path[i+1]], axis=0) for i in range(len(path)-1))
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                best_length = path_length(path)
         | 
| 585 | 
            +
                improved = True
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                for _ in range(n_iter):
         | 
| 588 | 
            +
                    if not improved:
         | 
| 589 | 
            +
                        break
         | 
| 590 | 
            +
                    improved = False
         | 
| 591 | 
            +
                    for i in range(1, n_points - 2):
         | 
| 592 | 
            +
                        for j in range(i + 1, n_points):
         | 
| 593 | 
            +
                            if j - i == 1: continue
         | 
| 594 | 
            +
                            new_path = path[:i] + path[i:j][::-1] + path[j:]
         | 
| 595 | 
            +
                            new_length = path_length(new_path)
         | 
| 596 | 
            +
                            if new_length < best_length:
         | 
| 597 | 
            +
                                path = new_path
         | 
| 598 | 
            +
                                best_length = new_length
         | 
| 599 | 
            +
                                improved = True
         | 
| 600 | 
            +
                                break
         | 
| 601 | 
            +
                        if improved:
         | 
| 602 | 
            +
                            break
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                return np.array(path)
         | 
| 605 | 
            +
             | 
| 606 | 
            +
            def generate_fully_smooth_cameras_with_tsp(existing_cameras: List[Camera], 
         | 
| 607 | 
            +
                                                       n_selected: int = 30, 
         | 
| 608 | 
            +
                                                       n_points_per_segment: int = 20, 
         | 
| 609 | 
            +
                                                       d: float = 2.0,
         | 
| 610 | 
            +
                                                       closed: bool = False) -> List[Camera]:
         | 
| 611 | 
            +
                """
         | 
| 612 | 
            +
                Generate a fully smooth camera path using TSP ordering, global Catmull-Rom spline for positions, and global SLERP for orientations.
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                Args:
         | 
| 615 | 
            +
                    existing_cameras (List[Camera]): List of input cameras.
         | 
| 616 | 
            +
                    n_selected (int): Number of cameras to select after ordering.
         | 
| 617 | 
            +
                    n_points_per_segment (int): Number of interpolated points per spline segment.
         | 
| 618 | 
            +
                    d (float): Distance ahead for estimating center of view.
         | 
| 619 | 
            +
                    closed (bool): Whether to close the path.
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                Returns:
         | 
| 622 | 
            +
                    List[Camera]: List of smoothly moving Camera objects.
         | 
| 623 | 
            +
                """
         | 
| 624 | 
            +
                positions = np.array([cam.T for cam in existing_cameras])
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                # 1. Solve approximate TSP
         | 
| 627 | 
            +
                order = solve_tsp_nearest_neighbor(positions)
         | 
| 628 | 
            +
                ordered_cameras = [existing_cameras[i] for i in order]
         | 
| 629 | 
            +
                ordered_positions = positions[order]
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                # 2. Subsample uniformly
         | 
| 632 | 
            +
                idx = np.linspace(0, len(ordered_positions) - 1, n_selected).astype(int)
         | 
| 633 | 
            +
                sampled_positions = ordered_positions[idx]
         | 
| 634 | 
            +
                sampled_cameras = [ordered_cameras[i] for i in idx]
         | 
| 635 | 
            +
             | 
| 636 | 
            +
                # 3. Prepare for Catmull-Rom
         | 
| 637 | 
            +
                if closed:
         | 
| 638 | 
            +
                    sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
         | 
| 639 | 
            +
                else:
         | 
| 640 | 
            +
                    sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                # 4. Generate smooth path positions
         | 
| 643 | 
            +
                path_positions = []
         | 
| 644 | 
            +
                for i in range(1, len(sampled_positions) - 2):
         | 
| 645 | 
            +
                    segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
         | 
| 646 | 
            +
                    path_positions.append(segment)
         | 
| 647 | 
            +
                path_positions = np.concatenate(path_positions, axis=0)
         | 
| 648 | 
            +
             | 
| 649 | 
            +
                # 5. Global SLERP for rotations
         | 
| 650 | 
            +
                rotations = R.from_matrix([cam.R for cam in sampled_cameras])
         | 
| 651 | 
            +
                key_times = np.linspace(0, 1, len(rotations))
         | 
| 652 | 
            +
                slerp = Slerp(key_times, rotations)
         | 
| 653 | 
            +
             | 
| 654 | 
            +
                query_times = np.linspace(0, 1, len(path_positions))
         | 
| 655 | 
            +
                interpolated_rotations = slerp(query_times)
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                # 6. Generate Camera objects
         | 
| 658 | 
            +
                reference_cam = existing_cameras[0]
         | 
| 659 | 
            +
                smooth_cameras = []
         | 
| 660 | 
            +
             | 
| 661 | 
            +
                for i, pos in enumerate(path_positions):
         | 
| 662 | 
            +
                    R_interp = interpolated_rotations[i].as_matrix()
         | 
| 663 | 
            +
             | 
| 664 | 
            +
                    smooth_cameras.append(Camera(
         | 
| 665 | 
            +
                        R=R_interp,
         | 
| 666 | 
            +
                        T=pos,
         | 
| 667 | 
            +
                        FoVx=reference_cam.FoVx,
         | 
| 668 | 
            +
                        FoVy=reference_cam.FoVy,
         | 
| 669 | 
            +
                        resolution=(reference_cam.image_width, reference_cam.image_height),
         | 
| 670 | 
            +
                        colmap_id=-1,
         | 
| 671 | 
            +
                        depth_params=None,
         | 
| 672 | 
            +
                        image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
         | 
| 673 | 
            +
                        invdepthmap=None,
         | 
| 674 | 
            +
                        image_name=f"fully_smooth_path_i={i}",
         | 
| 675 | 
            +
                        uid=i
         | 
| 676 | 
            +
                    ))
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                return smooth_cameras
         | 
| 679 | 
            +
             | 
| 680 | 
            +
            from typing import List
         | 
| 681 | 
            +
            import numpy as np
         | 
| 682 | 
            +
            from sklearn.mixture import GaussianMixture
         | 
| 683 | 
            +
            from scipy.spatial.transform import Rotation as R, Slerp
         | 
| 684 | 
            +
            from PIL import Image
         | 
| 685 | 
            +
             | 
| 686 | 
            +
            def generate_clustered_smooth_cameras_with_tsp(existing_cameras: List[Camera], 
         | 
| 687 | 
            +
                                                            n_selected: int = 30, 
         | 
| 688 | 
            +
                                                            n_points_per_segment: int = 20, 
         | 
| 689 | 
            +
                                                            d: float = 2.0,
         | 
| 690 | 
            +
                                                            n_clusters: int = 5,
         | 
| 691 | 
            +
                                                            closed: bool = False) -> List[Camera]:
         | 
| 692 | 
            +
                """
         | 
| 693 | 
            +
                Generate a fully smooth camera path using clustering + TSP between nearest cluster centers + TSP inside clusters.
         | 
| 694 | 
            +
                Positions are normalized before clustering and denormalized before generating final cameras.
         | 
| 695 | 
            +
             | 
| 696 | 
            +
                Args:
         | 
| 697 | 
            +
                    existing_cameras (List[Camera]): List of input cameras.
         | 
| 698 | 
            +
                    n_selected (int): Number of cameras to select after ordering.
         | 
| 699 | 
            +
                    n_points_per_segment (int): Number of interpolated points per spline segment.
         | 
| 700 | 
            +
                    d (float): Distance ahead for estimating center of view.
         | 
| 701 | 
            +
                    n_clusters (int): Number of GMM clusters.
         | 
| 702 | 
            +
                    closed (bool): Whether to close the path.
         | 
| 703 | 
            +
             | 
| 704 | 
            +
                Returns:
         | 
| 705 | 
            +
                    List[Camera]: Smooth path of Camera objects.
         | 
| 706 | 
            +
                """
         | 
| 707 | 
            +
                # Extract positions and rotations
         | 
| 708 | 
            +
                positions = np.array([cam.T for cam in existing_cameras])
         | 
| 709 | 
            +
                rotations = np.array([R.from_matrix(cam.R).as_quat() for cam in existing_cameras])
         | 
| 710 | 
            +
             | 
| 711 | 
            +
                # === Normalize positions
         | 
| 712 | 
            +
                mean_pos = np.mean(positions, axis=0)
         | 
| 713 | 
            +
                scale_pos = np.std(positions, axis=0)
         | 
| 714 | 
            +
                scale_pos[scale_pos == 0] = 1.0  # avoid division by zero
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                positions_normalized = (positions - mean_pos) / scale_pos
         | 
| 717 | 
            +
             | 
| 718 | 
            +
                # === Features for clustering (only positions, not rotations)
         | 
| 719 | 
            +
                features = positions_normalized
         | 
| 720 | 
            +
             | 
| 721 | 
            +
                # === 1. GMM clustering
         | 
| 722 | 
            +
                gmm = GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=42)
         | 
| 723 | 
            +
                cluster_labels = gmm.fit_predict(features)
         | 
| 724 | 
            +
             | 
| 725 | 
            +
                clusters = {}
         | 
| 726 | 
            +
                cluster_centers = []
         | 
| 727 | 
            +
             | 
| 728 | 
            +
                for cluster_id in range(n_clusters):
         | 
| 729 | 
            +
                    cluster_indices = np.where(cluster_labels == cluster_id)[0]
         | 
| 730 | 
            +
                    if len(cluster_indices) == 0:
         | 
| 731 | 
            +
                        continue
         | 
| 732 | 
            +
                    clusters[cluster_id] = cluster_indices
         | 
| 733 | 
            +
                    cluster_center = np.mean(features[cluster_indices], axis=0)
         | 
| 734 | 
            +
                    cluster_centers.append(cluster_center)
         | 
| 735 | 
            +
             | 
| 736 | 
            +
                cluster_centers = np.stack(cluster_centers)
         | 
| 737 | 
            +
             | 
| 738 | 
            +
                # === 2. Remap cluster centers to nearest existing cameras
         | 
| 739 | 
            +
                if False:
         | 
| 740 | 
            +
                    mapped_centers = []
         | 
| 741 | 
            +
                    for center in cluster_centers:
         | 
| 742 | 
            +
                        dists = np.linalg.norm(features - center, axis=1)
         | 
| 743 | 
            +
                        nearest_idx = np.argmin(dists)
         | 
| 744 | 
            +
                        mapped_centers.append(features[nearest_idx])
         | 
| 745 | 
            +
                    mapped_centers = np.stack(mapped_centers)
         | 
| 746 | 
            +
                    cluster_centers = mapped_centers
         | 
| 747 | 
            +
                # === 3. Solve TSP between mapped cluster centers
         | 
| 748 | 
            +
                cluster_order = solve_tsp_2opt(cluster_centers)
         | 
| 749 | 
            +
             | 
| 750 | 
            +
                # === 4. For each cluster, solve TSP inside cluster
         | 
| 751 | 
            +
                final_indices = []
         | 
| 752 | 
            +
                for cluster_id in cluster_order:
         | 
| 753 | 
            +
                    cluster_indices = clusters[cluster_id]
         | 
| 754 | 
            +
                    cluster_positions = features[cluster_indices]
         | 
| 755 | 
            +
             | 
| 756 | 
            +
                    if len(cluster_positions) == 1:
         | 
| 757 | 
            +
                        final_indices.append(cluster_indices[0])
         | 
| 758 | 
            +
                        continue
         | 
| 759 | 
            +
             | 
| 760 | 
            +
                    local_order = solve_tsp_nearest_neighbor(cluster_positions)
         | 
| 761 | 
            +
                    ordered_cluster_indices = cluster_indices[local_order]
         | 
| 762 | 
            +
                    final_indices.extend(ordered_cluster_indices)
         | 
| 763 | 
            +
             | 
| 764 | 
            +
                ordered_cameras = [existing_cameras[i] for i in final_indices]
         | 
| 765 | 
            +
                ordered_positions = positions_normalized[final_indices]
         | 
| 766 | 
            +
             | 
| 767 | 
            +
                # === 5. Subsample uniformly
         | 
| 768 | 
            +
                idx = np.linspace(0, len(ordered_positions) - 1, n_selected).astype(int)
         | 
| 769 | 
            +
                sampled_positions = ordered_positions[idx]
         | 
| 770 | 
            +
                sampled_cameras = [ordered_cameras[i] for i in idx]
         | 
| 771 | 
            +
             | 
| 772 | 
            +
                # === 6. Prepare for Catmull-Rom spline
         | 
| 773 | 
            +
                if closed:
         | 
| 774 | 
            +
                    sampled_positions = np.vstack([sampled_positions[-1], sampled_positions, sampled_positions[0], sampled_positions[1]])
         | 
| 775 | 
            +
                else:
         | 
| 776 | 
            +
                    sampled_positions = np.vstack([sampled_positions[0], sampled_positions, sampled_positions[-1], sampled_positions[-1]])
         | 
| 777 | 
            +
             | 
| 778 | 
            +
                # === 7. Smooth path positions
         | 
| 779 | 
            +
                path_positions = []
         | 
| 780 | 
            +
                for i in range(1, len(sampled_positions) - 2):
         | 
| 781 | 
            +
                    segment = catmull_rom_spline(sampled_positions[i-1], sampled_positions[i], sampled_positions[i+1], sampled_positions[i+2], n_points_per_segment)
         | 
| 782 | 
            +
                    path_positions.append(segment)
         | 
| 783 | 
            +
                path_positions = np.concatenate(path_positions, axis=0)
         | 
| 784 | 
            +
             | 
| 785 | 
            +
                # === 8. Denormalize
         | 
| 786 | 
            +
                path_positions = path_positions * scale_pos + mean_pos
         | 
| 787 | 
            +
             | 
| 788 | 
            +
                # === 9. SLERP for rotations
         | 
| 789 | 
            +
                rotations = R.from_matrix([cam.R for cam in sampled_cameras])
         | 
| 790 | 
            +
                key_times = np.linspace(0, 1, len(rotations))
         | 
| 791 | 
            +
                slerp = Slerp(key_times, rotations)
         | 
| 792 | 
            +
             | 
| 793 | 
            +
                query_times = np.linspace(0, 1, len(path_positions))
         | 
| 794 | 
            +
                interpolated_rotations = slerp(query_times)
         | 
| 795 | 
            +
             | 
| 796 | 
            +
                # === 10. Generate Camera objects
         | 
| 797 | 
            +
                reference_cam = existing_cameras[0]
         | 
| 798 | 
            +
                smooth_cameras = []
         | 
| 799 | 
            +
             | 
| 800 | 
            +
                for i, pos in enumerate(path_positions):
         | 
| 801 | 
            +
                    R_interp = interpolated_rotations[i].as_matrix()
         | 
| 802 | 
            +
             | 
| 803 | 
            +
                    smooth_cameras.append(Camera(
         | 
| 804 | 
            +
                        R=R_interp,
         | 
| 805 | 
            +
                        T=pos,
         | 
| 806 | 
            +
                        FoVx=reference_cam.FoVx,
         | 
| 807 | 
            +
                        FoVy=reference_cam.FoVy,
         | 
| 808 | 
            +
                        resolution=(reference_cam.image_width, reference_cam.image_height),
         | 
| 809 | 
            +
                        colmap_id=-1,
         | 
| 810 | 
            +
                        depth_params=None,
         | 
| 811 | 
            +
                        image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
         | 
| 812 | 
            +
                        invdepthmap=None,
         | 
| 813 | 
            +
                        image_name=f"clustered_smooth_path_i={i}",
         | 
| 814 | 
            +
                        uid=i
         | 
| 815 | 
            +
                    ))
         | 
| 816 | 
            +
             | 
| 817 | 
            +
                return smooth_cameras
         | 
| 818 | 
            +
             | 
| 819 | 
            +
             | 
| 820 | 
            +
            # def generate_clustered_path(existing_cameras: List[Camera], 
         | 
| 821 | 
            +
            #                              n_points_per_segment: int = 20, 
         | 
| 822 | 
            +
            #                              d: float = 2.0,
         | 
| 823 | 
            +
            #                              n_clusters: int = 5,
         | 
| 824 | 
            +
            #                              closed: bool = False) -> List[Camera]:
         | 
| 825 | 
            +
            #     """
         | 
| 826 | 
            +
            #     Generate a smooth camera path using GMM clustering and TSP on cluster centers.
         | 
| 827 | 
            +
             | 
| 828 | 
            +
            #     Args:
         | 
| 829 | 
            +
            #         existing_cameras (List[Camera]): List of input cameras.
         | 
| 830 | 
            +
            #         n_points_per_segment (int): Number of interpolated points per spline segment.
         | 
| 831 | 
            +
            #         d (float): Distance ahead for estimating center of view.
         | 
| 832 | 
            +
            #         n_clusters (int): Number of GMM clusters (zones).
         | 
| 833 | 
            +
            #         closed (bool): Whether to close the path.
         | 
| 834 | 
            +
             | 
| 835 | 
            +
            #     Returns:
         | 
| 836 | 
            +
            #         List[Camera]: Smooth path of Camera objects.
         | 
| 837 | 
            +
            #     """
         | 
| 838 | 
            +
            #     # Extract positions and rotations
         | 
| 839 | 
            +
            #     positions = np.array([cam.T for cam in existing_cameras])
         | 
| 840 | 
            +
             | 
| 841 | 
            +
            #     # === Normalize positions
         | 
| 842 | 
            +
            #     mean_pos = np.mean(positions, axis=0)
         | 
| 843 | 
            +
            #     scale_pos = np.std(positions, axis=0)
         | 
| 844 | 
            +
            #     scale_pos[scale_pos == 0] = 1.0
         | 
| 845 | 
            +
             | 
| 846 | 
            +
            #     positions_normalized = (positions - mean_pos) / scale_pos
         | 
| 847 | 
            +
             | 
| 848 | 
            +
            #     # === 1. GMM clustering (only positions)
         | 
| 849 | 
            +
            #     gmm = GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=42)
         | 
| 850 | 
            +
            #     cluster_labels = gmm.fit_predict(positions_normalized)
         | 
| 851 | 
            +
             | 
| 852 | 
            +
            #     cluster_centers = []
         | 
| 853 | 
            +
            #     for cluster_id in range(n_clusters):
         | 
| 854 | 
            +
            #         cluster_indices = np.where(cluster_labels == cluster_id)[0]
         | 
| 855 | 
            +
            #         if len(cluster_indices) == 0:
         | 
| 856 | 
            +
            #             continue
         | 
| 857 | 
            +
            #         cluster_center = np.mean(positions_normalized[cluster_indices], axis=0)
         | 
| 858 | 
            +
            #         cluster_centers.append(cluster_center)
         | 
| 859 | 
            +
             | 
| 860 | 
            +
            #     cluster_centers = np.stack(cluster_centers)
         | 
| 861 | 
            +
             | 
| 862 | 
            +
            #     # === 2. Solve TSP between cluster centers
         | 
| 863 | 
            +
            #     cluster_order = solve_tsp_2opt(cluster_centers)
         | 
| 864 | 
            +
             | 
| 865 | 
            +
            #     # === 3. Reorder cluster centers
         | 
| 866 | 
            +
            #     ordered_centers = cluster_centers[cluster_order]
         | 
| 867 | 
            +
             | 
| 868 | 
            +
            #     # === 4. Prepare Catmull-Rom spline
         | 
| 869 | 
            +
            #     if closed:
         | 
| 870 | 
            +
            #         ordered_centers = np.vstack([ordered_centers[-1], ordered_centers, ordered_centers[0], ordered_centers[1]])
         | 
| 871 | 
            +
            #     else:
         | 
| 872 | 
            +
            #         ordered_centers = np.vstack([ordered_centers[0], ordered_centers, ordered_centers[-1], ordered_centers[-1]])
         | 
| 873 | 
            +
             | 
| 874 | 
            +
            #     # === 5. Generate smooth path positions
         | 
| 875 | 
            +
            #     path_positions = []
         | 
| 876 | 
            +
            #     for i in range(1, len(ordered_centers) - 2):
         | 
| 877 | 
            +
            #         segment = catmull_rom_spline(ordered_centers[i-1], ordered_centers[i], ordered_centers[i+1], ordered_centers[i+2], n_points_per_segment)
         | 
| 878 | 
            +
            #         path_positions.append(segment)
         | 
| 879 | 
            +
            #     path_positions = np.concatenate(path_positions, axis=0)
         | 
| 880 | 
            +
             | 
| 881 | 
            +
            #     # === 6. Denormalize back
         | 
| 882 | 
            +
            #     path_positions = path_positions * scale_pos + mean_pos
         | 
| 883 | 
            +
             | 
| 884 | 
            +
            #     # === 7. Generate dummy rotations (constant forward facing)
         | 
| 885 | 
            +
            #     reference_cam = existing_cameras[0]
         | 
| 886 | 
            +
            #     default_rotation = R.from_matrix(reference_cam.R)
         | 
| 887 | 
            +
             | 
| 888 | 
            +
            #     # For simplicity, fixed rotation for all
         | 
| 889 | 
            +
            #     smooth_cameras = []
         | 
| 890 | 
            +
             | 
| 891 | 
            +
            #     for i, pos in enumerate(path_positions):
         | 
| 892 | 
            +
            #         R_interp = default_rotation.as_matrix()
         | 
| 893 | 
            +
             | 
| 894 | 
            +
            #         smooth_cameras.append(Camera(
         | 
| 895 | 
            +
            #             R=R_interp,
         | 
| 896 | 
            +
            #             T=pos,
         | 
| 897 | 
            +
            #             FoVx=reference_cam.FoVx,
         | 
| 898 | 
            +
            #             FoVy=reference_cam.FoVy,
         | 
| 899 | 
            +
            #             resolution=(reference_cam.image_width, reference_cam.image_height),
         | 
| 900 | 
            +
            #             colmap_id=-1,
         | 
| 901 | 
            +
            #             depth_params=None,
         | 
| 902 | 
            +
            #             image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
         | 
| 903 | 
            +
            #             invdepthmap=None,
         | 
| 904 | 
            +
            #             image_name=f"cluster_path_i={i}",
         | 
| 905 | 
            +
            #             uid=i
         | 
| 906 | 
            +
            #         ))
         | 
| 907 | 
            +
             | 
| 908 | 
            +
            #     return smooth_cameras
         | 
| 909 | 
            +
             | 
| 910 | 
            +
            from typing import List
         | 
| 911 | 
            +
            import numpy as np
         | 
| 912 | 
            +
            from sklearn.cluster import KMeans
         | 
| 913 | 
            +
            from scipy.spatial.transform import Rotation as R, Slerp
         | 
| 914 | 
            +
            from PIL import Image
         | 
| 915 | 
            +
             | 
| 916 | 
            +
            def generate_clustered_path(existing_cameras: List[Camera], 
         | 
| 917 | 
            +
                                         n_points_per_segment: int = 20, 
         | 
| 918 | 
            +
                                         d: float = 2.0,
         | 
| 919 | 
            +
                                         n_clusters: int = 5,
         | 
| 920 | 
            +
                                         closed: bool = False) -> List[Camera]:
         | 
| 921 | 
            +
                """
         | 
| 922 | 
            +
                Generate a smooth camera path using K-Means clustering and TSP on cluster centers.
         | 
| 923 | 
            +
             | 
| 924 | 
            +
                Args:
         | 
| 925 | 
            +
                    existing_cameras (List[Camera]): List of input cameras.
         | 
| 926 | 
            +
                    n_points_per_segment (int): Number of interpolated points per spline segment.
         | 
| 927 | 
            +
                    d (float): Distance ahead for estimating center of view.
         | 
| 928 | 
            +
                    n_clusters (int): Number of KMeans clusters (zones).
         | 
| 929 | 
            +
                    closed (bool): Whether to close the path.
         | 
| 930 | 
            +
             | 
| 931 | 
            +
                Returns:
         | 
| 932 | 
            +
                    List[Camera]: Smooth path of Camera objects.
         | 
| 933 | 
            +
                """
         | 
| 934 | 
            +
                # Extract positions
         | 
| 935 | 
            +
                positions = np.array([cam.T for cam in existing_cameras])
         | 
| 936 | 
            +
             | 
| 937 | 
            +
                # === Normalize positions
         | 
| 938 | 
            +
                mean_pos = np.mean(positions, axis=0)
         | 
| 939 | 
            +
                scale_pos = np.std(positions, axis=0)
         | 
| 940 | 
            +
                scale_pos[scale_pos == 0] = 1.0
         | 
| 941 | 
            +
             | 
| 942 | 
            +
                positions_normalized = (positions - mean_pos) / scale_pos
         | 
| 943 | 
            +
             | 
| 944 | 
            +
                # === 1. K-Means clustering (only positions)
         | 
| 945 | 
            +
                kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init='auto')
         | 
| 946 | 
            +
                cluster_labels = kmeans.fit_predict(positions_normalized)
         | 
| 947 | 
            +
             | 
| 948 | 
            +
                cluster_centers = []
         | 
| 949 | 
            +
                for cluster_id in range(n_clusters):
         | 
| 950 | 
            +
                    cluster_indices = np.where(cluster_labels == cluster_id)[0]
         | 
| 951 | 
            +
                    if len(cluster_indices) == 0:
         | 
| 952 | 
            +
                        continue
         | 
| 953 | 
            +
                    cluster_center = np.mean(positions_normalized[cluster_indices], axis=0)
         | 
| 954 | 
            +
                    cluster_centers.append(cluster_center)
         | 
| 955 | 
            +
             | 
| 956 | 
            +
                cluster_centers = np.stack(cluster_centers)
         | 
| 957 | 
            +
             | 
| 958 | 
            +
                # === 2. Solve TSP between cluster centers
         | 
| 959 | 
            +
                cluster_order = solve_tsp_2opt(cluster_centers)
         | 
| 960 | 
            +
             | 
| 961 | 
            +
                # === 3. Reorder cluster centers
         | 
| 962 | 
            +
                ordered_centers = cluster_centers[cluster_order]
         | 
| 963 | 
            +
             | 
| 964 | 
            +
                # === 4. Prepare Catmull-Rom spline
         | 
| 965 | 
            +
                if closed:
         | 
| 966 | 
            +
                    ordered_centers = np.vstack([ordered_centers[-1], ordered_centers, ordered_centers[0], ordered_centers[1]])
         | 
| 967 | 
            +
                else:
         | 
| 968 | 
            +
                    ordered_centers = np.vstack([ordered_centers[0], ordered_centers, ordered_centers[-1], ordered_centers[-1]])
         | 
| 969 | 
            +
             | 
| 970 | 
            +
                # === 5. Generate smooth path positions
         | 
| 971 | 
            +
                path_positions = []
         | 
| 972 | 
            +
                for i in range(1, len(ordered_centers) - 2):
         | 
| 973 | 
            +
                    segment = catmull_rom_spline(ordered_centers[i-1], ordered_centers[i], ordered_centers[i+1], ordered_centers[i+2], n_points_per_segment)
         | 
| 974 | 
            +
                    path_positions.append(segment)
         | 
| 975 | 
            +
                path_positions = np.concatenate(path_positions, axis=0)
         | 
| 976 | 
            +
             | 
| 977 | 
            +
                # === 6. Denormalize back
         | 
| 978 | 
            +
                path_positions = path_positions * scale_pos + mean_pos
         | 
| 979 | 
            +
             | 
| 980 | 
            +
                # === 7. Generate dummy rotations (constant forward facing)
         | 
| 981 | 
            +
                reference_cam = existing_cameras[0]
         | 
| 982 | 
            +
                default_rotation = R.from_matrix(reference_cam.R)
         | 
| 983 | 
            +
             | 
| 984 | 
            +
                # For simplicity, fixed rotation for all
         | 
| 985 | 
            +
                smooth_cameras = []
         | 
| 986 | 
            +
             | 
| 987 | 
            +
                for i, pos in enumerate(path_positions):
         | 
| 988 | 
            +
                    R_interp = default_rotation.as_matrix()
         | 
| 989 | 
            +
             | 
| 990 | 
            +
                    smooth_cameras.append(Camera(
         | 
| 991 | 
            +
                        R=R_interp,
         | 
| 992 | 
            +
                        T=pos,
         | 
| 993 | 
            +
                        FoVx=reference_cam.FoVx,
         | 
| 994 | 
            +
                        FoVy=reference_cam.FoVy,
         | 
| 995 | 
            +
                        resolution=(reference_cam.image_width, reference_cam.image_height),
         | 
| 996 | 
            +
                        colmap_id=-1,
         | 
| 997 | 
            +
                        depth_params=None,
         | 
| 998 | 
            +
                        image=Image.fromarray(np.zeros((reference_cam.image_height, reference_cam.image_width, 3), dtype=np.uint8)),
         | 
| 999 | 
            +
                        invdepthmap=None,
         | 
| 1000 | 
            +
                        image_name=f"cluster_path_i={i}",
         | 
| 1001 | 
            +
                        uid=i
         | 
| 1002 | 
            +
                    ))
         | 
| 1003 | 
            +
             | 
| 1004 | 
            +
                return smooth_cameras
         | 
| 1005 | 
            +
             | 
| 1006 | 
            +
             | 
| 1007 | 
            +
             | 
| 1008 | 
            +
             | 
| 1009 | 
            +
            def visualize_image_with_points(image, points):
         | 
| 1010 | 
            +
                """
         | 
| 1011 | 
            +
                Visualize an image with points overlaid on top. This is useful for correspondences visualizations
         | 
| 1012 | 
            +
             | 
| 1013 | 
            +
                Parameters:
         | 
| 1014 | 
            +
                - image: PIL Image object
         | 
| 1015 | 
            +
                - points: Numpy array of shape [N, 2] containing (x, y) coordinates of points
         | 
| 1016 | 
            +
             | 
| 1017 | 
            +
                Returns:
         | 
| 1018 | 
            +
                - None (displays the visualization)
         | 
| 1019 | 
            +
                """
         | 
| 1020 | 
            +
             | 
| 1021 | 
            +
                # Convert PIL image to numpy array
         | 
| 1022 | 
            +
                img_array = np.array(image)
         | 
| 1023 | 
            +
             | 
| 1024 | 
            +
                # Create a figure and axis
         | 
| 1025 | 
            +
                fig, ax = plt.subplots(figsize=(7,7))
         | 
| 1026 | 
            +
             | 
| 1027 | 
            +
                # Display the image
         | 
| 1028 | 
            +
                ax.imshow(img_array)
         | 
| 1029 | 
            +
             | 
| 1030 | 
            +
                # Scatter plot the points on top of the image
         | 
| 1031 | 
            +
                ax.scatter(points[:, 0], points[:, 1], color='red', marker='o', s=1)
         | 
| 1032 | 
            +
             | 
| 1033 | 
            +
                # Show the plot
         | 
| 1034 | 
            +
                plt.show()
         | 
| 1035 | 
            +
             | 
| 1036 | 
            +
             | 
| 1037 | 
            +
            def visualize_correspondences(image1, points1, image2, points2):
         | 
| 1038 | 
            +
                """
         | 
| 1039 | 
            +
                Visualize two images concatenated horizontally with key points and correspondences.
         | 
| 1040 | 
            +
             | 
| 1041 | 
            +
                Parameters:
         | 
| 1042 | 
            +
                - image1: PIL Image object (left image)
         | 
| 1043 | 
            +
                - points1: Numpy array of shape [N, 2] containing (x, y) coordinates of key points for image1
         | 
| 1044 | 
            +
                - image2: PIL Image object (right image)
         | 
| 1045 | 
            +
                - points2: Numpy array of shape [N, 2] containing (x, y) coordinates of key points for image2
         | 
| 1046 | 
            +
             | 
| 1047 | 
            +
                Returns:
         | 
| 1048 | 
            +
                - None (displays the visualization)
         | 
| 1049 | 
            +
                """
         | 
| 1050 | 
            +
             | 
| 1051 | 
            +
                # Concatenate images horizontally
         | 
| 1052 | 
            +
                concatenated_image = np.concatenate((np.array(image1), np.array(image2)), axis=1)
         | 
| 1053 | 
            +
             | 
| 1054 | 
            +
                # Create a figure and axis
         | 
| 1055 | 
            +
                fig, ax = plt.subplots(figsize=(10,10))
         | 
| 1056 | 
            +
             | 
| 1057 | 
            +
                # Display the concatenated image
         | 
| 1058 | 
            +
                ax.imshow(concatenated_image)
         | 
| 1059 | 
            +
             | 
| 1060 | 
            +
                # Plot key points on the left image
         | 
| 1061 | 
            +
                ax.scatter(points1[:, 0], points1[:, 1], color='red', marker='o', s=10)
         | 
| 1062 | 
            +
             | 
| 1063 | 
            +
                # Plot key points on the right image
         | 
| 1064 | 
            +
                ax.scatter(points2[:, 0] + image1.width, points2[:, 1], color='blue', marker='o', s=10)
         | 
| 1065 | 
            +
             | 
| 1066 | 
            +
                # Draw lines connecting corresponding key points
         | 
| 1067 | 
            +
                for i in range(len(points1)):
         | 
| 1068 | 
            +
                    ax.plot([points1[i, 0], points2[i, 0] + image1.width], [points1[i, 1], points2[i, 1]])#, color='green')
         | 
| 1069 | 
            +
             | 
| 1070 | 
            +
                # Show the plot
         | 
| 1071 | 
            +
                plt.show()
         | 
| 1072 | 
            +
             | 
    	
        submodules/RoMa/.gitignore
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.egg-info*
         | 
| 2 | 
            +
            *.vscode*
         | 
| 3 | 
            +
            *__pycache__*
         | 
| 4 | 
            +
            vis*
         | 
| 5 | 
            +
            workspace*
         | 
| 6 | 
            +
            .venv
         | 
| 7 | 
            +
            .DS_Store
         | 
| 8 | 
            +
            jobs/*
         | 
| 9 | 
            +
            *ignore_me*
         | 
| 10 | 
            +
            *.pth
         | 
| 11 | 
            +
            wandb*
         | 
    	
        submodules/RoMa/LICENSE
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            MIT License
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Copyright (c) 2023 Johan Edstedt
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 6 | 
            +
            of this software and associated documentation files (the "Software"), to deal
         | 
| 7 | 
            +
            in the Software without restriction, including without limitation the rights
         | 
| 8 | 
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 9 | 
            +
            copies of the Software, and to permit persons to whom the Software is
         | 
| 10 | 
            +
            furnished to do so, subject to the following conditions:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 13 | 
            +
            copies or substantial portions of the Software.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 16 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 17 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 18 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 19 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 20 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 21 | 
            +
            SOFTWARE.
         | 
    	
        submodules/RoMa/README.md
    ADDED
    
    | @@ -0,0 +1,123 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # 
         | 
| 2 | 
            +
            <p align="center">
         | 
| 3 | 
            +
              <h1 align="center"> <ins>RoMa</ins> 🏛️:<br> Robust Dense Feature Matching <br> ⭐CVPR 2024⭐</h1>
         | 
| 4 | 
            +
              <p align="center">
         | 
| 5 | 
            +
                <a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
         | 
| 6 | 
            +
                ·
         | 
| 7 | 
            +
                <a href="https://scholar.google.com/citations?user=HS2WuHkAAAAJ">Qiyu Sun</a>
         | 
| 8 | 
            +
                ·
         | 
| 9 | 
            +
                <a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg Bökman</a>
         | 
| 10 | 
            +
                ·
         | 
| 11 | 
            +
                <a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">Mårten Wadenbäck</a>
         | 
| 12 | 
            +
                ·
         | 
| 13 | 
            +
                <a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
         | 
| 14 | 
            +
              </p>
         | 
| 15 | 
            +
              <h2 align="center"><p>
         | 
| 16 | 
            +
                <a href="https://arxiv.org/abs/2305.15404" align="center">Paper</a> | 
         | 
| 17 | 
            +
                <a href="https://parskatt.github.io/RoMa" align="center">Project Page</a>
         | 
| 18 | 
            +
              </p></h2>
         | 
| 19 | 
            +
              <div align="center"></div>
         | 
| 20 | 
            +
            </p>
         | 
| 21 | 
            +
            <br/>
         | 
| 22 | 
            +
            <p align="center">
         | 
| 23 | 
            +
                <img src="https://github.com/Parskatt/RoMa/assets/22053118/15d8fea7-aa6d-479f-8a93-350d950d006b" alt="example" width=80%>
         | 
| 24 | 
            +
                <br>
         | 
| 25 | 
            +
                <em>RoMa is the robust dense feature matcher capable of estimating pixel-dense warps and reliable certainties for almost any image pair.</em>
         | 
| 26 | 
            +
            </p>
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            ## Setup/Install
         | 
| 29 | 
            +
            In your python environment (tested on Linux python 3.10), run:
         | 
| 30 | 
            +
            ```bash
         | 
| 31 | 
            +
            pip install -e .
         | 
| 32 | 
            +
            ```
         | 
| 33 | 
            +
            ## Demo / How to Use
         | 
| 34 | 
            +
            We provide two demos in the [demos folder](demo).
         | 
| 35 | 
            +
            Here's the gist of it:
         | 
| 36 | 
            +
            ```python
         | 
| 37 | 
            +
            from romatch import roma_outdoor
         | 
| 38 | 
            +
            roma_model = roma_outdoor(device=device)
         | 
| 39 | 
            +
            # Match
         | 
| 40 | 
            +
            warp, certainty = roma_model.match(imA_path, imB_path, device=device)
         | 
| 41 | 
            +
            # Sample matches for estimation
         | 
| 42 | 
            +
            matches, certainty = roma_model.sample(warp, certainty)
         | 
| 43 | 
            +
            # Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1])
         | 
| 44 | 
            +
            kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
         | 
| 45 | 
            +
            # Find a fundamental matrix (or anything else of interest)
         | 
| 46 | 
            +
            F, mask = cv2.findFundamentalMat(
         | 
| 47 | 
            +
                kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
         | 
| 48 | 
            +
            )
         | 
| 49 | 
            +
            ```
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            **New**: You can also match arbitrary keypoints with RoMa. See [match_keypoints](romatch/models/matcher.py) in RegressionMatcher.
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            ## Settings
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            ### Resolution
         | 
| 56 | 
            +
            By default RoMa uses an initial resolution of (560,560) which is then upsampled to (864,864). 
         | 
| 57 | 
            +
            You can change this at construction (see roma_outdoor kwargs).
         | 
| 58 | 
            +
            You can also change this later, by changing the roma_model.w_resized, roma_model.h_resized, and roma_model.upsample_res.
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            ### Sampling
         | 
| 61 | 
            +
            roma_model.sample_thresh controls the thresholding used when sampling matches for estimation. In certain cases a lower or higher threshold may improve results.
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            ## Reproducing Results
         | 
| 65 | 
            +
            The experiments in the paper are provided in the [experiments folder](experiments).
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            ### Training
         | 
| 68 | 
            +
            1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets.
         | 
| 69 | 
            +
            2. Run the relevant experiment, e.g.,
         | 
| 70 | 
            +
            ```bash
         | 
| 71 | 
            +
            torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py
         | 
| 72 | 
            +
            ```
         | 
| 73 | 
            +
            ### Testing
         | 
| 74 | 
            +
            ```bash
         | 
| 75 | 
            +
            python experiments/roma_outdoor.py --only_test --benchmark mega-1500
         | 
| 76 | 
            +
            ```
         | 
| 77 | 
            +
            ## License
         | 
| 78 | 
            +
            All our code except DINOv2 is MIT license.
         | 
| 79 | 
            +
            DINOv2 has an Apache 2 license [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE).
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            ## Acknowledgement
         | 
| 82 | 
            +
            Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            ## Tiny RoMa
         | 
| 85 | 
            +
            If you find that RoMa is too heavy, you might want to try Tiny RoMa which is built on top of XFeat.
         | 
| 86 | 
            +
            ```python
         | 
| 87 | 
            +
            from romatch import tiny_roma_v1_outdoor
         | 
| 88 | 
            +
            tiny_roma_model = tiny_roma_v1_outdoor(device=device)
         | 
| 89 | 
            +
            ```
         | 
| 90 | 
            +
            Mega1500:
         | 
| 91 | 
            +
            |  | AUC@5 | AUC@10 | AUC@20 |
         | 
| 92 | 
            +
            |----------|----------|----------|----------|
         | 
| 93 | 
            +
            | XFeat    | 46.4    | 58.9    | 69.2    |
         | 
| 94 | 
            +
            | XFeat*    |  51.9   | 67.2    | 78.9    |
         | 
| 95 | 
            +
            | Tiny RoMa v1    | 56.4 | 69.5 | 79.5     |
         | 
| 96 | 
            +
            | RoMa    |  -   | -    | -    |
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            Mega-8-Scenes (See DKM):
         | 
| 99 | 
            +
            |  | AUC@5 | AUC@10 | AUC@20 |
         | 
| 100 | 
            +
            |----------|----------|----------|----------|
         | 
| 101 | 
            +
            | XFeat    | -    | -    | -    |
         | 
| 102 | 
            +
            | XFeat*    |  50.1   | 64.4    | 75.2    |
         | 
| 103 | 
            +
            | Tiny RoMa v1    | 57.7 | 70.5 | 79.6     |
         | 
| 104 | 
            +
            | RoMa    |  -   | -    | -    |
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            IMC22 :'):
         | 
| 107 | 
            +
            |  | mAA@10 |
         | 
| 108 | 
            +
            |----------|----------|
         | 
| 109 | 
            +
            | XFeat    | 42.1    |
         | 
| 110 | 
            +
            | XFeat*    |  -   |
         | 
| 111 | 
            +
            | Tiny RoMa v1    | 42.2 |
         | 
| 112 | 
            +
            | RoMa    |  -   |
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            ## BibTeX
         | 
| 115 | 
            +
            If you find our models useful, please consider citing our paper!
         | 
| 116 | 
            +
            ```
         | 
| 117 | 
            +
            @article{edstedt2024roma,
         | 
| 118 | 
            +
            title={{RoMa: Robust Dense Feature Matching}},
         | 
| 119 | 
            +
            author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and Wadenbäck, Mårten and Felsberg, Michael},
         | 
| 120 | 
            +
            journal={IEEE Conference on Computer Vision and Pattern Recognition},
         | 
| 121 | 
            +
            year={2024}
         | 
| 122 | 
            +
            }
         | 
| 123 | 
            +
            ```
         | 
    	
        submodules/RoMa/data/.gitignore
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *
         | 
| 2 | 
            +
            !.gitignore
         | 
    	
        submodules/RoMa/demo/demo_3D_effect.py
    ADDED
    
    | @@ -0,0 +1,47 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from PIL import Image
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from romatch.utils.utils import tensor_to_pil
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from romatch import roma_outdoor
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 10 | 
            +
            if torch.backends.mps.is_available():
         | 
| 11 | 
            +
                device = torch.device('mps')
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            if __name__ == "__main__":
         | 
| 14 | 
            +
                from argparse import ArgumentParser
         | 
| 15 | 
            +
                parser = ArgumentParser()
         | 
| 16 | 
            +
                parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
         | 
| 17 | 
            +
                parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
         | 
| 18 | 
            +
                parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                args, _ = parser.parse_known_args()
         | 
| 21 | 
            +
                im1_path = args.im_A_path
         | 
| 22 | 
            +
                im2_path = args.im_B_path
         | 
| 23 | 
            +
                save_path = args.save_path
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                # Create model
         | 
| 26 | 
            +
                roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
         | 
| 27 | 
            +
                roma_model.symmetric = False
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                H, W = roma_model.get_output_resolution()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                im1 = Image.open(im1_path).resize((W, H))
         | 
| 32 | 
            +
                im2 = Image.open(im2_path).resize((W, H))
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                # Match
         | 
| 35 | 
            +
                warp, certainty = roma_model.match(im1_path, im2_path, device=device)
         | 
| 36 | 
            +
                # Sampling not needed, but can be done with model.sample(warp, certainty)
         | 
| 37 | 
            +
                x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
         | 
| 38 | 
            +
                x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                coords_A, coords_B = warp[...,:2], warp[...,2:]
         | 
| 41 | 
            +
                for i, x in enumerate(np.linspace(0,2*np.pi,200)):
         | 
| 42 | 
            +
                    t = (1 + np.cos(x))/2
         | 
| 43 | 
            +
                    interp_warp = (1-t)*coords_A + t*coords_B
         | 
| 44 | 
            +
                    im2_transfer_rgb = F.grid_sample(
         | 
| 45 | 
            +
                    x2[None], interp_warp[None], mode="bilinear", align_corners=False
         | 
| 46 | 
            +
                    )[0]
         | 
| 47 | 
            +
                    tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg")
         | 
    	
        submodules/RoMa/demo/demo_fundamental.py
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from PIL import Image
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import cv2
         | 
| 4 | 
            +
            from romatch import roma_outdoor
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 7 | 
            +
            if torch.backends.mps.is_available():
         | 
| 8 | 
            +
                device = torch.device('mps')
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            if __name__ == "__main__":
         | 
| 11 | 
            +
                from argparse import ArgumentParser
         | 
| 12 | 
            +
                parser = ArgumentParser()
         | 
| 13 | 
            +
                parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
         | 
| 14 | 
            +
                parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                args, _ = parser.parse_known_args()
         | 
| 17 | 
            +
                im1_path = args.im_A_path
         | 
| 18 | 
            +
                im2_path = args.im_B_path
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                # Create model
         | 
| 21 | 
            +
                roma_model = roma_outdoor(device=device)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
                W_A, H_A = Image.open(im1_path).size
         | 
| 25 | 
            +
                W_B, H_B = Image.open(im2_path).size
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                # Match
         | 
| 28 | 
            +
                warp, certainty = roma_model.match(im1_path, im2_path, device=device)
         | 
| 29 | 
            +
                # Sample matches for estimation
         | 
| 30 | 
            +
                matches, certainty = roma_model.sample(warp, certainty)
         | 
| 31 | 
            +
                kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)    
         | 
| 32 | 
            +
                F, mask = cv2.findFundamentalMat(
         | 
| 33 | 
            +
                    kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
         | 
| 34 | 
            +
                )
         | 
    	
        submodules/RoMa/demo/demo_match.py
    ADDED
    
    | @@ -0,0 +1,50 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from romatch.utils.utils import tensor_to_pil
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from romatch import roma_outdoor
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 12 | 
            +
            if torch.backends.mps.is_available():
         | 
| 13 | 
            +
                device = torch.device('mps')
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            if __name__ == "__main__":
         | 
| 16 | 
            +
                from argparse import ArgumentParser
         | 
| 17 | 
            +
                parser = ArgumentParser()
         | 
| 18 | 
            +
                parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
         | 
| 19 | 
            +
                parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
         | 
| 20 | 
            +
                parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                args, _ = parser.parse_known_args()
         | 
| 23 | 
            +
                im1_path = args.im_A_path
         | 
| 24 | 
            +
                im2_path = args.im_B_path
         | 
| 25 | 
            +
                save_path = args.save_path
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                # Create model
         | 
| 28 | 
            +
                roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                H, W = roma_model.get_output_resolution()
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                im1 = Image.open(im1_path).resize((W, H))
         | 
| 33 | 
            +
                im2 = Image.open(im2_path).resize((W, H))
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                # Match
         | 
| 36 | 
            +
                warp, certainty = roma_model.match(im1_path, im2_path, device=device)
         | 
| 37 | 
            +
                # Sampling not needed, but can be done with model.sample(warp, certainty)
         | 
| 38 | 
            +
                x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
         | 
| 39 | 
            +
                x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                im2_transfer_rgb = F.grid_sample(
         | 
| 42 | 
            +
                x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
         | 
| 43 | 
            +
                )[0]
         | 
| 44 | 
            +
                im1_transfer_rgb = F.grid_sample(
         | 
| 45 | 
            +
                x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
         | 
| 46 | 
            +
                )[0]
         | 
| 47 | 
            +
                warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
         | 
| 48 | 
            +
                white_im = torch.ones((H,2*W),device=device)
         | 
| 49 | 
            +
                vis_im = certainty * warp_im + (1 - certainty) * white_im
         | 
| 50 | 
            +
                tensor_to_pil(vis_im, unnormalize=False).save(save_path)
         | 
    	
        submodules/RoMa/demo/demo_match_opencv_sift.py
    ADDED
    
    | @@ -0,0 +1,43 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from PIL import Image
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import cv2 as cv
         | 
| 6 | 
            +
            import matplotlib.pyplot as plt
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            if __name__ == "__main__":
         | 
| 11 | 
            +
                from argparse import ArgumentParser
         | 
| 12 | 
            +
                parser = ArgumentParser()
         | 
| 13 | 
            +
                parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
         | 
| 14 | 
            +
                parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
         | 
| 15 | 
            +
                parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                args, _ = parser.parse_known_args()
         | 
| 18 | 
            +
                im1_path = args.im_A_path
         | 
| 19 | 
            +
                im2_path = args.im_B_path
         | 
| 20 | 
            +
                save_path = args.save_path
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE)          # queryImage
         | 
| 23 | 
            +
                img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage
         | 
| 24 | 
            +
                # Initiate SIFT detector
         | 
| 25 | 
            +
                sift = cv.SIFT_create()
         | 
| 26 | 
            +
                # find the keypoints and descriptors with SIFT
         | 
| 27 | 
            +
                kp1, des1 = sift.detectAndCompute(img1,None)
         | 
| 28 | 
            +
                kp2, des2 = sift.detectAndCompute(img2,None)
         | 
| 29 | 
            +
                # BFMatcher with default params
         | 
| 30 | 
            +
                bf = cv.BFMatcher()
         | 
| 31 | 
            +
                matches = bf.knnMatch(des1,des2,k=2)
         | 
| 32 | 
            +
                # Apply ratio test
         | 
| 33 | 
            +
                good = []
         | 
| 34 | 
            +
                for m,n in matches:
         | 
| 35 | 
            +
                    if m.distance < 0.75*n.distance:
         | 
| 36 | 
            +
                        good.append([m])
         | 
| 37 | 
            +
                # cv.drawMatchesKnn expects list of lists as matches.
         | 
| 38 | 
            +
                draw_params = dict(matchColor = (255,0,0), # draw matches in red color
         | 
| 39 | 
            +
                               singlePointColor = None,
         | 
| 40 | 
            +
                               flags = 2)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params)
         | 
| 43 | 
            +
                Image.fromarray(img3).save("demo/sift_matches.png")
         | 
    	
        submodules/RoMa/demo/demo_match_tiny.py
    ADDED
    
    | @@ -0,0 +1,77 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from romatch.utils.utils import tensor_to_pil
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from romatch import tiny_roma_v1_outdoor
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 12 | 
            +
            if torch.backends.mps.is_available():
         | 
| 13 | 
            +
                device = torch.device('mps')
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            if __name__ == "__main__":
         | 
| 16 | 
            +
                from argparse import ArgumentParser
         | 
| 17 | 
            +
                parser = ArgumentParser()
         | 
| 18 | 
            +
                parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
         | 
| 19 | 
            +
                parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
         | 
| 20 | 
            +
                parser.add_argument("--save_A_path", default="demo/tiny_roma_warp_A.jpg", type=str)
         | 
| 21 | 
            +
                parser.add_argument("--save_B_path", default="demo/tiny_roma_warp_B.jpg", type=str)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                args, _ = parser.parse_known_args()
         | 
| 24 | 
            +
                im1_path = args.im_A_path
         | 
| 25 | 
            +
                im2_path = args.im_B_path
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                # Create model
         | 
| 28 | 
            +
                roma_model = tiny_roma_v1_outdoor(device=device)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                # Match
         | 
| 31 | 
            +
                warp, certainty1 = roma_model.match(im1_path, im2_path)
         | 
| 32 | 
            +
                
         | 
| 33 | 
            +
                h1, w1 = warp.shape[:2]
         | 
| 34 | 
            +
                
         | 
| 35 | 
            +
                # maybe im1.size != im2.size
         | 
| 36 | 
            +
                im1 = Image.open(im1_path).resize((w1, h1))
         | 
| 37 | 
            +
                im2 = Image.open(im2_path)
         | 
| 38 | 
            +
                x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
         | 
| 39 | 
            +
                x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
         | 
| 40 | 
            +
                
         | 
| 41 | 
            +
                h2, w2 = x2.shape[1:]
         | 
| 42 | 
            +
                g1_p2x = w2 / 2 * (warp[..., 2] + 1)
         | 
| 43 | 
            +
                g1_p2y = h2 / 2 * (warp[..., 3] + 1)
         | 
| 44 | 
            +
                g2_p1x = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
         | 
| 45 | 
            +
                g2_p1y = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                x, y = torch.meshgrid(
         | 
| 48 | 
            +
                    torch.arange(w1, device=device),
         | 
| 49 | 
            +
                    torch.arange(h1, device=device),
         | 
| 50 | 
            +
                    indexing="xy",
         | 
| 51 | 
            +
                )
         | 
| 52 | 
            +
                g2x = torch.round(g1_p2x[y, x]).long()
         | 
| 53 | 
            +
                g2y = torch.round(g1_p2y[y, x]).long()
         | 
| 54 | 
            +
                idx_x = torch.bitwise_and(0 <= g2x, g2x < w2)
         | 
| 55 | 
            +
                idx_y = torch.bitwise_and(0 <= g2y, g2y < h2)
         | 
| 56 | 
            +
                idx = torch.bitwise_and(idx_x, idx_y)
         | 
| 57 | 
            +
                g2_p1x[g2y[idx], g2x[idx]] = x[idx].float() * 2 / w1 - 1
         | 
| 58 | 
            +
                g2_p1y[g2y[idx], g2x[idx]] = y[idx].float() * 2 / h1 - 1
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                certainty2 = F.grid_sample(
         | 
| 61 | 
            +
                    certainty1[None][None],
         | 
| 62 | 
            +
                    torch.stack([g2_p1x, g2_p1y], dim=2)[None],
         | 
| 63 | 
            +
                    mode="bilinear",
         | 
| 64 | 
            +
                    align_corners=False,
         | 
| 65 | 
            +
                )[0]
         | 
| 66 | 
            +
                
         | 
| 67 | 
            +
                white_im1 = torch.ones((h1, w1), device = device)
         | 
| 68 | 
            +
                white_im2 = torch.ones((h2, w2), device = device)
         | 
| 69 | 
            +
                
         | 
| 70 | 
            +
                certainty1 = F.avg_pool2d(certainty1[None], kernel_size=5, stride=1, padding=2)[0]
         | 
| 71 | 
            +
                certainty2 = F.avg_pool2d(certainty2[None], kernel_size=5, stride=1, padding=2)[0]
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                vis_im1 = certainty1 * x1 + (1 - certainty1) * white_im1
         | 
| 74 | 
            +
                vis_im2 = certainty2 * x2 + (1 - certainty2) * white_im2
         | 
| 75 | 
            +
                
         | 
| 76 | 
            +
                tensor_to_pil(vis_im1, unnormalize=False).save(args.save_A_path)
         | 
| 77 | 
            +
                tensor_to_pil(vis_im2, unnormalize=False).save(args.save_B_path)
         | 
    	
        submodules/RoMa/demo/gif/.gitignore
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *
         | 
| 2 | 
            +
            !.gitignore
         | 
    	
        submodules/RoMa/experiments/eval_roma_outdoor.py
    ADDED
    
    | @@ -0,0 +1,57 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from romatch.benchmarks import MegadepthDenseBenchmark
         | 
| 4 | 
            +
            from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, HpatchesHomogBenchmark
         | 
| 5 | 
            +
            from romatch.benchmarks import Mega1500PoseLibBenchmark
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            def test_mega_8_scenes(model, name):
         | 
| 8 | 
            +
                mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
         | 
| 9 | 
            +
                                                            scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
         | 
| 10 | 
            +
                                                                'mega_8_scenes_0025_0.1_0.3.npz',
         | 
| 11 | 
            +
                                                                'mega_8_scenes_0021_0.1_0.3.npz',
         | 
| 12 | 
            +
                                                                'mega_8_scenes_0008_0.1_0.3.npz',
         | 
| 13 | 
            +
                                                                'mega_8_scenes_0032_0.1_0.3.npz',
         | 
| 14 | 
            +
                                                                'mega_8_scenes_1589_0.1_0.3.npz',
         | 
| 15 | 
            +
                                                                'mega_8_scenes_0063_0.1_0.3.npz',
         | 
| 16 | 
            +
                                                                'mega_8_scenes_0024_0.1_0.3.npz',
         | 
| 17 | 
            +
                                                                'mega_8_scenes_0019_0.3_0.5.npz',
         | 
| 18 | 
            +
                                                                'mega_8_scenes_0025_0.3_0.5.npz',
         | 
| 19 | 
            +
                                                                'mega_8_scenes_0021_0.3_0.5.npz',
         | 
| 20 | 
            +
                                                                'mega_8_scenes_0008_0.3_0.5.npz',
         | 
| 21 | 
            +
                                                                'mega_8_scenes_0032_0.3_0.5.npz',
         | 
| 22 | 
            +
                                                                'mega_8_scenes_1589_0.3_0.5.npz',
         | 
| 23 | 
            +
                                                                'mega_8_scenes_0063_0.3_0.5.npz',
         | 
| 24 | 
            +
                                                                'mega_8_scenes_0024_0.3_0.5.npz'])
         | 
| 25 | 
            +
                mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
         | 
| 26 | 
            +
                print(mega_8_scenes_results)
         | 
| 27 | 
            +
                json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            def test_mega1500(model, name):
         | 
| 30 | 
            +
                mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
         | 
| 31 | 
            +
                mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
         | 
| 32 | 
            +
                json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            def test_mega1500_poselib(model, name):
         | 
| 35 | 
            +
                mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth")
         | 
| 36 | 
            +
                mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
         | 
| 37 | 
            +
                json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            def test_mega_dense(model, name):
         | 
| 40 | 
            +
                megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
         | 
| 41 | 
            +
                megadense_results = megadense_benchmark.benchmark(model)
         | 
| 42 | 
            +
                json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
            def test_hpatches(model, name):
         | 
| 45 | 
            +
                hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
         | 
| 46 | 
            +
                hpatches_results = hpatches_benchmark.benchmark(model)
         | 
| 47 | 
            +
                json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            if __name__ == "__main__":
         | 
| 51 | 
            +
                from romatch import roma_outdoor
         | 
| 52 | 
            +
                device = "cuda"
         | 
| 53 | 
            +
                model = roma_outdoor(device = device, coarse_res = 672, upsample_res = 1344)
         | 
| 54 | 
            +
                experiment_name = "roma_latest"
         | 
| 55 | 
            +
                test_mega1500(model, experiment_name)
         | 
| 56 | 
            +
                #test_mega1500_poselib(model, experiment_name)
         | 
| 57 | 
            +
                
         | 
    	
        submodules/RoMa/experiments/eval_tiny_roma_v1_outdoor.py
    ADDED
    
    | @@ -0,0 +1,84 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from pathlib import Path
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            from romatch.benchmarks import ScanNetBenchmark
         | 
| 6 | 
            +
            from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
         | 
| 7 | 
            +
            from romatch.benchmarks import MegaDepthPoseEstimationBenchmark
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            def test_mega_8_scenes(model, name):
         | 
| 10 | 
            +
                mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
         | 
| 11 | 
            +
                                                            scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
         | 
| 12 | 
            +
                                                                'mega_8_scenes_0025_0.1_0.3.npz',
         | 
| 13 | 
            +
                                                                'mega_8_scenes_0021_0.1_0.3.npz',
         | 
| 14 | 
            +
                                                                'mega_8_scenes_0008_0.1_0.3.npz',
         | 
| 15 | 
            +
                                                                'mega_8_scenes_0032_0.1_0.3.npz',
         | 
| 16 | 
            +
                                                                'mega_8_scenes_1589_0.1_0.3.npz',
         | 
| 17 | 
            +
                                                                'mega_8_scenes_0063_0.1_0.3.npz',
         | 
| 18 | 
            +
                                                                'mega_8_scenes_0024_0.1_0.3.npz',
         | 
| 19 | 
            +
                                                                'mega_8_scenes_0019_0.3_0.5.npz',
         | 
| 20 | 
            +
                                                                'mega_8_scenes_0025_0.3_0.5.npz',
         | 
| 21 | 
            +
                                                                'mega_8_scenes_0021_0.3_0.5.npz',
         | 
| 22 | 
            +
                                                                'mega_8_scenes_0008_0.3_0.5.npz',
         | 
| 23 | 
            +
                                                                'mega_8_scenes_0032_0.3_0.5.npz',
         | 
| 24 | 
            +
                                                                'mega_8_scenes_1589_0.3_0.5.npz',
         | 
| 25 | 
            +
                                                                'mega_8_scenes_0063_0.3_0.5.npz',
         | 
| 26 | 
            +
                                                                'mega_8_scenes_0024_0.3_0.5.npz'])
         | 
| 27 | 
            +
                mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
         | 
| 28 | 
            +
                print(mega_8_scenes_results)
         | 
| 29 | 
            +
                json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            def test_mega1500(model, name):
         | 
| 32 | 
            +
                mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
         | 
| 33 | 
            +
                mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
         | 
| 34 | 
            +
                json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            def test_mega1500_poselib(model, name):
         | 
| 37 | 
            +
                #model.exact_softmax = True
         | 
| 38 | 
            +
                mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
         | 
| 39 | 
            +
                mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
         | 
| 40 | 
            +
                json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            def test_mega_8_scenes_poselib(model, name):
         | 
| 43 | 
            +
                mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
         | 
| 44 | 
            +
                                                              scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
         | 
| 45 | 
            +
                                                                'mega_8_scenes_0025_0.1_0.3.npz',
         | 
| 46 | 
            +
                                                                'mega_8_scenes_0021_0.1_0.3.npz',
         | 
| 47 | 
            +
                                                                'mega_8_scenes_0008_0.1_0.3.npz',
         | 
| 48 | 
            +
                                                                'mega_8_scenes_0032_0.1_0.3.npz',
         | 
| 49 | 
            +
                                                                'mega_8_scenes_1589_0.1_0.3.npz',
         | 
| 50 | 
            +
                                                                'mega_8_scenes_0063_0.1_0.3.npz',
         | 
| 51 | 
            +
                                                                'mega_8_scenes_0024_0.1_0.3.npz',
         | 
| 52 | 
            +
                                                                'mega_8_scenes_0019_0.3_0.5.npz',
         | 
| 53 | 
            +
                                                                'mega_8_scenes_0025_0.3_0.5.npz',
         | 
| 54 | 
            +
                                                                'mega_8_scenes_0021_0.3_0.5.npz',
         | 
| 55 | 
            +
                                                                'mega_8_scenes_0008_0.3_0.5.npz',
         | 
| 56 | 
            +
                                                                'mega_8_scenes_0032_0.3_0.5.npz',
         | 
| 57 | 
            +
                                                                'mega_8_scenes_1589_0.3_0.5.npz',
         | 
| 58 | 
            +
                                                                'mega_8_scenes_0063_0.3_0.5.npz',
         | 
| 59 | 
            +
                                                                'mega_8_scenes_0024_0.3_0.5.npz'])
         | 
| 60 | 
            +
                mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
         | 
| 61 | 
            +
                json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            def test_scannet_poselib(model, name):
         | 
| 64 | 
            +
                scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
         | 
| 65 | 
            +
                scannet_results = scannet_benchmark.benchmark(model)
         | 
| 66 | 
            +
                json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            def test_scannet(model, name):
         | 
| 69 | 
            +
                scannet_benchmark = ScanNetBenchmark("data/scannet")
         | 
| 70 | 
            +
                scannet_results = scannet_benchmark.benchmark(model)
         | 
| 71 | 
            +
                json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            if __name__ == "__main__":
         | 
| 74 | 
            +
                os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
         | 
| 75 | 
            +
                os.environ["OMP_NUM_THREADS"] = "16"
         | 
| 76 | 
            +
                torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
         | 
| 77 | 
            +
                from romatch import tiny_roma_v1_outdoor
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                experiment_name = Path(__file__).stem
         | 
| 80 | 
            +
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 81 | 
            +
                model = tiny_roma_v1_outdoor(device)
         | 
| 82 | 
            +
                #test_mega1500_poselib(model, experiment_name)
         | 
| 83 | 
            +
                test_mega_8_scenes_poselib(model, experiment_name)
         | 
| 84 | 
            +
             
         | 
    	
        submodules/RoMa/experiments/roma_indoor.py
    ADDED
    
    | @@ -0,0 +1,320 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from argparse import ArgumentParser
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from torch import nn
         | 
| 6 | 
            +
            from torch.utils.data import ConcatDataset
         | 
| 7 | 
            +
            import torch.distributed as dist
         | 
| 8 | 
            +
            from torch.nn.parallel import DistributedDataParallel as DDP
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import json
         | 
| 11 | 
            +
            import wandb
         | 
| 12 | 
            +
            from tqdm import tqdm
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from romatch.benchmarks import MegadepthDenseBenchmark
         | 
| 15 | 
            +
            from romatch.datasets.megadepth import MegadepthBuilder
         | 
| 16 | 
            +
            from romatch.datasets.scannet import ScanNetBuilder
         | 
| 17 | 
            +
            from romatch.losses.robust_loss import RobustLosses
         | 
| 18 | 
            +
            from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
         | 
| 19 | 
            +
            from romatch.train.train import train_k_steps
         | 
| 20 | 
            +
            from romatch.models.matcher import *
         | 
| 21 | 
            +
            from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
         | 
| 22 | 
            +
            from romatch.models.encoders import *
         | 
| 23 | 
            +
            from romatch.checkpointing import CheckPoint
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
         | 
| 28 | 
            +
                gp_dim = 512
         | 
| 29 | 
            +
                feat_dim = 512
         | 
| 30 | 
            +
                decoder_dim = gp_dim + feat_dim
         | 
| 31 | 
            +
                cls_to_coord_res = 64
         | 
| 32 | 
            +
                coordinate_decoder = TransformerDecoder(
         | 
| 33 | 
            +
                    nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), 
         | 
| 34 | 
            +
                    decoder_dim, 
         | 
| 35 | 
            +
                    cls_to_coord_res**2 + 1,
         | 
| 36 | 
            +
                    is_classifier=True,
         | 
| 37 | 
            +
                    amp = True,
         | 
| 38 | 
            +
                    pos_enc = False,)
         | 
| 39 | 
            +
                dw = True
         | 
| 40 | 
            +
                hidden_blocks = 8
         | 
| 41 | 
            +
                kernel_size = 5
         | 
| 42 | 
            +
                displacement_emb = "linear"
         | 
| 43 | 
            +
                disable_local_corr_grad = True
         | 
| 44 | 
            +
                
         | 
| 45 | 
            +
                conv_refiner = nn.ModuleDict(
         | 
| 46 | 
            +
                    {
         | 
| 47 | 
            +
                        "16": ConvRefiner(
         | 
| 48 | 
            +
                            2 * 512+128+(2*7+1)**2,
         | 
| 49 | 
            +
                            2 * 512+128+(2*7+1)**2,
         | 
| 50 | 
            +
                            2 + 1,
         | 
| 51 | 
            +
                            kernel_size=kernel_size,
         | 
| 52 | 
            +
                            dw=dw,
         | 
| 53 | 
            +
                            hidden_blocks=hidden_blocks,
         | 
| 54 | 
            +
                            displacement_emb=displacement_emb,
         | 
| 55 | 
            +
                            displacement_emb_dim=128,
         | 
| 56 | 
            +
                            local_corr_radius = 7,
         | 
| 57 | 
            +
                            corr_in_other = True,
         | 
| 58 | 
            +
                            amp = True,
         | 
| 59 | 
            +
                            disable_local_corr_grad = disable_local_corr_grad,
         | 
| 60 | 
            +
                            bn_momentum = 0.01,
         | 
| 61 | 
            +
                        ),
         | 
| 62 | 
            +
                        "8": ConvRefiner(
         | 
| 63 | 
            +
                            2 * 512+64+(2*3+1)**2,
         | 
| 64 | 
            +
                            2 * 512+64+(2*3+1)**2,
         | 
| 65 | 
            +
                            2 + 1,
         | 
| 66 | 
            +
                            kernel_size=kernel_size,
         | 
| 67 | 
            +
                            dw=dw,
         | 
| 68 | 
            +
                            hidden_blocks=hidden_blocks,
         | 
| 69 | 
            +
                            displacement_emb=displacement_emb,
         | 
| 70 | 
            +
                            displacement_emb_dim=64,
         | 
| 71 | 
            +
                            local_corr_radius = 3,
         | 
| 72 | 
            +
                            corr_in_other = True,
         | 
| 73 | 
            +
                            amp = True,
         | 
| 74 | 
            +
                            disable_local_corr_grad = disable_local_corr_grad,
         | 
| 75 | 
            +
                            bn_momentum = 0.01,
         | 
| 76 | 
            +
                        ),
         | 
| 77 | 
            +
                        "4": ConvRefiner(
         | 
| 78 | 
            +
                            2 * 256+32+(2*2+1)**2,
         | 
| 79 | 
            +
                            2 * 256+32+(2*2+1)**2,
         | 
| 80 | 
            +
                            2 + 1,
         | 
| 81 | 
            +
                            kernel_size=kernel_size,
         | 
| 82 | 
            +
                            dw=dw,
         | 
| 83 | 
            +
                            hidden_blocks=hidden_blocks,
         | 
| 84 | 
            +
                            displacement_emb=displacement_emb,
         | 
| 85 | 
            +
                            displacement_emb_dim=32,
         | 
| 86 | 
            +
                            local_corr_radius = 2,
         | 
| 87 | 
            +
                            corr_in_other = True,
         | 
| 88 | 
            +
                            amp = True,
         | 
| 89 | 
            +
                            disable_local_corr_grad = disable_local_corr_grad,
         | 
| 90 | 
            +
                            bn_momentum = 0.01,
         | 
| 91 | 
            +
                        ),
         | 
| 92 | 
            +
                        "2": ConvRefiner(
         | 
| 93 | 
            +
                            2 * 64+16,
         | 
| 94 | 
            +
                            128+16,
         | 
| 95 | 
            +
                            2 + 1,
         | 
| 96 | 
            +
                            kernel_size=kernel_size,
         | 
| 97 | 
            +
                            dw=dw,
         | 
| 98 | 
            +
                            hidden_blocks=hidden_blocks,
         | 
| 99 | 
            +
                            displacement_emb=displacement_emb,
         | 
| 100 | 
            +
                            displacement_emb_dim=16,
         | 
| 101 | 
            +
                            amp = True,
         | 
| 102 | 
            +
                            disable_local_corr_grad = disable_local_corr_grad,
         | 
| 103 | 
            +
                            bn_momentum = 0.01,
         | 
| 104 | 
            +
                        ),
         | 
| 105 | 
            +
                        "1": ConvRefiner(
         | 
| 106 | 
            +
                            2 * 9 + 6,
         | 
| 107 | 
            +
                            24,
         | 
| 108 | 
            +
                            2 + 1,
         | 
| 109 | 
            +
                            kernel_size=kernel_size,
         | 
| 110 | 
            +
                            dw=dw,
         | 
| 111 | 
            +
                            hidden_blocks = hidden_blocks,
         | 
| 112 | 
            +
                            displacement_emb = displacement_emb,
         | 
| 113 | 
            +
                            displacement_emb_dim = 6,
         | 
| 114 | 
            +
                            amp = True,
         | 
| 115 | 
            +
                            disable_local_corr_grad = disable_local_corr_grad,
         | 
| 116 | 
            +
                            bn_momentum = 0.01,
         | 
| 117 | 
            +
                        ),
         | 
| 118 | 
            +
                    }
         | 
| 119 | 
            +
                )
         | 
| 120 | 
            +
                kernel_temperature = 0.2
         | 
| 121 | 
            +
                learn_temperature = False
         | 
| 122 | 
            +
                no_cov = True
         | 
| 123 | 
            +
                kernel = CosKernel
         | 
| 124 | 
            +
                only_attention = False
         | 
| 125 | 
            +
                basis = "fourier"
         | 
| 126 | 
            +
                gp16 = GP(
         | 
| 127 | 
            +
                    kernel,
         | 
| 128 | 
            +
                    T=kernel_temperature,
         | 
| 129 | 
            +
                    learn_temperature=learn_temperature,
         | 
| 130 | 
            +
                    only_attention=only_attention,
         | 
| 131 | 
            +
                    gp_dim=gp_dim,
         | 
| 132 | 
            +
                    basis=basis,
         | 
| 133 | 
            +
                    no_cov=no_cov,
         | 
| 134 | 
            +
                )
         | 
| 135 | 
            +
                gps = nn.ModuleDict({"16": gp16})
         | 
| 136 | 
            +
                proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
         | 
| 137 | 
            +
                proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
         | 
| 138 | 
            +
                proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
         | 
| 139 | 
            +
                proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
         | 
| 140 | 
            +
                proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
         | 
| 141 | 
            +
                proj = nn.ModuleDict({
         | 
| 142 | 
            +
                    "16": proj16,
         | 
| 143 | 
            +
                    "8": proj8,
         | 
| 144 | 
            +
                    "4": proj4,
         | 
| 145 | 
            +
                    "2": proj2,
         | 
| 146 | 
            +
                    "1": proj1,
         | 
| 147 | 
            +
                    })
         | 
| 148 | 
            +
                displacement_dropout_p = 0.0
         | 
| 149 | 
            +
                gm_warp_dropout_p = 0.0
         | 
| 150 | 
            +
                decoder = Decoder(coordinate_decoder, 
         | 
| 151 | 
            +
                                  gps, 
         | 
| 152 | 
            +
                                  proj, 
         | 
| 153 | 
            +
                                  conv_refiner, 
         | 
| 154 | 
            +
                                  detach=True, 
         | 
| 155 | 
            +
                                  scales=["16", "8", "4", "2", "1"], 
         | 
| 156 | 
            +
                                  displacement_dropout_p = displacement_dropout_p,
         | 
| 157 | 
            +
                                  gm_warp_dropout_p = gm_warp_dropout_p)
         | 
| 158 | 
            +
                h,w = resolutions[resolution]
         | 
| 159 | 
            +
                encoder = CNNandDinov2(
         | 
| 160 | 
            +
                    cnn_kwargs = dict(
         | 
| 161 | 
            +
                        pretrained=pretrained_backbone,
         | 
| 162 | 
            +
                        amp = True),
         | 
| 163 | 
            +
                    amp = True,
         | 
| 164 | 
            +
                    use_vgg = True,
         | 
| 165 | 
            +
                )
         | 
| 166 | 
            +
                matcher = RegressionMatcher(encoder, decoder, h=h, w=w, alpha=1, beta=0,**kwargs)
         | 
| 167 | 
            +
                return matcher
         | 
| 168 | 
            +
             | 
| 169 | 
            +
            def train(args):
         | 
| 170 | 
            +
                dist.init_process_group('nccl')
         | 
| 171 | 
            +
                #torch._dynamo.config.verbose=True
         | 
| 172 | 
            +
                gpus = int(os.environ['WORLD_SIZE'])
         | 
| 173 | 
            +
                # create model and move it to GPU with id rank
         | 
| 174 | 
            +
                rank = dist.get_rank()
         | 
| 175 | 
            +
                print(f"Start running DDP on rank {rank}")
         | 
| 176 | 
            +
                device_id = rank % torch.cuda.device_count()
         | 
| 177 | 
            +
                romatch.LOCAL_RANK = device_id
         | 
| 178 | 
            +
                torch.cuda.set_device(device_id)
         | 
| 179 | 
            +
                
         | 
| 180 | 
            +
                resolution = args.train_resolution
         | 
| 181 | 
            +
                wandb_log = not args.dont_log_wandb
         | 
| 182 | 
            +
                experiment_name = os.path.splitext(os.path.basename(__file__))[0]
         | 
| 183 | 
            +
                wandb_mode = "online" if wandb_log and rank == 0 and False else "disabled"
         | 
| 184 | 
            +
                wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
         | 
| 185 | 
            +
                checkpoint_dir = "workspace/checkpoints/"
         | 
| 186 | 
            +
                h,w = resolutions[resolution]
         | 
| 187 | 
            +
                model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
         | 
| 188 | 
            +
                # Num steps
         | 
| 189 | 
            +
                global_step = 0
         | 
| 190 | 
            +
                batch_size = args.gpu_batch_size
         | 
| 191 | 
            +
                step_size = gpus*batch_size
         | 
| 192 | 
            +
                romatch.STEP_SIZE = step_size
         | 
| 193 | 
            +
                
         | 
| 194 | 
            +
                N = (32 * 250000)  # 250k steps of batch size 32
         | 
| 195 | 
            +
                # checkpoint every
         | 
| 196 | 
            +
                k = 25000 // romatch.STEP_SIZE
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                # Data
         | 
| 199 | 
            +
                mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
         | 
| 200 | 
            +
                use_horizontal_flip_aug = True
         | 
| 201 | 
            +
                rot_prob = 0
         | 
| 202 | 
            +
                depth_interpolation_mode = "bilinear"
         | 
| 203 | 
            +
                megadepth_train1 = mega.build_scenes(
         | 
| 204 | 
            +
                    split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
         | 
| 205 | 
            +
                    ht=h,wt=w,
         | 
| 206 | 
            +
                )
         | 
| 207 | 
            +
                megadepth_train2 = mega.build_scenes(
         | 
| 208 | 
            +
                    split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
         | 
| 209 | 
            +
                    ht=h,wt=w,
         | 
| 210 | 
            +
                )
         | 
| 211 | 
            +
                megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
         | 
| 212 | 
            +
                mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
         | 
| 213 | 
            +
                
         | 
| 214 | 
            +
                scannet = ScanNetBuilder(data_root="data/scannet")
         | 
| 215 | 
            +
                scannet_train = scannet.build_scenes(split="train", ht=h, wt=w, use_horizontal_flip_aug = use_horizontal_flip_aug)
         | 
| 216 | 
            +
                scannet_train = ConcatDataset(scannet_train)
         | 
| 217 | 
            +
                scannet_ws = scannet.weight_scenes(scannet_train, alpha=0.75)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                # Loss and optimizer
         | 
| 220 | 
            +
                depth_loss_scannet = RobustLosses(
         | 
| 221 | 
            +
                    ce_weight=0.0, 
         | 
| 222 | 
            +
                    local_dist={1:4, 2:4, 4:8, 8:8},
         | 
| 223 | 
            +
                    local_largest_scale=8,
         | 
| 224 | 
            +
                    depth_interpolation_mode=depth_interpolation_mode,
         | 
| 225 | 
            +
                    alpha = 0.5,
         | 
| 226 | 
            +
                    c = 1e-4,)
         | 
| 227 | 
            +
                # Loss and optimizer
         | 
| 228 | 
            +
                depth_loss_mega = RobustLosses(
         | 
| 229 | 
            +
                    ce_weight=0.01, 
         | 
| 230 | 
            +
                    local_dist={1:4, 2:4, 4:8, 8:8},
         | 
| 231 | 
            +
                    local_largest_scale=8,
         | 
| 232 | 
            +
                    depth_interpolation_mode=depth_interpolation_mode,
         | 
| 233 | 
            +
                    alpha = 0.5,
         | 
| 234 | 
            +
                    c = 1e-4,)
         | 
| 235 | 
            +
                parameters = [
         | 
| 236 | 
            +
                    {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
         | 
| 237 | 
            +
                    {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
         | 
| 238 | 
            +
                ]
         | 
| 239 | 
            +
                optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
         | 
| 240 | 
            +
                lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
         | 
| 241 | 
            +
                    optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
         | 
| 242 | 
            +
                megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
         | 
| 243 | 
            +
                checkpointer = CheckPoint(checkpoint_dir, experiment_name)
         | 
| 244 | 
            +
                model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
         | 
| 245 | 
            +
                romatch.GLOBAL_STEP = global_step
         | 
| 246 | 
            +
                ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
         | 
| 247 | 
            +
                grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
         | 
| 248 | 
            +
                grad_clip_norm = 0.01
         | 
| 249 | 
            +
                for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
         | 
| 250 | 
            +
                    mega_sampler = torch.utils.data.WeightedRandomSampler(
         | 
| 251 | 
            +
                        mega_ws, num_samples = batch_size * k, replacement=False
         | 
| 252 | 
            +
                    )
         | 
| 253 | 
            +
                    mega_dataloader = iter(
         | 
| 254 | 
            +
                        torch.utils.data.DataLoader(
         | 
| 255 | 
            +
                            megadepth_train,
         | 
| 256 | 
            +
                            batch_size = batch_size,
         | 
| 257 | 
            +
                            sampler = mega_sampler,
         | 
| 258 | 
            +
                            num_workers = 8,
         | 
| 259 | 
            +
                        )
         | 
| 260 | 
            +
                    )
         | 
| 261 | 
            +
                    scannet_ws_sampler = torch.utils.data.WeightedRandomSampler(
         | 
| 262 | 
            +
                        scannet_ws, num_samples=batch_size * k, replacement=False
         | 
| 263 | 
            +
                    )
         | 
| 264 | 
            +
                    scannet_dataloader = iter(
         | 
| 265 | 
            +
                        torch.utils.data.DataLoader(
         | 
| 266 | 
            +
                            scannet_train,
         | 
| 267 | 
            +
                            batch_size=batch_size,
         | 
| 268 | 
            +
                            sampler=scannet_ws_sampler,
         | 
| 269 | 
            +
                            num_workers=gpus * 8,
         | 
| 270 | 
            +
                        )
         | 
| 271 | 
            +
                    )
         | 
| 272 | 
            +
                    for n_k in tqdm(range(n, n + 2 * k, 2),disable = romatch.RANK > 0):
         | 
| 273 | 
            +
                        train_k_steps(
         | 
| 274 | 
            +
                            n_k, 1, mega_dataloader, ddp_model, depth_loss_mega, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
         | 
| 275 | 
            +
                        )
         | 
| 276 | 
            +
                        train_k_steps(
         | 
| 277 | 
            +
                            n_k + 1, 1, scannet_dataloader, ddp_model, depth_loss_scannet, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
         | 
| 278 | 
            +
                        )
         | 
| 279 | 
            +
                    checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
         | 
| 280 | 
            +
                    wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
            def test_scannet(model, name, resolution, sample_mode):
         | 
| 283 | 
            +
                scannet_benchmark = ScanNetBenchmark("data/scannet")
         | 
| 284 | 
            +
                scannet_results = scannet_benchmark.benchmark(model)
         | 
| 285 | 
            +
                json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
         | 
| 286 | 
            +
             | 
| 287 | 
            +
            if __name__ == "__main__":
         | 
| 288 | 
            +
                import warnings
         | 
| 289 | 
            +
                warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
         | 
| 290 | 
            +
                warnings.filterwarnings('ignore')#, category=UserWarning)#, message='WARNING batched routines are designed for small sizes.')
         | 
| 291 | 
            +
                os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
         | 
| 292 | 
            +
                os.environ["OMP_NUM_THREADS"] = "16"
         | 
| 293 | 
            +
                
         | 
| 294 | 
            +
                import romatch
         | 
| 295 | 
            +
                parser = ArgumentParser()
         | 
| 296 | 
            +
                parser.add_argument("--test", action='store_true')
         | 
| 297 | 
            +
                parser.add_argument("--debug_mode", action='store_true')
         | 
| 298 | 
            +
                parser.add_argument("--dont_log_wandb", action='store_true')
         | 
| 299 | 
            +
                parser.add_argument("--train_resolution", default='medium')
         | 
| 300 | 
            +
                parser.add_argument("--gpu_batch_size", default=4, type=int)
         | 
| 301 | 
            +
                parser.add_argument("--wandb_entity", required = False)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                args, _ = parser.parse_known_args()
         | 
| 304 | 
            +
                romatch.DEBUG_MODE = args.debug_mode
         | 
| 305 | 
            +
                if not args.test:
         | 
| 306 | 
            +
                    train(args)
         | 
| 307 | 
            +
                experiment_name = os.path.splitext(os.path.basename(__file__))[0]
         | 
| 308 | 
            +
                checkpoint_dir = "workspace/"
         | 
| 309 | 
            +
                checkpoint_name = checkpoint_dir + experiment_name + ".pth"
         | 
| 310 | 
            +
                test_resolution = "medium"
         | 
| 311 | 
            +
                sample_mode = "threshold_balanced"
         | 
| 312 | 
            +
                symmetric = True
         | 
| 313 | 
            +
                upsample_preds = False
         | 
| 314 | 
            +
                attenuate_cert = True
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert)
         | 
| 317 | 
            +
                model = model.cuda()
         | 
| 318 | 
            +
                states = torch.load(checkpoint_name)
         | 
| 319 | 
            +
                model.load_state_dict(states["model"])
         | 
| 320 | 
            +
                test_scannet(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode)
         | 
    	
        submodules/RoMa/experiments/train_roma_outdoor.py
    ADDED
    
    | @@ -0,0 +1,307 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from argparse import ArgumentParser
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from torch import nn
         | 
| 6 | 
            +
            from torch.utils.data import ConcatDataset
         | 
| 7 | 
            +
            import torch.distributed as dist
         | 
| 8 | 
            +
            from torch.nn.parallel import DistributedDataParallel as DDP
         | 
| 9 | 
            +
            import json
         | 
| 10 | 
            +
            import wandb
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from romatch.benchmarks import MegadepthDenseBenchmark
         | 
| 13 | 
            +
            from romatch.datasets.megadepth import MegadepthBuilder
         | 
| 14 | 
            +
            from romatch.losses.robust_loss import RobustLosses
         | 
| 15 | 
            +
            from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from romatch.train.train import train_k_steps
         | 
| 18 | 
            +
            from romatch.models.matcher import *
         | 
| 19 | 
            +
            from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
         | 
| 20 | 
            +
            from romatch.models.encoders import *
         | 
| 21 | 
            +
            from romatch.checkpointing import CheckPoint
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
         | 
| 26 | 
            +
                import warnings
         | 
| 27 | 
            +
                warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
         | 
| 28 | 
            +
                gp_dim = 512
         | 
| 29 | 
            +
                feat_dim = 512
         | 
| 30 | 
            +
                decoder_dim = gp_dim + feat_dim
         | 
| 31 | 
            +
                cls_to_coord_res = 64
         | 
| 32 | 
            +
                coordinate_decoder = TransformerDecoder(
         | 
| 33 | 
            +
                    nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), 
         | 
| 34 | 
            +
                    decoder_dim, 
         | 
| 35 | 
            +
                    cls_to_coord_res**2 + 1,
         | 
| 36 | 
            +
                    is_classifier=True,
         | 
| 37 | 
            +
                    amp = True,
         | 
| 38 | 
            +
                    pos_enc = False,)
         | 
| 39 | 
            +
                dw = True
         | 
| 40 | 
            +
                hidden_blocks = 8
         | 
| 41 | 
            +
                kernel_size = 5
         | 
| 42 | 
            +
                displacement_emb = "linear"
         | 
| 43 | 
            +
                disable_local_corr_grad = True
         | 
| 44 | 
            +
                
         | 
| 45 | 
            +
                conv_refiner = nn.ModuleDict(
         | 
| 46 | 
            +
                    {
         | 
| 47 | 
            +
                        "16": ConvRefiner(
         | 
| 48 | 
            +
                            2 * 512+128+(2*7+1)**2,
         | 
| 49 | 
            +
                            2 * 512+128+(2*7+1)**2,
         | 
| 50 | 
            +
                            2 + 1,
         | 
| 51 | 
            +
                            kernel_size=kernel_size,
         | 
| 52 | 
            +
                            dw=dw,
         | 
| 53 | 
            +
                            hidden_blocks=hidden_blocks,
         | 
| 54 | 
            +
                            displacement_emb=displacement_emb,
         | 
| 55 | 
            +
                            displacement_emb_dim=128,
         | 
| 56 | 
            +
                            local_corr_radius = 7,
         | 
| 57 | 
            +
                            corr_in_other = True,
         | 
| 58 | 
            +
                            amp = True,
         | 
| 59 | 
            +
                            disable_local_corr_grad = disable_local_corr_grad,
         | 
| 60 | 
            +
                            bn_momentum = 0.01,
         | 
| 61 | 
            +
                        ),
         | 
| 62 | 
            +
                        "8": ConvRefiner(
         | 
| 63 | 
            +
                            2 * 512+64+(2*3+1)**2,
         | 
| 64 | 
            +
                            2 * 512+64+(2*3+1)**2,
         | 
| 65 | 
            +
                            2 + 1,
         | 
| 66 | 
            +
                            kernel_size=kernel_size,
         | 
| 67 | 
            +
                            dw=dw,
         | 
| 68 | 
            +
                            hidden_blocks=hidden_blocks,
         | 
| 69 | 
            +
                            displacement_emb=displacement_emb,
         | 
| 70 | 
            +
                            displacement_emb_dim=64,
         | 
| 71 | 
            +
                            local_corr_radius = 3,
         | 
| 72 | 
            +
                            corr_in_other = True,
         | 
| 73 | 
            +
                            amp = True,
         | 
| 74 | 
            +
                            disable_local_corr_grad = disable_local_corr_grad,
         | 
| 75 | 
            +
                            bn_momentum = 0.01,
         | 
| 76 | 
            +
                        ),
         | 
| 77 | 
            +
                        "4": ConvRefiner(
         | 
| 78 | 
            +
                            2 * 256+32+(2*2+1)**2,
         | 
| 79 | 
            +
                            2 * 256+32+(2*2+1)**2,
         | 
| 80 | 
            +
                            2 + 1,
         | 
| 81 | 
            +
                            kernel_size=kernel_size,
         | 
| 82 | 
            +
                            dw=dw,
         | 
| 83 | 
            +
                            hidden_blocks=hidden_blocks,
         | 
| 84 | 
            +
                            displacement_emb=displacement_emb,
         | 
| 85 | 
            +
                            displacement_emb_dim=32,
         | 
| 86 | 
            +
                            local_corr_radius = 2,
         | 
| 87 | 
            +
                            corr_in_other = True,
         | 
| 88 | 
            +
                            amp = True,
         | 
| 89 | 
            +
                            disable_local_corr_grad = disable_local_corr_grad,
         | 
| 90 | 
            +
                            bn_momentum = 0.01,
         | 
| 91 | 
            +
                        ),
         | 
| 92 | 
            +
                        "2": ConvRefiner(
         | 
| 93 | 
            +
                            2 * 64+16,
         | 
| 94 | 
            +
                            128+16,
         | 
| 95 | 
            +
                            2 + 1,
         | 
| 96 | 
            +
                            kernel_size=kernel_size,
         | 
| 97 | 
            +
                            dw=dw,
         | 
| 98 | 
            +
                            hidden_blocks=hidden_blocks,
         | 
| 99 | 
            +
                            displacement_emb=displacement_emb,
         | 
| 100 | 
            +
                            displacement_emb_dim=16,
         | 
| 101 | 
            +
                            amp = True,
         | 
| 102 | 
            +
                            disable_local_corr_grad = disable_local_corr_grad,
         | 
| 103 | 
            +
                            bn_momentum = 0.01,
         | 
| 104 | 
            +
                        ),
         | 
| 105 | 
            +
                        "1": ConvRefiner(
         | 
| 106 | 
            +
                            2 * 9 + 6,
         | 
| 107 | 
            +
                            24,
         | 
| 108 | 
            +
                            2 + 1,
         | 
| 109 | 
            +
                            kernel_size=kernel_size,
         | 
| 110 | 
            +
                            dw=dw,
         | 
| 111 | 
            +
                            hidden_blocks = hidden_blocks,
         | 
| 112 | 
            +
                            displacement_emb = displacement_emb,
         | 
| 113 | 
            +
                            displacement_emb_dim = 6,
         | 
| 114 | 
            +
                            amp = True,
         | 
| 115 | 
            +
                            disable_local_corr_grad = disable_local_corr_grad,
         | 
| 116 | 
            +
                            bn_momentum = 0.01,
         | 
| 117 | 
            +
                        ),
         | 
| 118 | 
            +
                    }
         | 
| 119 | 
            +
                )
         | 
| 120 | 
            +
                kernel_temperature = 0.2
         | 
| 121 | 
            +
                learn_temperature = False
         | 
| 122 | 
            +
                no_cov = True
         | 
| 123 | 
            +
                kernel = CosKernel
         | 
| 124 | 
            +
                only_attention = False
         | 
| 125 | 
            +
                basis = "fourier"
         | 
| 126 | 
            +
                gp16 = GP(
         | 
| 127 | 
            +
                    kernel,
         | 
| 128 | 
            +
                    T=kernel_temperature,
         | 
| 129 | 
            +
                    learn_temperature=learn_temperature,
         | 
| 130 | 
            +
                    only_attention=only_attention,
         | 
| 131 | 
            +
                    gp_dim=gp_dim,
         | 
| 132 | 
            +
                    basis=basis,
         | 
| 133 | 
            +
                    no_cov=no_cov,
         | 
| 134 | 
            +
                )
         | 
| 135 | 
            +
                gps = nn.ModuleDict({"16": gp16})
         | 
| 136 | 
            +
                proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
         | 
| 137 | 
            +
                proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
         | 
| 138 | 
            +
                proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
         | 
| 139 | 
            +
                proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
         | 
| 140 | 
            +
                proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
         | 
| 141 | 
            +
                proj = nn.ModuleDict({
         | 
| 142 | 
            +
                    "16": proj16,
         | 
| 143 | 
            +
                    "8": proj8,
         | 
| 144 | 
            +
                    "4": proj4,
         | 
| 145 | 
            +
                    "2": proj2,
         | 
| 146 | 
            +
                    "1": proj1,
         | 
| 147 | 
            +
                    })
         | 
| 148 | 
            +
                displacement_dropout_p = 0.0
         | 
| 149 | 
            +
                gm_warp_dropout_p = 0.0
         | 
| 150 | 
            +
                decoder = Decoder(coordinate_decoder, 
         | 
| 151 | 
            +
                                  gps, 
         | 
| 152 | 
            +
                                  proj, 
         | 
| 153 | 
            +
                                  conv_refiner, 
         | 
| 154 | 
            +
                                  detach=True, 
         | 
| 155 | 
            +
                                  scales=["16", "8", "4", "2", "1"], 
         | 
| 156 | 
            +
                                  displacement_dropout_p = displacement_dropout_p,
         | 
| 157 | 
            +
                                  gm_warp_dropout_p = gm_warp_dropout_p)
         | 
| 158 | 
            +
                h,w = resolutions[resolution]
         | 
| 159 | 
            +
                encoder = CNNandDinov2(
         | 
| 160 | 
            +
                    cnn_kwargs = dict(
         | 
| 161 | 
            +
                        pretrained=pretrained_backbone,
         | 
| 162 | 
            +
                        amp = True),
         | 
| 163 | 
            +
                    amp = True,
         | 
| 164 | 
            +
                    use_vgg = True,
         | 
| 165 | 
            +
                )
         | 
| 166 | 
            +
                matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs)
         | 
| 167 | 
            +
                return matcher
         | 
| 168 | 
            +
             | 
| 169 | 
            +
            def train(args):
         | 
| 170 | 
            +
                dist.init_process_group('nccl')
         | 
| 171 | 
            +
                #torch._dynamo.config.verbose=True
         | 
| 172 | 
            +
                gpus = int(os.environ['WORLD_SIZE'])
         | 
| 173 | 
            +
                # create model and move it to GPU with id rank
         | 
| 174 | 
            +
                rank = dist.get_rank()
         | 
| 175 | 
            +
                print(f"Start running DDP on rank {rank}")
         | 
| 176 | 
            +
                device_id = rank % torch.cuda.device_count()
         | 
| 177 | 
            +
                romatch.LOCAL_RANK = device_id
         | 
| 178 | 
            +
                torch.cuda.set_device(device_id)
         | 
| 179 | 
            +
                
         | 
| 180 | 
            +
                resolution = args.train_resolution
         | 
| 181 | 
            +
                wandb_log = not args.dont_log_wandb
         | 
| 182 | 
            +
                experiment_name = os.path.splitext(os.path.basename(__file__))[0]
         | 
| 183 | 
            +
                wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
         | 
| 184 | 
            +
                wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
         | 
| 185 | 
            +
                checkpoint_dir = "workspace/checkpoints/"
         | 
| 186 | 
            +
                h,w = resolutions[resolution]
         | 
| 187 | 
            +
                model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
         | 
| 188 | 
            +
                # Num steps
         | 
| 189 | 
            +
                global_step = 0
         | 
| 190 | 
            +
                batch_size = args.gpu_batch_size
         | 
| 191 | 
            +
                step_size = gpus*batch_size
         | 
| 192 | 
            +
                romatch.STEP_SIZE = step_size
         | 
| 193 | 
            +
                
         | 
| 194 | 
            +
                N = (32 * 250000)  # 250k steps of batch size 32
         | 
| 195 | 
            +
                # checkpoint every
         | 
| 196 | 
            +
                k = 25000 // romatch.STEP_SIZE
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                # Data
         | 
| 199 | 
            +
                mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
         | 
| 200 | 
            +
                use_horizontal_flip_aug = True
         | 
| 201 | 
            +
                rot_prob = 0
         | 
| 202 | 
            +
                depth_interpolation_mode = "bilinear"
         | 
| 203 | 
            +
                megadepth_train1 = mega.build_scenes(
         | 
| 204 | 
            +
                    split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
         | 
| 205 | 
            +
                    ht=h,wt=w,
         | 
| 206 | 
            +
                )
         | 
| 207 | 
            +
                megadepth_train2 = mega.build_scenes(
         | 
| 208 | 
            +
                    split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
         | 
| 209 | 
            +
                    ht=h,wt=w,
         | 
| 210 | 
            +
                )
         | 
| 211 | 
            +
                megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
         | 
| 212 | 
            +
                mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
         | 
| 213 | 
            +
                # Loss and optimizer
         | 
| 214 | 
            +
                depth_loss = RobustLosses(
         | 
| 215 | 
            +
                    ce_weight=0.01, 
         | 
| 216 | 
            +
                    local_dist={1:4, 2:4, 4:8, 8:8},
         | 
| 217 | 
            +
                    local_largest_scale=8,
         | 
| 218 | 
            +
                    depth_interpolation_mode=depth_interpolation_mode,
         | 
| 219 | 
            +
                    alpha = 0.5,
         | 
| 220 | 
            +
                    c = 1e-4,)
         | 
| 221 | 
            +
                parameters = [
         | 
| 222 | 
            +
                    {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
         | 
| 223 | 
            +
                    {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
         | 
| 224 | 
            +
                ]
         | 
| 225 | 
            +
                optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
         | 
| 226 | 
            +
                lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
         | 
| 227 | 
            +
                    optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
         | 
| 228 | 
            +
                megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
         | 
| 229 | 
            +
                checkpointer = CheckPoint(checkpoint_dir, experiment_name)
         | 
| 230 | 
            +
                model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
         | 
| 231 | 
            +
                romatch.GLOBAL_STEP = global_step
         | 
| 232 | 
            +
                ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
         | 
| 233 | 
            +
                grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
         | 
| 234 | 
            +
                grad_clip_norm = 0.01
         | 
| 235 | 
            +
                for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
         | 
| 236 | 
            +
                    mega_sampler = torch.utils.data.WeightedRandomSampler(
         | 
| 237 | 
            +
                        mega_ws, num_samples = batch_size * k, replacement=False
         | 
| 238 | 
            +
                    )
         | 
| 239 | 
            +
                    mega_dataloader = iter(
         | 
| 240 | 
            +
                        torch.utils.data.DataLoader(
         | 
| 241 | 
            +
                            megadepth_train,
         | 
| 242 | 
            +
                            batch_size = batch_size,
         | 
| 243 | 
            +
                            sampler = mega_sampler,
         | 
| 244 | 
            +
                            num_workers = 8,
         | 
| 245 | 
            +
                        )
         | 
| 246 | 
            +
                    )
         | 
| 247 | 
            +
                    train_k_steps(
         | 
| 248 | 
            +
                        n, k, mega_dataloader, ddp_model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
         | 
| 249 | 
            +
                    )
         | 
| 250 | 
            +
                    checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
         | 
| 251 | 
            +
                    wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
         | 
| 252 | 
            +
             | 
| 253 | 
            +
            def test_mega_8_scenes(model, name):
         | 
| 254 | 
            +
                mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
         | 
| 255 | 
            +
                                                            scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
         | 
| 256 | 
            +
                                                                'mega_8_scenes_0025_0.1_0.3.npz',
         | 
| 257 | 
            +
                                                                'mega_8_scenes_0021_0.1_0.3.npz',
         | 
| 258 | 
            +
                                                                'mega_8_scenes_0008_0.1_0.3.npz',
         | 
| 259 | 
            +
                                                                'mega_8_scenes_0032_0.1_0.3.npz',
         | 
| 260 | 
            +
                                                                'mega_8_scenes_1589_0.1_0.3.npz',
         | 
| 261 | 
            +
                                                                'mega_8_scenes_0063_0.1_0.3.npz',
         | 
| 262 | 
            +
                                                                'mega_8_scenes_0024_0.1_0.3.npz',
         | 
| 263 | 
            +
                                                                'mega_8_scenes_0019_0.3_0.5.npz',
         | 
| 264 | 
            +
                                                                'mega_8_scenes_0025_0.3_0.5.npz',
         | 
| 265 | 
            +
                                                                'mega_8_scenes_0021_0.3_0.5.npz',
         | 
| 266 | 
            +
                                                                'mega_8_scenes_0008_0.3_0.5.npz',
         | 
| 267 | 
            +
                                                                'mega_8_scenes_0032_0.3_0.5.npz',
         | 
| 268 | 
            +
                                                                'mega_8_scenes_1589_0.3_0.5.npz',
         | 
| 269 | 
            +
                                                                'mega_8_scenes_0063_0.3_0.5.npz',
         | 
| 270 | 
            +
                                                                'mega_8_scenes_0024_0.3_0.5.npz'])
         | 
| 271 | 
            +
                mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
         | 
| 272 | 
            +
                print(mega_8_scenes_results)
         | 
| 273 | 
            +
                json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
         | 
| 274 | 
            +
             | 
| 275 | 
            +
            def test_mega1500(model, name):
         | 
| 276 | 
            +
                mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
         | 
| 277 | 
            +
                mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
         | 
| 278 | 
            +
                json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
         | 
| 279 | 
            +
             | 
| 280 | 
            +
            def test_mega_dense(model, name):
         | 
| 281 | 
            +
                megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
         | 
| 282 | 
            +
                megadense_results = megadense_benchmark.benchmark(model)
         | 
| 283 | 
            +
                json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
         | 
| 284 | 
            +
                
         | 
| 285 | 
            +
            def test_hpatches(model, name):
         | 
| 286 | 
            +
                hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
         | 
| 287 | 
            +
                hpatches_results = hpatches_benchmark.benchmark(model)
         | 
| 288 | 
            +
                json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
         | 
| 289 | 
            +
             | 
| 290 | 
            +
             | 
| 291 | 
            +
            if __name__ == "__main__":
         | 
| 292 | 
            +
                os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
         | 
| 293 | 
            +
                os.environ["OMP_NUM_THREADS"] = "16"
         | 
| 294 | 
            +
                torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
         | 
| 295 | 
            +
                import romatch
         | 
| 296 | 
            +
                parser = ArgumentParser()
         | 
| 297 | 
            +
                parser.add_argument("--only_test", action='store_true')
         | 
| 298 | 
            +
                parser.add_argument("--debug_mode", action='store_true')
         | 
| 299 | 
            +
                parser.add_argument("--dont_log_wandb", action='store_true')
         | 
| 300 | 
            +
                parser.add_argument("--train_resolution", default='medium')
         | 
| 301 | 
            +
                parser.add_argument("--gpu_batch_size", default=8, type=int)
         | 
| 302 | 
            +
                parser.add_argument("--wandb_entity", required = False)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                args, _ = parser.parse_known_args()
         | 
| 305 | 
            +
                romatch.DEBUG_MODE = args.debug_mode
         | 
| 306 | 
            +
                if not args.only_test:
         | 
| 307 | 
            +
                    train(args)
         | 
    	
        submodules/RoMa/experiments/train_tiny_roma_v1_outdoor.py
    ADDED
    
    | @@ -0,0 +1,498 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from argparse import ArgumentParser
         | 
| 7 | 
            +
            from pathlib import Path
         | 
| 8 | 
            +
            import math
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from torch import nn
         | 
| 12 | 
            +
            from torch.utils.data import ConcatDataset
         | 
| 13 | 
            +
            import torch.distributed as dist
         | 
| 14 | 
            +
            from torch.nn.parallel import DistributedDataParallel as DDP
         | 
| 15 | 
            +
            import json
         | 
| 16 | 
            +
            import wandb
         | 
| 17 | 
            +
            from PIL import Image
         | 
| 18 | 
            +
            from torchvision.transforms import ToTensor
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
         | 
| 21 | 
            +
            from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
         | 
| 22 | 
            +
            from romatch.datasets.megadepth import MegadepthBuilder
         | 
| 23 | 
            +
            from romatch.losses.robust_loss_tiny_roma import RobustLosses
         | 
| 24 | 
            +
            from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
         | 
| 25 | 
            +
            from romatch.train.train import train_k_steps
         | 
| 26 | 
            +
            from romatch.checkpointing import CheckPoint
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6), "xfeat": (600,800), "big": (768, 1024)}
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            def kde(x, std = 0.1):
         | 
| 31 | 
            +
                # use a gaussian kernel to estimate density
         | 
| 32 | 
            +
                x = x.half() # Do it in half precision TODO: remove hardcoding
         | 
| 33 | 
            +
                scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
         | 
| 34 | 
            +
                density = scores.sum(dim=-1)
         | 
| 35 | 
            +
                return density
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            class BasicLayer(nn.Module):
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                    Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
                def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True):
         | 
| 42 | 
            +
                    super().__init__()
         | 
| 43 | 
            +
                    self.layer = nn.Sequential(
         | 
| 44 | 
            +
                                                    nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
         | 
| 45 | 
            +
                                                    nn.BatchNorm2d(out_channels, affine=False),
         | 
| 46 | 
            +
                                                    nn.ReLU(inplace = True) if relu else nn.Identity()
         | 
| 47 | 
            +
                                                )
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def forward(self, x):
         | 
| 50 | 
            +
                    return self.layer(x)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            class XFeatModel(nn.Module):
         | 
| 53 | 
            +
                """
         | 
| 54 | 
            +
                    Implementation of architecture described in 
         | 
| 55 | 
            +
                    "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def __init__(self, xfeat = None, 
         | 
| 59 | 
            +
                             freeze_xfeat = True, 
         | 
| 60 | 
            +
                             sample_mode = "threshold_balanced", 
         | 
| 61 | 
            +
                             symmetric = False, 
         | 
| 62 | 
            +
                             exact_softmax = False):
         | 
| 63 | 
            +
                    super().__init__()
         | 
| 64 | 
            +
                    if xfeat is None:
         | 
| 65 | 
            +
                        xfeat = torch.hub.load('verlab/accelerated_features', 'XFeat', pretrained = True, top_k = 4096).net
         | 
| 66 | 
            +
                        del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher
         | 
| 67 | 
            +
                    if freeze_xfeat:
         | 
| 68 | 
            +
                        xfeat.train(False)
         | 
| 69 | 
            +
                        self.xfeat = [xfeat]# hide params from ddp
         | 
| 70 | 
            +
                    else:
         | 
| 71 | 
            +
                        self.xfeat = nn.ModuleList([xfeat])
         | 
| 72 | 
            +
                    self.freeze_xfeat = freeze_xfeat
         | 
| 73 | 
            +
                    match_dim = 256
         | 
| 74 | 
            +
                    self.coarse_matcher = nn.Sequential(
         | 
| 75 | 
            +
                        BasicLayer(64+64+2, match_dim,),
         | 
| 76 | 
            +
                        BasicLayer(match_dim, match_dim,), 
         | 
| 77 | 
            +
                        BasicLayer(match_dim, match_dim,), 
         | 
| 78 | 
            +
                        BasicLayer(match_dim, match_dim,), 
         | 
| 79 | 
            +
                        nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0))
         | 
| 80 | 
            +
                    fine_match_dim = 64
         | 
| 81 | 
            +
                    self.fine_matcher = nn.Sequential(
         | 
| 82 | 
            +
                        BasicLayer(24+24+2, fine_match_dim,),
         | 
| 83 | 
            +
                        BasicLayer(fine_match_dim, fine_match_dim,), 
         | 
| 84 | 
            +
                        BasicLayer(fine_match_dim, fine_match_dim,), 
         | 
| 85 | 
            +
                        BasicLayer(fine_match_dim, fine_match_dim,), 
         | 
| 86 | 
            +
                        nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),)
         | 
| 87 | 
            +
                    self.sample_mode = sample_mode
         | 
| 88 | 
            +
                    self.sample_thresh = 0.2
         | 
| 89 | 
            +
                    self.symmetric = symmetric
         | 
| 90 | 
            +
                    self.exact_softmax = exact_softmax
         | 
| 91 | 
            +
                
         | 
| 92 | 
            +
                @property
         | 
| 93 | 
            +
                def device(self):
         | 
| 94 | 
            +
                    return self.fine_matcher[-1].weight.device
         | 
| 95 | 
            +
                
         | 
| 96 | 
            +
                def preprocess_tensor(self, x):
         | 
| 97 | 
            +
                    """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
         | 
| 98 | 
            +
                    H, W = x.shape[-2:]
         | 
| 99 | 
            +
                    _H, _W = (H//32) * 32, (W//32) * 32
         | 
| 100 | 
            +
                    rh, rw = H/_H, W/_W
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
         | 
| 103 | 
            +
                    return x, rh, rw        
         | 
| 104 | 
            +
                
         | 
| 105 | 
            +
                def forward_single(self, x):
         | 
| 106 | 
            +
                    with torch.inference_mode(self.freeze_xfeat or not self.training):
         | 
| 107 | 
            +
                        xfeat = self.xfeat[0]
         | 
| 108 | 
            +
                        with torch.no_grad():
         | 
| 109 | 
            +
                            x = x.mean(dim=1, keepdim = True)
         | 
| 110 | 
            +
                            x = xfeat.norm(x)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                        #main backbone
         | 
| 113 | 
            +
                        x1 = xfeat.block1(x)
         | 
| 114 | 
            +
                        x2 = xfeat.block2(x1 + xfeat.skip1(x))
         | 
| 115 | 
            +
                        x3 = xfeat.block3(x2)
         | 
| 116 | 
            +
                        x4 = xfeat.block4(x3)
         | 
| 117 | 
            +
                        x5 = xfeat.block5(x4)
         | 
| 118 | 
            +
                        x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
         | 
| 119 | 
            +
                        x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
         | 
| 120 | 
            +
                        feats = xfeat.block_fusion( x3 + x4 + x5 )
         | 
| 121 | 
            +
                    if self.freeze_xfeat:
         | 
| 122 | 
            +
                        return x2.clone(), feats.clone()
         | 
| 123 | 
            +
                    return x2, feats
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
         | 
| 126 | 
            +
                    if coords.shape[-1] == 2:
         | 
| 127 | 
            +
                        return self._to_pixel_coordinates(coords, H_A, W_A) 
         | 
| 128 | 
            +
                    
         | 
| 129 | 
            +
                    if isinstance(coords, (list, tuple)):
         | 
| 130 | 
            +
                        kpts_A, kpts_B = coords[0], coords[1]
         | 
| 131 | 
            +
                    else:
         | 
| 132 | 
            +
                        kpts_A, kpts_B = coords[...,:2], coords[...,2:]
         | 
| 133 | 
            +
                    return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def _to_pixel_coordinates(self, coords, H, W):
         | 
| 136 | 
            +
                    kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
         | 
| 137 | 
            +
                    return kpts
         | 
| 138 | 
            +
                
         | 
| 139 | 
            +
                def pos_embed(self, corr_volume: torch.Tensor):
         | 
| 140 | 
            +
                    B, H1, W1, H0, W0 = corr_volume.shape 
         | 
| 141 | 
            +
                    grid = torch.stack(
         | 
| 142 | 
            +
                            torch.meshgrid(
         | 
| 143 | 
            +
                                torch.linspace(-1+1/W1,1-1/W1, W1), 
         | 
| 144 | 
            +
                                torch.linspace(-1+1/H1,1-1/H1, H1), 
         | 
| 145 | 
            +
                                indexing = "xy"), 
         | 
| 146 | 
            +
                            dim = -1).float().to(corr_volume).reshape(H1*W1, 2)
         | 
| 147 | 
            +
                    down = 4
         | 
| 148 | 
            +
                    if not self.training and not self.exact_softmax:
         | 
| 149 | 
            +
                        grid_lr = torch.stack(
         | 
| 150 | 
            +
                            torch.meshgrid(
         | 
| 151 | 
            +
                                torch.linspace(-1+down/W1,1-down/W1, W1//down), 
         | 
| 152 | 
            +
                                torch.linspace(-1+down/H1,1-down/H1, H1//down), 
         | 
| 153 | 
            +
                                indexing = "xy"), 
         | 
| 154 | 
            +
                            dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
         | 
| 155 | 
            +
                        cv = corr_volume
         | 
| 156 | 
            +
                        best_match = cv.reshape(B,H1*W1,H0,W0).amax(dim=1) # B, HW, H, W
         | 
| 157 | 
            +
                        P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
         | 
| 158 | 
            +
                        pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
         | 
| 159 | 
            +
                        pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
         | 
| 160 | 
            +
                    else:
         | 
| 161 | 
            +
                        P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W
         | 
| 162 | 
            +
                        pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid)
         | 
| 163 | 
            +
                    return pos_embeddings
         | 
| 164 | 
            +
                
         | 
| 165 | 
            +
                def visualize_warp(self, warp, certainty, im_A = None, im_B = None, 
         | 
| 166 | 
            +
                                   im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False):
         | 
| 167 | 
            +
                    device = warp.device
         | 
| 168 | 
            +
                    H,W2,_ = warp.shape
         | 
| 169 | 
            +
                    W = W2//2 if symmetric else W2
         | 
| 170 | 
            +
                    if im_A is None:
         | 
| 171 | 
            +
                        from PIL import Image
         | 
| 172 | 
            +
                        im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
         | 
| 173 | 
            +
                    if not isinstance(im_A, torch.Tensor):
         | 
| 174 | 
            +
                        im_A = im_A.resize((W,H))
         | 
| 175 | 
            +
                        im_B = im_B.resize((W,H))    
         | 
| 176 | 
            +
                        x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
         | 
| 177 | 
            +
                        if symmetric:
         | 
| 178 | 
            +
                            x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
         | 
| 179 | 
            +
                    else:
         | 
| 180 | 
            +
                        if symmetric:
         | 
| 181 | 
            +
                            x_A = im_A
         | 
| 182 | 
            +
                        x_B = im_B
         | 
| 183 | 
            +
                    im_A_transfer_rgb = F.grid_sample(
         | 
| 184 | 
            +
                    x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
         | 
| 185 | 
            +
                    )[0]
         | 
| 186 | 
            +
                    if symmetric:
         | 
| 187 | 
            +
                        im_B_transfer_rgb = F.grid_sample(
         | 
| 188 | 
            +
                        x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
         | 
| 189 | 
            +
                        )[0]
         | 
| 190 | 
            +
                        warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
         | 
| 191 | 
            +
                        white_im = torch.ones((H,2*W),device=device)
         | 
| 192 | 
            +
                    else:
         | 
| 193 | 
            +
                        warp_im = im_A_transfer_rgb
         | 
| 194 | 
            +
                        white_im = torch.ones((H, W), device = device)
         | 
| 195 | 
            +
                    vis_im = certainty * warp_im + (1 - certainty) * white_im
         | 
| 196 | 
            +
                    if save_path is not None:
         | 
| 197 | 
            +
                        from romatch.utils import tensor_to_pil
         | 
| 198 | 
            +
                        tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
         | 
| 199 | 
            +
                    return vis_im
         | 
| 200 | 
            +
                 
         | 
| 201 | 
            +
                def corr_volume(self, feat0, feat1):
         | 
| 202 | 
            +
                    """
         | 
| 203 | 
            +
                        input:
         | 
| 204 | 
            +
                            feat0 -> torch.Tensor(B, C, H, W)
         | 
| 205 | 
            +
                            feat1 -> torch.Tensor(B, C, H, W)
         | 
| 206 | 
            +
                        return:
         | 
| 207 | 
            +
                            corr_volume -> torch.Tensor(B, H, W, H, W)
         | 
| 208 | 
            +
                    """
         | 
| 209 | 
            +
                    B, C, H0, W0 = feat0.shape
         | 
| 210 | 
            +
                    B, C, H1, W1 = feat1.shape
         | 
| 211 | 
            +
                    feat0 = feat0.view(B, C, H0*W0)
         | 
| 212 | 
            +
                    feat1 = feat1.view(B, C, H1*W1)
         | 
| 213 | 
            +
                    corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16
         | 
| 214 | 
            +
                    return corr_volume
         | 
| 215 | 
            +
                
         | 
| 216 | 
            +
                @torch.inference_mode()
         | 
| 217 | 
            +
                def match_from_path(self, im0_path, im1_path):
         | 
| 218 | 
            +
                    device = self.device
         | 
| 219 | 
            +
                    im0 = ToTensor()(Image.open(im0_path))[None].to(device)
         | 
| 220 | 
            +
                    im1 = ToTensor()(Image.open(im1_path))[None].to(device)
         | 
| 221 | 
            +
                    return self.match(im0, im1, batched = False)
         | 
| 222 | 
            +
                
         | 
| 223 | 
            +
                @torch.inference_mode()
         | 
| 224 | 
            +
                def match(self, im0, im1, *args, batched = True):
         | 
| 225 | 
            +
                    # stupid
         | 
| 226 | 
            +
                    if isinstance(im0, (str, Path)):
         | 
| 227 | 
            +
                        return self.match_from_path(im0, im1)
         | 
| 228 | 
            +
                    elif isinstance(im0, Image.Image):
         | 
| 229 | 
            +
                        batched = False
         | 
| 230 | 
            +
                        device = self.device
         | 
| 231 | 
            +
                        im0 = ToTensor()(im0)[None].to(device)
         | 
| 232 | 
            +
                        im1 = ToTensor()(im1)[None].to(device)
         | 
| 233 | 
            +
             
         | 
| 234 | 
            +
                    B,C,H0,W0 = im0.shape
         | 
| 235 | 
            +
                    B,C,H1,W1 = im1.shape
         | 
| 236 | 
            +
                    self.train(False)
         | 
| 237 | 
            +
                    corresps = self.forward({"im_A":im0, "im_B":im1})
         | 
| 238 | 
            +
                    #return 1,1
         | 
| 239 | 
            +
                    flow = F.interpolate(
         | 
| 240 | 
            +
                        corresps[4]["flow"], 
         | 
| 241 | 
            +
                        size = (H0, W0), 
         | 
| 242 | 
            +
                        mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2)
         | 
| 243 | 
            +
                    grid = torch.stack(
         | 
| 244 | 
            +
                        torch.meshgrid(
         | 
| 245 | 
            +
                            torch.linspace(-1+1/W0,1-1/W0, W0), 
         | 
| 246 | 
            +
                            torch.linspace(-1+1/H0,1-1/H0, H0), 
         | 
| 247 | 
            +
                            indexing = "xy"), 
         | 
| 248 | 
            +
                        dim = -1).float().to(flow.device).expand(B, H0, W0, 2)
         | 
| 249 | 
            +
                    
         | 
| 250 | 
            +
                    certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False)
         | 
| 251 | 
            +
                    warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid()
         | 
| 252 | 
            +
                    if batched:
         | 
| 253 | 
            +
                        return warp, cert
         | 
| 254 | 
            +
                    else:
         | 
| 255 | 
            +
                        return warp[0], cert[0]
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                def sample(
         | 
| 258 | 
            +
                    self,
         | 
| 259 | 
            +
                    matches,
         | 
| 260 | 
            +
                    certainty,
         | 
| 261 | 
            +
                    num=10000,
         | 
| 262 | 
            +
                ):
         | 
| 263 | 
            +
                    if "threshold" in self.sample_mode:
         | 
| 264 | 
            +
                        upper_thresh = self.sample_thresh
         | 
| 265 | 
            +
                        certainty = certainty.clone()
         | 
| 266 | 
            +
                        certainty[certainty > upper_thresh] = 1
         | 
| 267 | 
            +
                    matches, certainty = (
         | 
| 268 | 
            +
                        matches.reshape(-1, 4),
         | 
| 269 | 
            +
                        certainty.reshape(-1),
         | 
| 270 | 
            +
                    )
         | 
| 271 | 
            +
                    expansion_factor = 4 if "balanced" in self.sample_mode else 1
         | 
| 272 | 
            +
                    good_samples = torch.multinomial(certainty, 
         | 
| 273 | 
            +
                                      num_samples = min(expansion_factor*num, len(certainty)), 
         | 
| 274 | 
            +
                                      replacement=False)
         | 
| 275 | 
            +
                    good_matches, good_certainty = matches[good_samples], certainty[good_samples]
         | 
| 276 | 
            +
                    if "balanced" not in self.sample_mode:
         | 
| 277 | 
            +
                        return good_matches, good_certainty
         | 
| 278 | 
            +
                    density = kde(good_matches, std=0.1)
         | 
| 279 | 
            +
                    p = 1 / (density+1)
         | 
| 280 | 
            +
                    p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
         | 
| 281 | 
            +
                    balanced_samples = torch.multinomial(p, 
         | 
| 282 | 
            +
                                      num_samples = min(num,len(good_certainty)), 
         | 
| 283 | 
            +
                                      replacement=False)
         | 
| 284 | 
            +
                    return good_matches[balanced_samples], good_certainty[balanced_samples]
         | 
| 285 | 
            +
                        
         | 
| 286 | 
            +
                def forward(self, batch):
         | 
| 287 | 
            +
                    """
         | 
| 288 | 
            +
                        input:
         | 
| 289 | 
            +
                            x -> torch.Tensor(B, C, H, W) grayscale or rgb images
         | 
| 290 | 
            +
                        return:
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                    """
         | 
| 293 | 
            +
                    im0 = batch["im_A"]
         | 
| 294 | 
            +
                    im1 = batch["im_B"]
         | 
| 295 | 
            +
                    corresps = {}
         | 
| 296 | 
            +
                    im0, rh0, rw0 = self.preprocess_tensor(im0)
         | 
| 297 | 
            +
                    im1, rh1, rw1 = self.preprocess_tensor(im1)
         | 
| 298 | 
            +
                    B, C, H0, W0 = im0.shape
         | 
| 299 | 
            +
                    B, C, H1, W1 = im1.shape
         | 
| 300 | 
            +
                    to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None]
         | 
| 301 | 
            +
             
         | 
| 302 | 
            +
                    if im0.shape[-2:] == im1.shape[-2:]:
         | 
| 303 | 
            +
                        x = torch.cat([im0, im1], dim=0)
         | 
| 304 | 
            +
                        x = self.forward_single(x)
         | 
| 305 | 
            +
                        feats_x0_c, feats_x1_c = x[1].chunk(2)
         | 
| 306 | 
            +
                        feats_x0_f, feats_x1_f = x[0].chunk(2)
         | 
| 307 | 
            +
                    else:
         | 
| 308 | 
            +
                        feats_x0_f, feats_x0_c = self.forward_single(im0)
         | 
| 309 | 
            +
                        feats_x1_f, feats_x1_c = self.forward_single(im1)
         | 
| 310 | 
            +
                    corr_volume = self.corr_volume(feats_x0_c, feats_x1_c)
         | 
| 311 | 
            +
                    coarse_warp = self.pos_embed(corr_volume)
         | 
| 312 | 
            +
                    coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1)
         | 
| 313 | 
            +
                    feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
         | 
| 314 | 
            +
                    coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1))
         | 
| 315 | 
            +
                    coarse_matches = coarse_matches + coarse_matches_delta * to_normalized
         | 
| 316 | 
            +
                    corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]}
         | 
| 317 | 
            +
                    coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False)        
         | 
| 318 | 
            +
                    coarse_matches_up_detach = coarse_matches_up.detach()#note the detach
         | 
| 319 | 
            +
                    feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
         | 
| 320 | 
            +
                    fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1))
         | 
| 321 | 
            +
                    fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized
         | 
| 322 | 
            +
                    corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]}
         | 
| 323 | 
            +
                    return corresps
         | 
| 324 | 
            +
                
         | 
| 325 | 
            +
             | 
| 326 | 
            +
             | 
| 327 | 
            +
             | 
| 328 | 
            +
             | 
| 329 | 
            +
            def train(args):
         | 
| 330 | 
            +
                rank = 0
         | 
| 331 | 
            +
                gpus = 1
         | 
| 332 | 
            +
                device_id = rank % torch.cuda.device_count()
         | 
| 333 | 
            +
                romatch.LOCAL_RANK = 0
         | 
| 334 | 
            +
                torch.cuda.set_device(device_id)
         | 
| 335 | 
            +
                    
         | 
| 336 | 
            +
                resolution = "big"
         | 
| 337 | 
            +
                wandb_log = not args.dont_log_wandb
         | 
| 338 | 
            +
                experiment_name = Path(__file__).stem
         | 
| 339 | 
            +
                wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
         | 
| 340 | 
            +
                wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
         | 
| 341 | 
            +
                checkpoint_dir = "workspace/checkpoints/"
         | 
| 342 | 
            +
                h,w = resolutions[resolution]
         | 
| 343 | 
            +
                model = XFeatModel(freeze_xfeat = False).to(device_id)
         | 
| 344 | 
            +
                # Num steps
         | 
| 345 | 
            +
                global_step = 0
         | 
| 346 | 
            +
                batch_size = args.gpu_batch_size
         | 
| 347 | 
            +
                step_size = gpus*batch_size
         | 
| 348 | 
            +
                romatch.STEP_SIZE = step_size
         | 
| 349 | 
            +
                
         | 
| 350 | 
            +
                N = 2_000_000  # 2M pairs
         | 
| 351 | 
            +
                # checkpoint every
         | 
| 352 | 
            +
                k = 25000 // romatch.STEP_SIZE
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                # Data
         | 
| 355 | 
            +
                mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
         | 
| 356 | 
            +
                use_horizontal_flip_aug = True
         | 
| 357 | 
            +
                normalize = False # don't imgnet normalize
         | 
| 358 | 
            +
                rot_prob = 0
         | 
| 359 | 
            +
                depth_interpolation_mode = "bilinear"
         | 
| 360 | 
            +
                megadepth_train1 = mega.build_scenes(
         | 
| 361 | 
            +
                    split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
         | 
| 362 | 
            +
                    ht=h,wt=w, normalize = normalize
         | 
| 363 | 
            +
                )
         | 
| 364 | 
            +
                megadepth_train2 = mega.build_scenes(
         | 
| 365 | 
            +
                    split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
         | 
| 366 | 
            +
                    ht=h,wt=w, normalize = normalize
         | 
| 367 | 
            +
                )
         | 
| 368 | 
            +
                megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
         | 
| 369 | 
            +
                mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
         | 
| 370 | 
            +
                # Loss and optimizer
         | 
| 371 | 
            +
                depth_loss = RobustLosses(
         | 
| 372 | 
            +
                    ce_weight=0.01, 
         | 
| 373 | 
            +
                    local_dist={4:4},
         | 
| 374 | 
            +
                    depth_interpolation_mode=depth_interpolation_mode,
         | 
| 375 | 
            +
                    alpha = {4:0.15, 8:0.15},
         | 
| 376 | 
            +
                    c = 1e-4,
         | 
| 377 | 
            +
                    epe_mask_prob_th = 0.001,
         | 
| 378 | 
            +
                    )
         | 
| 379 | 
            +
                parameters = [
         | 
| 380 | 
            +
                    {"params": model.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
         | 
| 381 | 
            +
                ]
         | 
| 382 | 
            +
                optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
         | 
| 383 | 
            +
                lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
         | 
| 384 | 
            +
                    optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
         | 
| 385 | 
            +
                #megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
         | 
| 386 | 
            +
                mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 30)
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                checkpointer = CheckPoint(checkpoint_dir, experiment_name)
         | 
| 389 | 
            +
                model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
         | 
| 390 | 
            +
                romatch.GLOBAL_STEP = global_step
         | 
| 391 | 
            +
                grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
         | 
| 392 | 
            +
                grad_clip_norm = 0.01
         | 
| 393 | 
            +
                #megadense_benchmark.benchmark(model)
         | 
| 394 | 
            +
                for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
         | 
| 395 | 
            +
                    mega_sampler = torch.utils.data.WeightedRandomSampler(
         | 
| 396 | 
            +
                        mega_ws, num_samples = batch_size * k, replacement=False
         | 
| 397 | 
            +
                    )
         | 
| 398 | 
            +
                    mega_dataloader = iter(
         | 
| 399 | 
            +
                        torch.utils.data.DataLoader(
         | 
| 400 | 
            +
                            megadepth_train,
         | 
| 401 | 
            +
                            batch_size = batch_size,
         | 
| 402 | 
            +
                            sampler = mega_sampler,
         | 
| 403 | 
            +
                            num_workers = 8,
         | 
| 404 | 
            +
                        )
         | 
| 405 | 
            +
                    )
         | 
| 406 | 
            +
                    train_k_steps(
         | 
| 407 | 
            +
                        n, k, mega_dataloader, model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
         | 
| 408 | 
            +
                    )
         | 
| 409 | 
            +
                    checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
         | 
| 410 | 
            +
                    wandb.log(mega1500_benchmark.benchmark(model, model_name=experiment_name), step = romatch.GLOBAL_STEP)
         | 
| 411 | 
            +
             | 
| 412 | 
            +
            def test_mega_8_scenes(model, name):
         | 
| 413 | 
            +
                mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
         | 
| 414 | 
            +
                                                            scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
         | 
| 415 | 
            +
                                                                'mega_8_scenes_0025_0.1_0.3.npz',
         | 
| 416 | 
            +
                                                                'mega_8_scenes_0021_0.1_0.3.npz',
         | 
| 417 | 
            +
                                                                'mega_8_scenes_0008_0.1_0.3.npz',
         | 
| 418 | 
            +
                                                                'mega_8_scenes_0032_0.1_0.3.npz',
         | 
| 419 | 
            +
                                                                'mega_8_scenes_1589_0.1_0.3.npz',
         | 
| 420 | 
            +
                                                                'mega_8_scenes_0063_0.1_0.3.npz',
         | 
| 421 | 
            +
                                                                'mega_8_scenes_0024_0.1_0.3.npz',
         | 
| 422 | 
            +
                                                                'mega_8_scenes_0019_0.3_0.5.npz',
         | 
| 423 | 
            +
                                                                'mega_8_scenes_0025_0.3_0.5.npz',
         | 
| 424 | 
            +
                                                                'mega_8_scenes_0021_0.3_0.5.npz',
         | 
| 425 | 
            +
                                                                'mega_8_scenes_0008_0.3_0.5.npz',
         | 
| 426 | 
            +
                                                                'mega_8_scenes_0032_0.3_0.5.npz',
         | 
| 427 | 
            +
                                                                'mega_8_scenes_1589_0.3_0.5.npz',
         | 
| 428 | 
            +
                                                                'mega_8_scenes_0063_0.3_0.5.npz',
         | 
| 429 | 
            +
                                                                'mega_8_scenes_0024_0.3_0.5.npz'])
         | 
| 430 | 
            +
                mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
         | 
| 431 | 
            +
                print(mega_8_scenes_results)
         | 
| 432 | 
            +
                json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
         | 
| 433 | 
            +
             | 
| 434 | 
            +
            def test_mega1500(model, name):
         | 
| 435 | 
            +
                mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
         | 
| 436 | 
            +
                mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
         | 
| 437 | 
            +
                json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
         | 
| 438 | 
            +
             | 
| 439 | 
            +
            def test_mega1500_poselib(model, name):
         | 
| 440 | 
            +
                mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
         | 
| 441 | 
            +
                mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
         | 
| 442 | 
            +
                json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
         | 
| 443 | 
            +
             | 
| 444 | 
            +
            def test_mega_8_scenes_poselib(model, name):
         | 
| 445 | 
            +
                mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
         | 
| 446 | 
            +
                                                              scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
         | 
| 447 | 
            +
                                                                'mega_8_scenes_0025_0.1_0.3.npz',
         | 
| 448 | 
            +
                                                                'mega_8_scenes_0021_0.1_0.3.npz',
         | 
| 449 | 
            +
                                                                'mega_8_scenes_0008_0.1_0.3.npz',
         | 
| 450 | 
            +
                                                                'mega_8_scenes_0032_0.1_0.3.npz',
         | 
| 451 | 
            +
                                                                'mega_8_scenes_1589_0.1_0.3.npz',
         | 
| 452 | 
            +
                                                                'mega_8_scenes_0063_0.1_0.3.npz',
         | 
| 453 | 
            +
                                                                'mega_8_scenes_0024_0.1_0.3.npz',
         | 
| 454 | 
            +
                                                                'mega_8_scenes_0019_0.3_0.5.npz',
         | 
| 455 | 
            +
                                                                'mega_8_scenes_0025_0.3_0.5.npz',
         | 
| 456 | 
            +
                                                                'mega_8_scenes_0021_0.3_0.5.npz',
         | 
| 457 | 
            +
                                                                'mega_8_scenes_0008_0.3_0.5.npz',
         | 
| 458 | 
            +
                                                                'mega_8_scenes_0032_0.3_0.5.npz',
         | 
| 459 | 
            +
                                                                'mega_8_scenes_1589_0.3_0.5.npz',
         | 
| 460 | 
            +
                                                                'mega_8_scenes_0063_0.3_0.5.npz',
         | 
| 461 | 
            +
                                                                'mega_8_scenes_0024_0.3_0.5.npz'])
         | 
| 462 | 
            +
                mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
         | 
| 463 | 
            +
                json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
         | 
| 464 | 
            +
             | 
| 465 | 
            +
            def test_scannet_poselib(model, name):
         | 
| 466 | 
            +
                scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
         | 
| 467 | 
            +
                scannet_results = scannet_benchmark.benchmark(model)
         | 
| 468 | 
            +
                json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
         | 
| 469 | 
            +
             | 
| 470 | 
            +
            def test_scannet(model, name):
         | 
| 471 | 
            +
                scannet_benchmark = ScanNetBenchmark("data/scannet")
         | 
| 472 | 
            +
                scannet_results = scannet_benchmark.benchmark(model)
         | 
| 473 | 
            +
                json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
         | 
| 474 | 
            +
             | 
| 475 | 
            +
            if __name__ == "__main__":
         | 
| 476 | 
            +
                os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
         | 
| 477 | 
            +
                os.environ["OMP_NUM_THREADS"] = "16"
         | 
| 478 | 
            +
                torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
         | 
| 479 | 
            +
                import romatch
         | 
| 480 | 
            +
                parser = ArgumentParser()
         | 
| 481 | 
            +
                parser.add_argument("--only_test", action='store_true')
         | 
| 482 | 
            +
                parser.add_argument("--debug_mode", action='store_true')
         | 
| 483 | 
            +
                parser.add_argument("--dont_log_wandb", action='store_true')
         | 
| 484 | 
            +
                parser.add_argument("--train_resolution", default='medium')
         | 
| 485 | 
            +
                parser.add_argument("--gpu_batch_size", default=8, type=int)
         | 
| 486 | 
            +
                parser.add_argument("--wandb_entity", required = False)
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                args, _ = parser.parse_known_args()
         | 
| 489 | 
            +
                romatch.DEBUG_MODE = args.debug_mode
         | 
| 490 | 
            +
                if not args.only_test:
         | 
| 491 | 
            +
                    train(args)
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                experiment_name = "tiny_roma_v1_outdoor"#Path(__file__).stem
         | 
| 494 | 
            +
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 495 | 
            +
                model = XFeatModel(freeze_xfeat=False, exact_softmax=False).to(device)
         | 
| 496 | 
            +
                model.load_state_dict(torch.load(f"{experiment_name}.pth"))
         | 
| 497 | 
            +
                test_mega1500_poselib(model, experiment_name)
         | 
| 498 | 
            +
                
         | 
    	
        submodules/RoMa/requirements.txt
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            torch
         | 
| 2 | 
            +
            einops
         | 
| 3 | 
            +
            torchvision
         | 
| 4 | 
            +
            opencv-python
         | 
| 5 | 
            +
            kornia
         | 
| 6 | 
            +
            albumentations
         | 
| 7 | 
            +
            loguru
         | 
| 8 | 
            +
            tqdm
         | 
| 9 | 
            +
            matplotlib
         | 
| 10 | 
            +
            h5py
         | 
| 11 | 
            +
            wandb
         | 
| 12 | 
            +
            timm
         | 
| 13 | 
            +
            poselib
         | 
| 14 | 
            +
            #xformers # Optional, used for memefficient attention
         | 
    	
        submodules/RoMa/romatch/__init__.py
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from .models import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            DEBUG_MODE = False
         | 
| 5 | 
            +
            RANK = int(os.environ.get('RANK', default = 0))
         | 
| 6 | 
            +
            GLOBAL_STEP = 0
         | 
| 7 | 
            +
            STEP_SIZE = 1
         | 
| 8 | 
            +
            LOCAL_RANK = -1
         | 
    	
        submodules/RoMa/romatch/benchmarks/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
         | 
| 2 | 
            +
            from .scannet_benchmark import ScanNetBenchmark
         | 
| 3 | 
            +
            from .megadepth_pose_estimation_benchmark import MegaDepthPoseEstimationBenchmark
         | 
| 4 | 
            +
            from .megadepth_dense_benchmark import MegadepthDenseBenchmark
         | 
| 5 | 
            +
            from .megadepth_pose_estimation_benchmark_poselib import Mega1500PoseLibBenchmark
         | 
| 6 | 
            +
            #from .scannet_benchmark_poselib import ScanNetPoselibBenchmark
         | 
    	
        submodules/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py
    ADDED
    
    | @@ -0,0 +1,113 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from PIL import Image
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from tqdm import tqdm
         | 
| 7 | 
            +
            from romatch.utils import pose_auc
         | 
| 8 | 
            +
            import cv2
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class HpatchesHomogBenchmark:
         | 
| 12 | 
            +
                """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                def __init__(self, dataset_path) -> None:
         | 
| 15 | 
            +
                    seqs_dir = "hpatches-sequences-release"
         | 
| 16 | 
            +
                    self.seqs_path = os.path.join(dataset_path, seqs_dir)
         | 
| 17 | 
            +
                    self.seq_names = sorted(os.listdir(self.seqs_path))
         | 
| 18 | 
            +
                    # Ignore seqs is same as LoFTR.
         | 
| 19 | 
            +
                    self.ignore_seqs = set(
         | 
| 20 | 
            +
                        [
         | 
| 21 | 
            +
                            "i_contruction",
         | 
| 22 | 
            +
                            "i_crownnight",
         | 
| 23 | 
            +
                            "i_dc",
         | 
| 24 | 
            +
                            "i_pencils",
         | 
| 25 | 
            +
                            "i_whitebuilding",
         | 
| 26 | 
            +
                            "v_artisans",
         | 
| 27 | 
            +
                            "v_astronautis",
         | 
| 28 | 
            +
                            "v_talent",
         | 
| 29 | 
            +
                        ]
         | 
| 30 | 
            +
                    )
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def convert_coordinates(self, im_A_coords, im_A_to_im_B, wq, hq, wsup, hsup):
         | 
| 33 | 
            +
                    offset = 0.5  # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
         | 
| 34 | 
            +
                    im_A_coords = (
         | 
| 35 | 
            +
                        np.stack(
         | 
| 36 | 
            +
                            (
         | 
| 37 | 
            +
                                wq * (im_A_coords[..., 0] + 1) / 2,
         | 
| 38 | 
            +
                                hq * (im_A_coords[..., 1] + 1) / 2,
         | 
| 39 | 
            +
                            ),
         | 
| 40 | 
            +
                            axis=-1,
         | 
| 41 | 
            +
                        )
         | 
| 42 | 
            +
                        - offset
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
                    im_A_to_im_B = (
         | 
| 45 | 
            +
                        np.stack(
         | 
| 46 | 
            +
                            (
         | 
| 47 | 
            +
                                wsup * (im_A_to_im_B[..., 0] + 1) / 2,
         | 
| 48 | 
            +
                                hsup * (im_A_to_im_B[..., 1] + 1) / 2,
         | 
| 49 | 
            +
                            ),
         | 
| 50 | 
            +
                            axis=-1,
         | 
| 51 | 
            +
                        )
         | 
| 52 | 
            +
                        - offset
         | 
| 53 | 
            +
                    )
         | 
| 54 | 
            +
                    return im_A_coords, im_A_to_im_B
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def benchmark(self, model, model_name = None):
         | 
| 57 | 
            +
                    n_matches = []
         | 
| 58 | 
            +
                    homog_dists = []
         | 
| 59 | 
            +
                    for seq_idx, seq_name in tqdm(
         | 
| 60 | 
            +
                        enumerate(self.seq_names), total=len(self.seq_names)
         | 
| 61 | 
            +
                    ):
         | 
| 62 | 
            +
                        im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
         | 
| 63 | 
            +
                        im_A = Image.open(im_A_path)
         | 
| 64 | 
            +
                        w1, h1 = im_A.size
         | 
| 65 | 
            +
                        for im_idx in range(2, 7):
         | 
| 66 | 
            +
                            im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
         | 
| 67 | 
            +
                            im_B = Image.open(im_B_path)
         | 
| 68 | 
            +
                            w2, h2 = im_B.size
         | 
| 69 | 
            +
                            H = np.loadtxt(
         | 
| 70 | 
            +
                                os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
         | 
| 71 | 
            +
                            )
         | 
| 72 | 
            +
                            dense_matches, dense_certainty = model.match(
         | 
| 73 | 
            +
                                im_A_path, im_B_path
         | 
| 74 | 
            +
                            )
         | 
| 75 | 
            +
                            good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
         | 
| 76 | 
            +
                            pos_a, pos_b = self.convert_coordinates(
         | 
| 77 | 
            +
                                good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
         | 
| 78 | 
            +
                            )
         | 
| 79 | 
            +
                            try:
         | 
| 80 | 
            +
                                H_pred, inliers = cv2.findHomography(
         | 
| 81 | 
            +
                                    pos_a,
         | 
| 82 | 
            +
                                    pos_b,
         | 
| 83 | 
            +
                                    method = cv2.RANSAC,
         | 
| 84 | 
            +
                                    confidence = 0.99999,
         | 
| 85 | 
            +
                                    ransacReprojThreshold = 3 * min(w2, h2) / 480,
         | 
| 86 | 
            +
                                )
         | 
| 87 | 
            +
                            except:
         | 
| 88 | 
            +
                                H_pred = None
         | 
| 89 | 
            +
                            if H_pred is None:
         | 
| 90 | 
            +
                                H_pred = np.zeros((3, 3))
         | 
| 91 | 
            +
                                H_pred[2, 2] = 1.0
         | 
| 92 | 
            +
                            corners = np.array(
         | 
| 93 | 
            +
                                [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
         | 
| 94 | 
            +
                            )
         | 
| 95 | 
            +
                            real_warped_corners = np.dot(corners, np.transpose(H))
         | 
| 96 | 
            +
                            real_warped_corners = (
         | 
| 97 | 
            +
                                real_warped_corners[:, :2] / real_warped_corners[:, 2:]
         | 
| 98 | 
            +
                            )
         | 
| 99 | 
            +
                            warped_corners = np.dot(corners, np.transpose(H_pred))
         | 
| 100 | 
            +
                            warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
         | 
| 101 | 
            +
                            mean_dist = np.mean(
         | 
| 102 | 
            +
                                np.linalg.norm(real_warped_corners - warped_corners, axis=1)
         | 
| 103 | 
            +
                            ) / (min(w2, h2) / 480.0)
         | 
| 104 | 
            +
                            homog_dists.append(mean_dist)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    n_matches = np.array(n_matches)
         | 
| 107 | 
            +
                    thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
         | 
| 108 | 
            +
                    auc = pose_auc(np.array(homog_dists), thresholds)
         | 
| 109 | 
            +
                    return {
         | 
| 110 | 
            +
                        "hpatches_homog_auc_3": auc[2],
         | 
| 111 | 
            +
                        "hpatches_homog_auc_5": auc[4],
         | 
| 112 | 
            +
                        "hpatches_homog_auc_10": auc[9],
         | 
| 113 | 
            +
                    }
         | 
    	
        submodules/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py
    ADDED
    
    | @@ -0,0 +1,106 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import tqdm
         | 
| 4 | 
            +
            from romatch.datasets import MegadepthBuilder
         | 
| 5 | 
            +
            from romatch.utils import warp_kpts
         | 
| 6 | 
            +
            from torch.utils.data import ConcatDataset
         | 
| 7 | 
            +
            import romatch
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            class MegadepthDenseBenchmark:
         | 
| 10 | 
            +
                def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
         | 
| 11 | 
            +
                    mega = MegadepthBuilder(data_root=data_root)
         | 
| 12 | 
            +
                    self.dataset = ConcatDataset(
         | 
| 13 | 
            +
                        mega.build_scenes(split="test_loftr", ht=h, wt=w)
         | 
| 14 | 
            +
                    )  # fixed resolution of 384,512
         | 
| 15 | 
            +
                    self.num_samples = num_samples
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
         | 
| 18 | 
            +
                    b, h1, w1, d = dense_matches.shape
         | 
| 19 | 
            +
                    with torch.no_grad():
         | 
| 20 | 
            +
                        x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
         | 
| 21 | 
            +
                        mask, x2 = warp_kpts(
         | 
| 22 | 
            +
                            x1.double(),
         | 
| 23 | 
            +
                            depth1.double(),
         | 
| 24 | 
            +
                            depth2.double(),
         | 
| 25 | 
            +
                            T_1to2.double(),
         | 
| 26 | 
            +
                            K1.double(),
         | 
| 27 | 
            +
                            K2.double(),
         | 
| 28 | 
            +
                        )
         | 
| 29 | 
            +
                        x2 = torch.stack(
         | 
| 30 | 
            +
                            (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
         | 
| 31 | 
            +
                        )
         | 
| 32 | 
            +
                        prob = mask.float().reshape(b, h1, w1)
         | 
| 33 | 
            +
                    x2_hat = dense_matches[..., 2:]
         | 
| 34 | 
            +
                    x2_hat = torch.stack(
         | 
| 35 | 
            +
                        (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
         | 
| 36 | 
            +
                    )
         | 
| 37 | 
            +
                    gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
         | 
| 38 | 
            +
                    gd = gd[prob == 1]
         | 
| 39 | 
            +
                    pck_1 = (gd < 1.0).float().mean()
         | 
| 40 | 
            +
                    pck_3 = (gd < 3.0).float().mean()
         | 
| 41 | 
            +
                    pck_5 = (gd < 5.0).float().mean()
         | 
| 42 | 
            +
                    return gd, pck_1, pck_3, pck_5, prob
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def benchmark(self, model, batch_size=8):
         | 
| 45 | 
            +
                    model.train(False)
         | 
| 46 | 
            +
                    with torch.no_grad():
         | 
| 47 | 
            +
                        gd_tot = 0.0
         | 
| 48 | 
            +
                        pck_1_tot = 0.0
         | 
| 49 | 
            +
                        pck_3_tot = 0.0
         | 
| 50 | 
            +
                        pck_5_tot = 0.0
         | 
| 51 | 
            +
                        sampler = torch.utils.data.WeightedRandomSampler(
         | 
| 52 | 
            +
                            torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
         | 
| 53 | 
            +
                        )
         | 
| 54 | 
            +
                        B = batch_size
         | 
| 55 | 
            +
                        dataloader = torch.utils.data.DataLoader(
         | 
| 56 | 
            +
                            self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
         | 
| 57 | 
            +
                        )
         | 
| 58 | 
            +
                        for idx, data in tqdm.tqdm(enumerate(dataloader), disable = romatch.RANK > 0):
         | 
| 59 | 
            +
                            im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
         | 
| 60 | 
            +
                                data["im_A"].cuda(),
         | 
| 61 | 
            +
                                data["im_B"].cuda(),
         | 
| 62 | 
            +
                                data["im_A_depth"].cuda(),
         | 
| 63 | 
            +
                                data["im_B_depth"].cuda(),
         | 
| 64 | 
            +
                                data["T_1to2"].cuda(),
         | 
| 65 | 
            +
                                data["K1"].cuda(),
         | 
| 66 | 
            +
                                data["K2"].cuda(),
         | 
| 67 | 
            +
                            )
         | 
| 68 | 
            +
                            matches, certainty = model.match(im_A, im_B, batched=True)
         | 
| 69 | 
            +
                            gd, pck_1, pck_3, pck_5, prob = self.geometric_dist(
         | 
| 70 | 
            +
                                depth1, depth2, T_1to2, K1, K2, matches
         | 
| 71 | 
            +
                            )
         | 
| 72 | 
            +
                            if romatch.DEBUG_MODE:
         | 
| 73 | 
            +
                                from romatch.utils.utils import tensor_to_pil
         | 
| 74 | 
            +
                                import torch.nn.functional as F
         | 
| 75 | 
            +
                                path = "vis"
         | 
| 76 | 
            +
                                H, W = model.get_output_resolution()
         | 
| 77 | 
            +
                                white_im = torch.ones((B,1,H,W),device="cuda")
         | 
| 78 | 
            +
                                im_B_transfer_rgb = F.grid_sample(
         | 
| 79 | 
            +
                                    im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
         | 
| 80 | 
            +
                                )
         | 
| 81 | 
            +
                                warp_im = im_B_transfer_rgb
         | 
| 82 | 
            +
                                c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
         | 
| 83 | 
            +
                                vis_im = c_b * warp_im + (1 - c_b) * white_im
         | 
| 84 | 
            +
                                for b in range(B):
         | 
| 85 | 
            +
                                    import os
         | 
| 86 | 
            +
                                    os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
         | 
| 87 | 
            +
                                    tensor_to_pil(vis_im[b], unnormalize=True).save(
         | 
| 88 | 
            +
                                        f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
         | 
| 89 | 
            +
                                    tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
         | 
| 90 | 
            +
                                        f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
         | 
| 91 | 
            +
                                    tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
         | 
| 92 | 
            +
                                        f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
                            gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
         | 
| 96 | 
            +
                                gd_tot + gd.mean(),
         | 
| 97 | 
            +
                                pck_1_tot + pck_1,
         | 
| 98 | 
            +
                                pck_3_tot + pck_3,
         | 
| 99 | 
            +
                                pck_5_tot + pck_5,
         | 
| 100 | 
            +
                            )
         | 
| 101 | 
            +
                    return {
         | 
| 102 | 
            +
                        "epe": gd_tot.item() / len(dataloader),
         | 
| 103 | 
            +
                        "mega_pck_1": pck_1_tot.item() / len(dataloader),
         | 
| 104 | 
            +
                        "mega_pck_3": pck_3_tot.item() / len(dataloader),
         | 
| 105 | 
            +
                        "mega_pck_5": pck_5_tot.item() / len(dataloader),
         | 
| 106 | 
            +
                    }
         | 
    	
        submodules/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py
    ADDED
    
    | @@ -0,0 +1,118 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from romatch.utils import *
         | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            from tqdm import tqdm
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
            import romatch
         | 
| 8 | 
            +
            import kornia.geometry.epipolar as kepi
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            class MegaDepthPoseEstimationBenchmark:
         | 
| 11 | 
            +
                def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
         | 
| 12 | 
            +
                    if scene_names is None:
         | 
| 13 | 
            +
                        self.scene_names = [
         | 
| 14 | 
            +
                            "0015_0.1_0.3.npz",
         | 
| 15 | 
            +
                            "0015_0.3_0.5.npz",
         | 
| 16 | 
            +
                            "0022_0.1_0.3.npz",
         | 
| 17 | 
            +
                            "0022_0.3_0.5.npz",
         | 
| 18 | 
            +
                            "0022_0.5_0.7.npz",
         | 
| 19 | 
            +
                        ]
         | 
| 20 | 
            +
                    else:
         | 
| 21 | 
            +
                        self.scene_names = scene_names
         | 
| 22 | 
            +
                    self.scenes = [
         | 
| 23 | 
            +
                        np.load(f"{data_root}/{scene}", allow_pickle=True)
         | 
| 24 | 
            +
                        for scene in self.scene_names
         | 
| 25 | 
            +
                    ]
         | 
| 26 | 
            +
                    self.data_root = data_root
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def benchmark(self, model, model_name = None):
         | 
| 29 | 
            +
                    with torch.no_grad():
         | 
| 30 | 
            +
                        data_root = self.data_root
         | 
| 31 | 
            +
                        tot_e_t, tot_e_R, tot_e_pose = [], [], []
         | 
| 32 | 
            +
                        thresholds = [5, 10, 20]
         | 
| 33 | 
            +
                        for scene_ind in range(len(self.scenes)):
         | 
| 34 | 
            +
                            import os
         | 
| 35 | 
            +
                            scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
         | 
| 36 | 
            +
                            scene = self.scenes[scene_ind]
         | 
| 37 | 
            +
                            pairs = scene["pair_infos"]
         | 
| 38 | 
            +
                            intrinsics = scene["intrinsics"]
         | 
| 39 | 
            +
                            poses = scene["poses"]
         | 
| 40 | 
            +
                            im_paths = scene["image_paths"]
         | 
| 41 | 
            +
                            pair_inds = range(len(pairs))
         | 
| 42 | 
            +
                            for pairind in tqdm(pair_inds):
         | 
| 43 | 
            +
                                idx1, idx2 = pairs[pairind][0]
         | 
| 44 | 
            +
                                K1 = intrinsics[idx1].copy()
         | 
| 45 | 
            +
                                T1 = poses[idx1].copy()
         | 
| 46 | 
            +
                                R1, t1 = T1[:3, :3], T1[:3, 3]
         | 
| 47 | 
            +
                                K2 = intrinsics[idx2].copy()
         | 
| 48 | 
            +
                                T2 = poses[idx2].copy()
         | 
| 49 | 
            +
                                R2, t2 = T2[:3, :3], T2[:3, 3]
         | 
| 50 | 
            +
                                R, t = compute_relative_pose(R1, t1, R2, t2)
         | 
| 51 | 
            +
                                T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
         | 
| 52 | 
            +
                                im_A_path = f"{data_root}/{im_paths[idx1]}"
         | 
| 53 | 
            +
                                im_B_path = f"{data_root}/{im_paths[idx2]}"
         | 
| 54 | 
            +
                                dense_matches, dense_certainty = model.match(
         | 
| 55 | 
            +
                                    im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
         | 
| 56 | 
            +
                                )
         | 
| 57 | 
            +
                                sparse_matches,_ = model.sample(
         | 
| 58 | 
            +
                                    dense_matches, dense_certainty, 5_000
         | 
| 59 | 
            +
                                )
         | 
| 60 | 
            +
                                
         | 
| 61 | 
            +
                                im_A = Image.open(im_A_path)
         | 
| 62 | 
            +
                                w1, h1 = im_A.size
         | 
| 63 | 
            +
                                im_B = Image.open(im_B_path)
         | 
| 64 | 
            +
                                w2, h2 = im_B.size
         | 
| 65 | 
            +
                                if True: # Note: we keep this true as it was used in DKM/RoMa papers. There is very little difference compared to setting to False. 
         | 
| 66 | 
            +
                                    scale1 = 1200 / max(w1, h1)
         | 
| 67 | 
            +
                                    scale2 = 1200 / max(w2, h2)
         | 
| 68 | 
            +
                                    w1, h1 = scale1 * w1, scale1 * h1
         | 
| 69 | 
            +
                                    w2, h2 = scale2 * w2, scale2 * h2
         | 
| 70 | 
            +
                                    K1, K2 = K1.copy(), K2.copy()
         | 
| 71 | 
            +
                                    K1[:2] = K1[:2] * scale1
         | 
| 72 | 
            +
                                    K2[:2] = K2[:2] * scale2
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                                kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
         | 
| 75 | 
            +
                                kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
         | 
| 76 | 
            +
                                for _ in range(5):
         | 
| 77 | 
            +
                                    shuffling = np.random.permutation(np.arange(len(kpts1)))
         | 
| 78 | 
            +
                                    kpts1 = kpts1[shuffling]
         | 
| 79 | 
            +
                                    kpts2 = kpts2[shuffling]
         | 
| 80 | 
            +
                                    try:
         | 
| 81 | 
            +
                                        threshold = 0.5 
         | 
| 82 | 
            +
                                        norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
         | 
| 83 | 
            +
                                        R_est, t_est, mask = estimate_pose(
         | 
| 84 | 
            +
                                            kpts1,
         | 
| 85 | 
            +
                                            kpts2,
         | 
| 86 | 
            +
                                            K1,
         | 
| 87 | 
            +
                                            K2,
         | 
| 88 | 
            +
                                            norm_threshold,
         | 
| 89 | 
            +
                                            conf=0.99999,
         | 
| 90 | 
            +
                                        )
         | 
| 91 | 
            +
                                        T1_to_2_est = np.concatenate((R_est, t_est), axis=-1)  #
         | 
| 92 | 
            +
                                        e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
         | 
| 93 | 
            +
                                        e_pose = max(e_t, e_R)
         | 
| 94 | 
            +
                                    except Exception as e:
         | 
| 95 | 
            +
                                        print(repr(e))
         | 
| 96 | 
            +
                                        e_t, e_R = 90, 90
         | 
| 97 | 
            +
                                        e_pose = max(e_t, e_R)
         | 
| 98 | 
            +
                                    tot_e_t.append(e_t)
         | 
| 99 | 
            +
                                    tot_e_R.append(e_R)
         | 
| 100 | 
            +
                                    tot_e_pose.append(e_pose)
         | 
| 101 | 
            +
                        tot_e_pose = np.array(tot_e_pose)
         | 
| 102 | 
            +
                        auc = pose_auc(tot_e_pose, thresholds)
         | 
| 103 | 
            +
                        acc_5 = (tot_e_pose < 5).mean()
         | 
| 104 | 
            +
                        acc_10 = (tot_e_pose < 10).mean()
         | 
| 105 | 
            +
                        acc_15 = (tot_e_pose < 15).mean()
         | 
| 106 | 
            +
                        acc_20 = (tot_e_pose < 20).mean()
         | 
| 107 | 
            +
                        map_5 = acc_5
         | 
| 108 | 
            +
                        map_10 = np.mean([acc_5, acc_10])
         | 
| 109 | 
            +
                        map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
         | 
| 110 | 
            +
                        print(f"{model_name} auc: {auc}")
         | 
| 111 | 
            +
                        return {
         | 
| 112 | 
            +
                            "auc_5": auc[0],
         | 
| 113 | 
            +
                            "auc_10": auc[1],
         | 
| 114 | 
            +
                            "auc_20": auc[2],
         | 
| 115 | 
            +
                            "map_5": map_5,
         | 
| 116 | 
            +
                            "map_10": map_10,
         | 
| 117 | 
            +
                            "map_20": map_20,
         | 
| 118 | 
            +
                        }
         | 
